diff --git a/playground/models/spectra.py b/playground/models/spectra.py new file mode 100644 index 00000000..e69de29b diff --git a/pyproject.toml b/pyproject.toml index 55979046..17bcfe24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "zetascale" -version = "2.3.3" +version = "2.3.5" description = "Rapidly Build, Optimize, and Deploy SOTA AI Models" authors = ["Zeta Team "] license = "MIT" @@ -35,6 +35,7 @@ tqdm = "4.66.2" rich = "13.7.1" colt5-attention = "*" argparse = "^1.4.0" +local-attention = "*" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/zeta/nn/attention/__init__.py b/zeta/nn/attention/__init__.py index 1f55a15c..563c96a2 100644 --- a/zeta/nn/attention/__init__.py +++ b/zeta/nn/attention/__init__.py @@ -22,10 +22,7 @@ from zeta.nn.attention.spatial_linear_attention import SpatialLinearAttention from zeta.structs.transformer import Attention, AttentionLayers from zeta.nn.attention.multi_grouped_attn import MultiGroupedQueryAttn - -# from zeta.nn.attention.flash_attention2 import FlashAttentionTwo -# from zeta.nn.attention.mgqa import MGQA - +from zeta.nn.attention.scalable_img_self_attn import ScalableImgSelfAttention __all__ = [ "Attend", @@ -48,4 +45,5 @@ "Attention", "AttentionLayers", "MultiGroupedQueryAttn", + "ScalableImgSelfAttention", ] diff --git a/zeta/nn/attention/scalable_img_self_attn.py b/zeta/nn/attention/scalable_img_self_attn.py new file mode 100644 index 00000000..7a885c01 --- /dev/null +++ b/zeta/nn/attention/scalable_img_self_attn.py @@ -0,0 +1,129 @@ +import torch +from torch import nn, Tensor +from zeta.nn.modules.chan_layer_norm import ChanLayerNorm +from einops import rearrange + + +class ScalableImgSelfAttention(nn.Module): + """ + ScalableImgSelfAttention module applies self-attention mechanism to image data. + + Args: + dim (int): The input dimension of the image. + heads (int, optional): The number of attention heads. Defaults to 8. + dim_key (int, optional): The dimension of the key vectors. Defaults to 32. + dim_value (int, optional): The dimension of the value vectors. Defaults to 32. + dropout (float, optional): The dropout rate. Defaults to 0.0. + reduction_factor (int, optional): The reduction factor for downscaling the image. Defaults to 1. + + Attributes: + dim (int): The input dimension of the image. + heads (int): The number of attention heads. + dim_key (int): The dimension of the key vectors. + dim_value (int): The dimension of the value vectors. + reduction_factor (int): The reduction factor for downscaling the image. + scale (float): The scaling factor for the key vectors. + attend (nn.Softmax): The softmax function for attention calculation. + dropout (nn.Dropout): The dropout layer. + norm (ChanLayerNorm): The channel-wise layer normalization. + to_q (nn.Conv2d): The convolutional layer for query projection. + to_k (nn.Conv2d): The convolutional layer for key projection. + to_v (nn.Conv2d): The convolutional layer for value projection. + to_out (nn.Sequential): The sequential layer for output projection. + + """ + + def __init__( + self, + dim: int, + heads: int = 8, + dim_key: int = 32, + dim_value: int = 32, + dropout: float = 0.0, + reduction_factor: int = 1, + *args, + **kwargs, + ): + super().__init__() + self.dim = dim + self.heads = heads + self.dim_key = dim_key + self.dim_value = dim_value + self.reduction_factor = reduction_factor + + self.scale = dim_key**-0.5 + self.attend = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + self.norm = ChanLayerNorm(dim) + + # Projections + self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias=False) + self.to_k = nn.Conv2d( + dim, + dim_key * heads, + reduction_factor, + stride=reduction_factor, + bias=False, + ) + self.to_v = nn.Conv2d( + dim, + dim_value * heads, + reduction_factor, + stride=reduction_factor, + bias=False, + ) + + self.to_out = nn.Sequential( + nn.Conv2d(dim_value * heads, dim, 1), nn.Dropout(dropout) + ) + + def forward(self, x: Tensor) -> Tensor: + """ + Forward pass of the ScalableImgSelfAttention module. + + Args: + x (Tensor): The input tensor of shape (batch_size, channels, height, width). + + Returns: + Tensor: The output tensor of shape (batch_size, channels, height, width). + + """ + h, w, h = *x.shape[-2:], self.heads + + x = self.norm(x) + + q, k, v = self.to_q(x), self.to_k(x), self.to_v(x) + + # Split out heads + q, k, v = map( + lambda t: rearrange(t, "b (h d) ... -> b h (...) d", h=h), + ( + q, + k, + ), + ) + + # Similarity + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + # Attention + attn = self.attend(dots) + attn = self.dropout(attn) + + # Aggregate values + out = torch.matmul(attn, v) + + # Merge back heads + out = rearrange( + out, + "b h (x y) d -> b (h d) x y", + x=h, + y=w, + ) + return self.to_out(out) + + +# x = torch.randn(1, 3, 64, 64) +# peg = ScalableImgSelfAttention(3) +# out = peg(x) +# print(out.shape) diff --git a/zeta/nn/modules/__init__.py b/zeta/nn/modules/__init__.py index 52491e3d..f8fcc0be 100644 --- a/zeta/nn/modules/__init__.py +++ b/zeta/nn/modules/__init__.py @@ -195,6 +195,20 @@ NormalSparseMoE, HeirarchicalSparseMoE, ) +from zeta.nn.modules.return_loss_text import ( + return_loss_text, + calc_z_loss, + max_neg_value, + TextTokenEmbedding, + dropout_seq, + transformer_generate, +) +from zeta.nn.modules.patch_linear_flatten import ( + vit_output_head, + patch_linear_flatten, +) +from zeta.nn.modules.chan_layer_norm import ChanLayerNorm + # from zeta.nn.modules.img_reshape import image_reshape # from zeta.nn.modules.flatten_features import flatten_features @@ -392,4 +406,14 @@ "Top2Gating", "NormalSparseMoE", "HeirarchicalSparseMoE", + "return_loss_text", + "calc_z_loss", + "max_neg_value", + "TextTokenEmbedding", + "dropout_seq", + "transformer_generate", + "patch_linear_flatten", + "vit_output_head", + "posemb_sincos_2d", + "ChanLayerNorm", ] diff --git a/zeta/nn/modules/chan_layer_norm.py b/zeta/nn/modules/chan_layer_norm.py new file mode 100644 index 00000000..72c835d9 --- /dev/null +++ b/zeta/nn/modules/chan_layer_norm.py @@ -0,0 +1,37 @@ +import torch +from torch import nn, Tensor + + +class ChanLayerNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + """ + Initializes the ChanLayerNorm module. + + Args: + dim (int): The input dimension. + eps (float, optional): The epsilon value. Defaults to 1e-5. + """ + super().__init__() + self.dim = dim + self.eps = eps + self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) + self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) + + def forward(self, x: Tensor): + """ + Forward pass of the ChanLayerNorm module. + + Args: + x (Tensor): The input tensor. + + Returns: + Tensor: The normalized tensor. + """ + var = torch.car( + x, + dim=1, + unbiased=False, + keepdim=True, + ) + mean = torch.mean(x, dim=1, keepdim=True) + return (x - mean) / (var + self.eps).sqrt() * self.g + self.b diff --git a/zeta/nn/modules/patch_linear_flatten.py b/zeta/nn/modules/patch_linear_flatten.py new file mode 100644 index 00000000..43fd786a --- /dev/null +++ b/zeta/nn/modules/patch_linear_flatten.py @@ -0,0 +1,88 @@ +import torch +from torch import nn, Tensor +from einops.layers.torch import Rearrange + + +def posemb_sincos_2d(patches, temperature=10000, dtype=torch.float32): + _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype + + y, x = torch.meshgrid( + torch.arange(h, device=device), + torch.arange(w, device=device), + indexing="ij", + ) + assert ( + dim % 4 + ) == 0, "feature dimension must be multiple of 4 for sincos emb" + omega = torch.arange(dim // 4, device=device) / (dim // 4 - 1) + omega = 1.0 / (temperature**omega) + + y = y.flatten()[:, None] * omega[None, :] + x = x.flatten()[:, None] * omega[None, :] + pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) + return pe.type(dtype) + + +def vit_output_head(x: Tensor, dim: int, num_classes: int = None): + """ + Applies a Vision Transformer (ViT) output head to the input tensor. + + Args: + x (Tensor): The input tensor. + dim (int): The dimension of the input tensor. + num_classes (int, optional): The number of output classes. Defaults to None. + + Returns: + Tensor: The output tensor after applying the ViT output head. + """ + return nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))(x) + + +def patch_linear_flatten( + x: Tensor, + patch_size: int, + dim: int, + image_size: int, + channels: int = 3, + add_pos_embeddings: bool = False, + *args, + **kwargs, +): + """ + Applies patch embedding to the input tensor and flattens it. + + Args: + x (Tensor): Input tensor of shape (batch_size, channels, image_height, image_width). + patch_size (int): Size of the square patch. + dim (int): Dimension of the output tensor. + image_size (int): Size of the input image (assumed to be square). + channels (int, optional): Number of input channels. Defaults to 3. + add_pos_embeddings (bool, optional): Whether to add positional embeddings. Defaults to False. + + Returns: + Tensor: Flattened tensor of shape (batch_size, num_patches, dim). + """ + image_height, image_width = image_size, image_size + patch_height, patch_width = patch_size, patch_size + + # calculate number of patches + (image_height // patch_height) * (image_width // patch_width) + patch_dim = channels * patch_height * patch_width + + # Patch Embedding layer + to_patch_embeddings = nn.Sequential( + Rearrange( + "b c (h p1) (w p2) -> b h w (p1 p2 c)", + p1=patch_height, + p2=patch_width, + ), + nn.LayerNorm(patch_dim), + nn.Linear(patch_dim, dim), + nn.LayerNorm(dim), + )(x) + + if add_pos_embeddings is not False: + pos_embeddings = posemb_sincos_2d(x, *args, **kwargs) + to_patch_embeddings + +pos_embeddings + + return to_patch_embeddings diff --git a/zeta/nn/modules/peg.py b/zeta/nn/modules/peg.py new file mode 100644 index 00000000..c1f18287 --- /dev/null +++ b/zeta/nn/modules/peg.py @@ -0,0 +1,34 @@ +from torch import nn, Tensor + + +class PEG(nn.Module): + """ + PEG (Positional Encoding Generator) module. + + Args: + dim (int): The input dimension. + kernel_size (int, optional): The size of the convolutional kernel. Defaults to 3. + """ + + def __init__(self, dim: int, kernel_size: int = 3): + super().__init__() + self.proj = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + groups=dim, + stride=1, + ) + + def forward(self, x: Tensor): + """ + Forward pass of the PEG module. + + Args: + x (Tensor): The input tensor. + + Returns: + Tensor: The output tensor. + """ + return self.proj(x) + x diff --git a/zeta/nn/modules/return_loss_text.py b/zeta/nn/modules/return_loss_text.py new file mode 100644 index 00000000..7a8dd132 --- /dev/null +++ b/zeta/nn/modules/return_loss_text.py @@ -0,0 +1,196 @@ +import torch +from einops import rearrange +import torch.nn.functional as F +from torch import Tensor +from torch import nn +from zeta.structs.auto_regressive_wrapper import AutoRegressiveWrapper +from typing import List +from einops import reduce + + +def exists(val): + return val is not None + + +def return_loss_text( + x: Tensor, logits: Tensor, labels: Tensor, ignore_index, mask: Tensor +): + """ + Computes the cross-entropy loss between the predicted logits and the target labels. + + Args: + logits (Tensor): The predicted logits of shape (batch_size, num_classes, sequence_length). + labels (Tensor): The target labels of shape (batch_size, sequence_length). + ignore_index (int): The index to ignore when computing the loss. + + Returns: + Tensor: The computed cross-entropy loss. + """ + seq, labels = x[:, :-1], x[:, 1:] + + labels = labels.masked_fill(~mask[:, 1:], ignore_index) + + loss = F.cross_entropy( + rearrange(logits, "b n c -> b c n"), labels, ignore_index=ignore_index + ) + + return loss + + +def add_masking_llm(x: Tensor, mask: Tensor, ignore_index: int): + """ + Adds masking to the input tensor. + + Args: + x (Tensor): The input tensor. + ignore_index (int): The index to ignore. + + Returns: + Tensor: The masked input tensor. + """ + ... + + +def calc_z_loss( + pre_softmax_attns: List[Tensor], mask: Tensor = None, weight: float = 1.0 +): + lse = 0.0 + + for attn in pre_softmax_attns: + lse = lse + attn.logsumexp(dim=-1) + + loss = torch.square(lse) + loss = reduce(loss, "b h n -> b n", "sum") + + if not exists(mask): + return loss.mean() * weight + + loss = loss[mask].sum() / mask.sum().clamp(min=1e-5) + return loss * weight + + +def max_neg_value(tensor: Tensor): + return -torch.finfo(tensor.dtype).max + + +def l2norm(x: Tensor, groups: int = 1): + """ + Applies L2 normalization to the input tensor. + + Args: + x (Tensor): The input tensor to be normalized. + groups (int, optional): The number of groups to divide the input tensor into. Defaults to 1. + + Returns: + Tensor: The normalized tensor. + + """ + x = rearrange(x, "... (g d) -> ... g d", g=groups) + x = F.normalize(x, p=2, dim=-1) + return rearrange(x, "... g d -> ... (g d)") + + +class TextTokenEmbedding(nn.Module): + def __init__( + self, + dim: int, + num_tokens: int, + l2norm_embed: bool = True, + ): + """ + Initializes a TextTokenEmbedding module. + + Args: + dim (int): The dimension of the embedding. + num_tokens (int): The number of tokens in the vocabulary. + l2norm_embed (bool, optional): Whether to apply L2 normalization to the embeddings. Defaults to True. + """ + super().__init__() + self.dim = dim + self.num_tokens = num_tokens + self.l2norm_embed = l2norm_embed + self.embed = nn.Embedding(num_tokens, dim) + + def forward(self, x: Tensor): + """ + Forward pass of the TextTokenEmbedding module. + + Args: + x (Tensor): The input tensor of shape (batch_size, sequence_length). + + Returns: + Tensor: The embedded tensor of shape (batch_size, sequence_length, dim). + """ + token_embed = self.embed(x.long()) + return l2norm(token_embed) if self.l2norm_embed else token_embed + + +def dropout_seq(seq: Tensor, mask: Tensor, dropout: float = 0.0): + """ + Applies dropout to a sequence of tensors. + + Args: + seq (Tensor): The input sequence tensor of shape (batch_size, sequence_length, ...). + mask (Tensor): The mask tensor of shape (batch_size, sequence_length) indicating which elements to keep. + dropout (float, optional): The dropout probability. Defaults to 0. + + Returns: + Tuple[Tensor, Tensor]: A tuple containing the modified sequence tensor and the modified mask tensor. + + """ + b, n, *_, device = *seq.shape, seq.device + logits = torch.randn(b, n, device=device) + + if exists(mask): + mask_value = max_neg_value(logits) + logits = logits.masked_fill(~mask, mask_value) + + keep_prob = 1.0 - dropout + num_keep = max(1, int(keep_prob * n)) + keep_indices = logits.topk(num_keep, dim=1).indices + + batch_indices = torch.arange(b, device=device) + batch_indices = rearrange(batch_indices, "b -> b 1") + + seq = seq[batch_indices, keep_indices] + + if exists(mask): + seq_counts = mask.sum(dim=-1) + seq_keep_counts = torch.ceil(seq_counts * keep_prob).int() + keep_mask = torch.arange(num_keep, device=device) < rearrange( + seq_keep_counts, "b -> b 1" + ) + + mask = mask[batch_indices, keep_indices] & keep_mask + + return seq, mask + + +@torch.no_grad() +def transformer_generate( + model: nn.Module, + prompt: Tensor, + temperature: float = 0.5, + filter_threshold: float = 0.9, + *args, + **kwargs, +): + """ + Generates text given a prompt. + + Args: + model (nn.Module): The model to generate text. + prompt (Tensor): The prompt tensor. + + Returns: + Tensor: The generated text. + """ + model = AutoRegressiveWrapper(net=model) + + return model.generate( + prompt, + filter_thres=filter_threshold, + temperature=temperature, + *args, + **kwargs, + ) diff --git a/zeta/structs/auto_regressive_wrapper.py b/zeta/structs/auto_regressive_wrapper.py index a7df7879..3f77cbb5 100644 --- a/zeta/structs/auto_regressive_wrapper.py +++ b/zeta/structs/auto_regressive_wrapper.py @@ -1,16 +1,16 @@ import torch import torch.nn.functional as F from einops import pack, rearrange, unpack -from torch import nn +from torch import Tensor, nn -from zeta.utils.main import once # noqa: F401 from zeta.utils.main import ( eval_decorator, exists, + once, # noqa: F401 top_a, top_k, top_p, -) # noqa: E402 +) # Utils @@ -86,7 +86,7 @@ def contrastive_guidance(self, logits, k): return torch.multinomial(F.softmax(top_k_logits, dim=-1), 1) -class AutoregressiveWrapper(nn.Module): +class AutoRegressiveWrapper(nn.Module): """ Auto-regressive wrapper for any nn.Module that takes in a sequence of @@ -114,11 +114,11 @@ class AutoregressiveWrapper(nn.Module): def __init__( self, - net, - ignore_index=-100, - pad_value=0, - mask_prob=0.0, - speculative=False, + net: nn.Module, + ignore_index: int = -100, + pad_value: int = 0, + mask_prob: float = 0.0, + speculative: bool = False, ): super().__init__() self.pad_value = pad_value @@ -138,7 +138,7 @@ def __init__( def generate( self, start_tokens, - seq_len, + seq_len: int, eos_token=None, strategy="temperature", temperature=1.0, @@ -352,3 +352,11 @@ def evaluate_and_select_best_solution( def grade_solution(self, solution): """Grade a solution.""" + ... + return self.net(solution) + + def majority_voting(self, task: Tensor): + """ + Majority voting. + """ + ...