-
-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FEAT]-[Module]: [return_loss_text]: Add [return_loss_text] function …
…for enhanced loss computation readability [FEAT]-[Module]: [calc_z_loss]: Introduce [calc_z_loss] function to calculate Z loss in model training [FEAT]-[Module]: [max_neg_value]: Implement [max_neg_value] function for negative value handling in computations [FEAT]-[Module]: [TextTokenEmbedding]: Deploy [TextTokenEmbedding] for improved text token embedding functionality [FEAT]-[Module]: [dropout_seq]: Add [dropout_seq] function for sequence dropout in neural network layers [FEAT]-[Module]: [transformer_generate]: Introduce [transformer_generate] function for efficient transformer text generation [FEAT]-[Module]: [vit_output_head]: Add [vit_output_head] for Vision Transformer model output handling [FEAT]-[Module]: [patch_linear_flatten]: Implement [patch_linear_flatten] for streamlined linear patch flattening in ViT [FEAT]-[Module]: [ScalableImgSelfAttention]: Introduce [ScalableImgSelfAttention] for scalable image self-attention mechanism ]
- Loading branch information
Kye
committed
Apr 6, 2024
1 parent
b9b67a7
commit cb58448
Showing
10 changed files
with
530 additions
and
15 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[tool.poetry] | ||
name = "zetascale" | ||
version = "2.3.3" | ||
version = "2.3.5" | ||
description = "Rapidly Build, Optimize, and Deploy SOTA AI Models" | ||
authors = ["Zeta Team <[email protected]>"] | ||
license = "MIT" | ||
|
@@ -35,6 +35,7 @@ tqdm = "4.66.2" | |
rich = "13.7.1" | ||
colt5-attention = "*" | ||
argparse = "^1.4.0" | ||
local-attention = "*" | ||
|
||
[build-system] | ||
requires = ["poetry-core>=1.0.0"] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import torch | ||
from torch import nn, Tensor | ||
from zeta.nn.modules.chan_layer_norm import ChanLayerNorm | ||
from einops import rearrange | ||
|
||
|
||
class ScalableImgSelfAttention(nn.Module): | ||
""" | ||
ScalableImgSelfAttention module applies self-attention mechanism to image data. | ||
Args: | ||
dim (int): The input dimension of the image. | ||
heads (int, optional): The number of attention heads. Defaults to 8. | ||
dim_key (int, optional): The dimension of the key vectors. Defaults to 32. | ||
dim_value (int, optional): The dimension of the value vectors. Defaults to 32. | ||
dropout (float, optional): The dropout rate. Defaults to 0.0. | ||
reduction_factor (int, optional): The reduction factor for downscaling the image. Defaults to 1. | ||
Attributes: | ||
dim (int): The input dimension of the image. | ||
heads (int): The number of attention heads. | ||
dim_key (int): The dimension of the key vectors. | ||
dim_value (int): The dimension of the value vectors. | ||
reduction_factor (int): The reduction factor for downscaling the image. | ||
scale (float): The scaling factor for the key vectors. | ||
attend (nn.Softmax): The softmax function for attention calculation. | ||
dropout (nn.Dropout): The dropout layer. | ||
norm (ChanLayerNorm): The channel-wise layer normalization. | ||
to_q (nn.Conv2d): The convolutional layer for query projection. | ||
to_k (nn.Conv2d): The convolutional layer for key projection. | ||
to_v (nn.Conv2d): The convolutional layer for value projection. | ||
to_out (nn.Sequential): The sequential layer for output projection. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dim: int, | ||
heads: int = 8, | ||
dim_key: int = 32, | ||
dim_value: int = 32, | ||
dropout: float = 0.0, | ||
reduction_factor: int = 1, | ||
*args, | ||
**kwargs, | ||
): | ||
super().__init__() | ||
self.dim = dim | ||
self.heads = heads | ||
self.dim_key = dim_key | ||
self.dim_value = dim_value | ||
self.reduction_factor = reduction_factor | ||
|
||
self.scale = dim_key**-0.5 | ||
self.attend = nn.Softmax(dim=-1) | ||
self.dropout = nn.Dropout(dropout) | ||
self.norm = ChanLayerNorm(dim) | ||
|
||
# Projections | ||
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias=False) | ||
self.to_k = nn.Conv2d( | ||
dim, | ||
dim_key * heads, | ||
reduction_factor, | ||
stride=reduction_factor, | ||
bias=False, | ||
) | ||
self.to_v = nn.Conv2d( | ||
dim, | ||
dim_value * heads, | ||
reduction_factor, | ||
stride=reduction_factor, | ||
bias=False, | ||
) | ||
|
||
self.to_out = nn.Sequential( | ||
nn.Conv2d(dim_value * heads, dim, 1), nn.Dropout(dropout) | ||
) | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
""" | ||
Forward pass of the ScalableImgSelfAttention module. | ||
Args: | ||
x (Tensor): The input tensor of shape (batch_size, channels, height, width). | ||
Returns: | ||
Tensor: The output tensor of shape (batch_size, channels, height, width). | ||
""" | ||
h, w, h = *x.shape[-2:], self.heads | ||
|
||
x = self.norm(x) | ||
|
||
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x) | ||
|
||
# Split out heads | ||
q, k, v = map( | ||
lambda t: rearrange(t, "b (h d) ... -> b h (...) d", h=h), | ||
( | ||
q, | ||
k, | ||
), | ||
) | ||
|
||
# Similarity | ||
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale | ||
|
||
# Attention | ||
attn = self.attend(dots) | ||
attn = self.dropout(attn) | ||
|
||
# Aggregate values | ||
out = torch.matmul(attn, v) | ||
|
||
# Merge back heads | ||
out = rearrange( | ||
out, | ||
"b h (x y) d -> b (h d) x y", | ||
x=h, | ||
y=w, | ||
) | ||
return self.to_out(out) | ||
|
||
|
||
# x = torch.randn(1, 3, 64, 64) | ||
# peg = ScalableImgSelfAttention(3) | ||
# out = peg(x) | ||
# print(out.shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import torch | ||
from torch import nn, Tensor | ||
|
||
|
||
class ChanLayerNorm(nn.Module): | ||
def __init__(self, dim: int, eps: float = 1e-5): | ||
""" | ||
Initializes the ChanLayerNorm module. | ||
Args: | ||
dim (int): The input dimension. | ||
eps (float, optional): The epsilon value. Defaults to 1e-5. | ||
""" | ||
super().__init__() | ||
self.dim = dim | ||
self.eps = eps | ||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) | ||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) | ||
|
||
def forward(self, x: Tensor): | ||
""" | ||
Forward pass of the ChanLayerNorm module. | ||
Args: | ||
x (Tensor): The input tensor. | ||
Returns: | ||
Tensor: The normalized tensor. | ||
""" | ||
var = torch.car( | ||
x, | ||
dim=1, | ||
unbiased=False, | ||
keepdim=True, | ||
) | ||
mean = torch.mean(x, dim=1, keepdim=True) | ||
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import torch | ||
from torch import nn, Tensor | ||
from einops.layers.torch import Rearrange | ||
|
||
|
||
def posemb_sincos_2d(patches, temperature=10000, dtype=torch.float32): | ||
_, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype | ||
|
||
y, x = torch.meshgrid( | ||
torch.arange(h, device=device), | ||
torch.arange(w, device=device), | ||
indexing="ij", | ||
) | ||
assert ( | ||
dim % 4 | ||
) == 0, "feature dimension must be multiple of 4 for sincos emb" | ||
omega = torch.arange(dim // 4, device=device) / (dim // 4 - 1) | ||
omega = 1.0 / (temperature**omega) | ||
|
||
y = y.flatten()[:, None] * omega[None, :] | ||
x = x.flatten()[:, None] * omega[None, :] | ||
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) | ||
return pe.type(dtype) | ||
|
||
|
||
def vit_output_head(x: Tensor, dim: int, num_classes: int = None): | ||
""" | ||
Applies a Vision Transformer (ViT) output head to the input tensor. | ||
Args: | ||
x (Tensor): The input tensor. | ||
dim (int): The dimension of the input tensor. | ||
num_classes (int, optional): The number of output classes. Defaults to None. | ||
Returns: | ||
Tensor: The output tensor after applying the ViT output head. | ||
""" | ||
return nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))(x) | ||
|
||
|
||
def patch_linear_flatten( | ||
x: Tensor, | ||
patch_size: int, | ||
dim: int, | ||
image_size: int, | ||
channels: int = 3, | ||
add_pos_embeddings: bool = False, | ||
*args, | ||
**kwargs, | ||
): | ||
""" | ||
Applies patch embedding to the input tensor and flattens it. | ||
Args: | ||
x (Tensor): Input tensor of shape (batch_size, channels, image_height, image_width). | ||
patch_size (int): Size of the square patch. | ||
dim (int): Dimension of the output tensor. | ||
image_size (int): Size of the input image (assumed to be square). | ||
channels (int, optional): Number of input channels. Defaults to 3. | ||
add_pos_embeddings (bool, optional): Whether to add positional embeddings. Defaults to False. | ||
Returns: | ||
Tensor: Flattened tensor of shape (batch_size, num_patches, dim). | ||
""" | ||
image_height, image_width = image_size, image_size | ||
patch_height, patch_width = patch_size, patch_size | ||
|
||
# calculate number of patches | ||
(image_height // patch_height) * (image_width // patch_width) | ||
patch_dim = channels * patch_height * patch_width | ||
|
||
# Patch Embedding layer | ||
to_patch_embeddings = nn.Sequential( | ||
Rearrange( | ||
"b c (h p1) (w p2) -> b h w (p1 p2 c)", | ||
p1=patch_height, | ||
p2=patch_width, | ||
), | ||
nn.LayerNorm(patch_dim), | ||
nn.Linear(patch_dim, dim), | ||
nn.LayerNorm(dim), | ||
)(x) | ||
|
||
if add_pos_embeddings is not False: | ||
pos_embeddings = posemb_sincos_2d(x, *args, **kwargs) | ||
to_patch_embeddings + +pos_embeddings | ||
|
||
return to_patch_embeddings |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from torch import nn, Tensor | ||
|
||
|
||
class PEG(nn.Module): | ||
""" | ||
PEG (Positional Encoding Generator) module. | ||
Args: | ||
dim (int): The input dimension. | ||
kernel_size (int, optional): The size of the convolutional kernel. Defaults to 3. | ||
""" | ||
|
||
def __init__(self, dim: int, kernel_size: int = 3): | ||
super().__init__() | ||
self.proj = nn.Conv2d( | ||
dim, | ||
dim, | ||
kernel_size=kernel_size, | ||
padding=kernel_size // 2, | ||
groups=dim, | ||
stride=1, | ||
) | ||
|
||
def forward(self, x: Tensor): | ||
""" | ||
Forward pass of the PEG module. | ||
Args: | ||
x (Tensor): The input tensor. | ||
Returns: | ||
Tensor: The output tensor. | ||
""" | ||
return self.proj(x) + x |
Oops, something went wrong.