Skip to content

Commit

Permalink
[FEAT]-[Module]: [return_loss_text]: Add [return_loss_text] function …
Browse files Browse the repository at this point in the history
…for enhanced loss computation readability

[FEAT]-[Module]: [calc_z_loss]: Introduce [calc_z_loss] function to calculate Z loss in model training
[FEAT]-[Module]: [max_neg_value]: Implement [max_neg_value] function for negative value handling in computations
[FEAT]-[Module]: [TextTokenEmbedding]: Deploy [TextTokenEmbedding] for improved text token embedding functionality
[FEAT]-[Module]: [dropout_seq]: Add [dropout_seq] function for sequence dropout in neural network layers
[FEAT]-[Module]: [transformer_generate]: Introduce [transformer_generate] function for efficient transformer text generation
[FEAT]-[Module]: [vit_output_head]: Add [vit_output_head] for Vision Transformer model output handling
[FEAT]-[Module]: [patch_linear_flatten]: Implement [patch_linear_flatten] for streamlined linear patch flattening in ViT
[FEAT]-[Module]: [ScalableImgSelfAttention]: Introduce [ScalableImgSelfAttention] for scalable image self-attention mechanism
]
  • Loading branch information
Kye committed Apr 6, 2024
1 parent b9b67a7 commit cb58448
Show file tree
Hide file tree
Showing 10 changed files with 530 additions and 15 deletions.
Empty file added playground/models/spectra.py
Empty file.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "zetascale"
version = "2.3.3"
version = "2.3.5"
description = "Rapidly Build, Optimize, and Deploy SOTA AI Models"
authors = ["Zeta Team <[email protected]>"]
license = "MIT"
Expand Down Expand Up @@ -35,6 +35,7 @@ tqdm = "4.66.2"
rich = "13.7.1"
colt5-attention = "*"
argparse = "^1.4.0"
local-attention = "*"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
6 changes: 2 additions & 4 deletions zeta/nn/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@
from zeta.nn.attention.spatial_linear_attention import SpatialLinearAttention
from zeta.structs.transformer import Attention, AttentionLayers
from zeta.nn.attention.multi_grouped_attn import MultiGroupedQueryAttn

# from zeta.nn.attention.flash_attention2 import FlashAttentionTwo
# from zeta.nn.attention.mgqa import MGQA

from zeta.nn.attention.scalable_img_self_attn import ScalableImgSelfAttention

__all__ = [
"Attend",
Expand All @@ -48,4 +45,5 @@
"Attention",
"AttentionLayers",
"MultiGroupedQueryAttn",
"ScalableImgSelfAttention",
]
129 changes: 129 additions & 0 deletions zeta/nn/attention/scalable_img_self_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import torch
from torch import nn, Tensor
from zeta.nn.modules.chan_layer_norm import ChanLayerNorm
from einops import rearrange


class ScalableImgSelfAttention(nn.Module):
"""
ScalableImgSelfAttention module applies self-attention mechanism to image data.
Args:
dim (int): The input dimension of the image.
heads (int, optional): The number of attention heads. Defaults to 8.
dim_key (int, optional): The dimension of the key vectors. Defaults to 32.
dim_value (int, optional): The dimension of the value vectors. Defaults to 32.
dropout (float, optional): The dropout rate. Defaults to 0.0.
reduction_factor (int, optional): The reduction factor for downscaling the image. Defaults to 1.
Attributes:
dim (int): The input dimension of the image.
heads (int): The number of attention heads.
dim_key (int): The dimension of the key vectors.
dim_value (int): The dimension of the value vectors.
reduction_factor (int): The reduction factor for downscaling the image.
scale (float): The scaling factor for the key vectors.
attend (nn.Softmax): The softmax function for attention calculation.
dropout (nn.Dropout): The dropout layer.
norm (ChanLayerNorm): The channel-wise layer normalization.
to_q (nn.Conv2d): The convolutional layer for query projection.
to_k (nn.Conv2d): The convolutional layer for key projection.
to_v (nn.Conv2d): The convolutional layer for value projection.
to_out (nn.Sequential): The sequential layer for output projection.
"""

def __init__(
self,
dim: int,
heads: int = 8,
dim_key: int = 32,
dim_value: int = 32,
dropout: float = 0.0,
reduction_factor: int = 1,
*args,
**kwargs,
):
super().__init__()
self.dim = dim
self.heads = heads
self.dim_key = dim_key
self.dim_value = dim_value
self.reduction_factor = reduction_factor

self.scale = dim_key**-0.5
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.norm = ChanLayerNorm(dim)

# Projections
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias=False)
self.to_k = nn.Conv2d(
dim,
dim_key * heads,
reduction_factor,
stride=reduction_factor,
bias=False,
)
self.to_v = nn.Conv2d(
dim,
dim_value * heads,
reduction_factor,
stride=reduction_factor,
bias=False,
)

self.to_out = nn.Sequential(
nn.Conv2d(dim_value * heads, dim, 1), nn.Dropout(dropout)
)

def forward(self, x: Tensor) -> Tensor:
"""
Forward pass of the ScalableImgSelfAttention module.
Args:
x (Tensor): The input tensor of shape (batch_size, channels, height, width).
Returns:
Tensor: The output tensor of shape (batch_size, channels, height, width).
"""
h, w, h = *x.shape[-2:], self.heads

x = self.norm(x)

q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)

# Split out heads
q, k, v = map(
lambda t: rearrange(t, "b (h d) ... -> b h (...) d", h=h),
(
q,
k,
),
)

# Similarity
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

# Attention
attn = self.attend(dots)
attn = self.dropout(attn)

# Aggregate values
out = torch.matmul(attn, v)

# Merge back heads
out = rearrange(
out,
"b h (x y) d -> b (h d) x y",
x=h,
y=w,
)
return self.to_out(out)


# x = torch.randn(1, 3, 64, 64)
# peg = ScalableImgSelfAttention(3)
# out = peg(x)
# print(out.shape)
24 changes: 24 additions & 0 deletions zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,20 @@
NormalSparseMoE,
HeirarchicalSparseMoE,
)
from zeta.nn.modules.return_loss_text import (
return_loss_text,
calc_z_loss,
max_neg_value,
TextTokenEmbedding,
dropout_seq,
transformer_generate,
)
from zeta.nn.modules.patch_linear_flatten import (
vit_output_head,
patch_linear_flatten,
)
from zeta.nn.modules.chan_layer_norm import ChanLayerNorm


# from zeta.nn.modules.img_reshape import image_reshape
# from zeta.nn.modules.flatten_features import flatten_features
Expand Down Expand Up @@ -392,4 +406,14 @@
"Top2Gating",
"NormalSparseMoE",
"HeirarchicalSparseMoE",
"return_loss_text",
"calc_z_loss",
"max_neg_value",
"TextTokenEmbedding",
"dropout_seq",
"transformer_generate",
"patch_linear_flatten",
"vit_output_head",
"posemb_sincos_2d",
"ChanLayerNorm",
]
37 changes: 37 additions & 0 deletions zeta/nn/modules/chan_layer_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
from torch import nn, Tensor


class ChanLayerNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
"""
Initializes the ChanLayerNorm module.
Args:
dim (int): The input dimension.
eps (float, optional): The epsilon value. Defaults to 1e-5.
"""
super().__init__()
self.dim = dim
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

def forward(self, x: Tensor):
"""
Forward pass of the ChanLayerNorm module.
Args:
x (Tensor): The input tensor.
Returns:
Tensor: The normalized tensor.
"""
var = torch.car(
x,
dim=1,
unbiased=False,
keepdim=True,
)
mean = torch.mean(x, dim=1, keepdim=True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
88 changes: 88 additions & 0 deletions zeta/nn/modules/patch_linear_flatten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import torch
from torch import nn, Tensor
from einops.layers.torch import Rearrange


def posemb_sincos_2d(patches, temperature=10000, dtype=torch.float32):
_, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype

y, x = torch.meshgrid(
torch.arange(h, device=device),
torch.arange(w, device=device),
indexing="ij",
)
assert (
dim % 4
) == 0, "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4, device=device) / (dim // 4 - 1)
omega = 1.0 / (temperature**omega)

y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pe.type(dtype)


def vit_output_head(x: Tensor, dim: int, num_classes: int = None):
"""
Applies a Vision Transformer (ViT) output head to the input tensor.
Args:
x (Tensor): The input tensor.
dim (int): The dimension of the input tensor.
num_classes (int, optional): The number of output classes. Defaults to None.
Returns:
Tensor: The output tensor after applying the ViT output head.
"""
return nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))(x)


def patch_linear_flatten(
x: Tensor,
patch_size: int,
dim: int,
image_size: int,
channels: int = 3,
add_pos_embeddings: bool = False,
*args,
**kwargs,
):
"""
Applies patch embedding to the input tensor and flattens it.
Args:
x (Tensor): Input tensor of shape (batch_size, channels, image_height, image_width).
patch_size (int): Size of the square patch.
dim (int): Dimension of the output tensor.
image_size (int): Size of the input image (assumed to be square).
channels (int, optional): Number of input channels. Defaults to 3.
add_pos_embeddings (bool, optional): Whether to add positional embeddings. Defaults to False.
Returns:
Tensor: Flattened tensor of shape (batch_size, num_patches, dim).
"""
image_height, image_width = image_size, image_size
patch_height, patch_width = patch_size, patch_size

# calculate number of patches
(image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width

# Patch Embedding layer
to_patch_embeddings = nn.Sequential(
Rearrange(
"b c (h p1) (w p2) -> b h w (p1 p2 c)",
p1=patch_height,
p2=patch_width,
),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)(x)

if add_pos_embeddings is not False:
pos_embeddings = posemb_sincos_2d(x, *args, **kwargs)
to_patch_embeddings + +pos_embeddings

return to_patch_embeddings
34 changes: 34 additions & 0 deletions zeta/nn/modules/peg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from torch import nn, Tensor


class PEG(nn.Module):
"""
PEG (Positional Encoding Generator) module.
Args:
dim (int): The input dimension.
kernel_size (int, optional): The size of the convolutional kernel. Defaults to 3.
"""

def __init__(self, dim: int, kernel_size: int = 3):
super().__init__()
self.proj = nn.Conv2d(
dim,
dim,
kernel_size=kernel_size,
padding=kernel_size // 2,
groups=dim,
stride=1,
)

def forward(self, x: Tensor):
"""
Forward pass of the PEG module.
Args:
x (Tensor): The input tensor.
Returns:
Tensor: The output tensor.
"""
return self.proj(x) + x
Loading

0 comments on commit cb58448

Please sign in to comment.