Skip to content

Commit

Permalink
Merge branch 'kyegomez:master' into feat/init-experimental-layers
Browse files Browse the repository at this point in the history
  • Loading branch information
dtunai authored Apr 6, 2024
2 parents 3779845 + cb58448 commit f7acd58
Show file tree
Hide file tree
Showing 18 changed files with 1,166 additions and 75 deletions.
148 changes: 148 additions & 0 deletions playground/models/nirvana.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""
Nirvana
Multi grouped query attention + feedforward
"""
import torch
from torch import Tensor, nn

from zeta.nn import FeedForward, OutputHead
from zeta.nn.attention import MultiQueryAttention


class TransformerBlock(nn.Module):
"""
TransformerBlock is a module that represents a single block in a transformer model.
Args:
dim (int): The input dimension of the block.
heads (int): The number of attention heads.
mult (int): The multiplier for the hidden dimension in the feed-forward network.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
"""

def __init__(self, dim: int, heads: int, mult: int, *args, **kwargs):
super().__init__()
self.dim = dim
self.heads = heads
self.mult = mult

# Multi-grouped query attention
self.attn = MultiQueryAttention(dim, heads, *args, **kwargs)

# Ffn
self.ffn = FeedForward(dim, dim, mult, swish=True, post_act_ln=True)

# LayerNorm
self.norm = nn.LayerNorm(dim)

def forward(self, x: Tensor):
"""
Forward pass of the TransformerBlock.
Args:
x (Tensor): The input tensor.
Returns:
Tensor: The output tensor after passing through the TransformerBlock.
"""
skip = x

x = self.norm(x)

# Attn
x, _, _ = self.attn(x)
x + skip

# ffn
skip_two = x

# Ffn
return self.ffn(x) + skip_two


class Nirvna(nn.Module):
"""
A class representing the Nirvna model.
Args:
dim (int): The dimension of the model.
heads (int): The number of attention heads.
mult (int): The multiplier for the hidden dimension in the feed-forward network.
depth (int, optional): The number of transformer blocks. Defaults to 8.
num_tokens (int, optional): The number of tokens in the input vocabulary. Defaults to None.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Attributes:
dim (int): The dimension of the model.
heads (int): The number of attention heads.
mult (int): The multiplier for the hidden dimension in the feed-forward network.
depth (int): The number of transformer blocks.
num_tokens (int): The number of tokens in the input vocabulary.
embed (nn.Embedding): The embedding layer.
layers (nn.ModuleList): The list of transformer blocks.
"""

def __init__(
self,
dim: int,
heads: int,
mult: int,
depth: int = 8,
num_tokens: int = None,
*args,
**kwargs,
):
super().__init__()
self.dim = dim
self.heads = heads
self.mult = mult
self.depth = depth
self.num_tokens = num_tokens

# Embedding
self.embed = nn.Embedding(num_tokens, dim)

# Layers
self.layers = nn.ModuleList(
[
TransformerBlock(dim, heads, mult, *args, **kwargs)
for _ in range(depth)
]
)

def forward(self, x):
"""
Forward pass of the Nirvna model.
Args:
x: The input tensor.
Returns:
The output tensor.
"""
x = self.embed(x)

for layer in self.layers:
x = layer(x)

x = OutputHead(self.dim, -1)(x)
return x


# Forward pass
x = torch.randint(0, 100, (1, 100))


# Model
model = Nirvna(512, 8, 4, 8, 100)

# Forward
y = model(x)
print(y)
Empty file added playground/models/spectra.py
Empty file.
24 changes: 7 additions & 17 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.poetry]
name = "zetascale"
version = "2.2.6"
description = "Transformers at zeta scales"
version = "2.3.5"
description = "Rapidly Build, Optimize, and Deploy SOTA AI Models"
authors = ["Zeta Team <[email protected]>"]
license = "MIT"
readme = "README.md"
Expand All @@ -16,35 +16,26 @@ packages = [
]

[tool.poetry.dependencies]
python = "^3.8"
torch = "2.2.0"
timm = "0.9.16"
torchdiffeq = "0.2.3"
python = "^3.9"
torch = ">=2.1.1,<3.0"
pytest = "8.1.1"
torchfix = "*"
einops = "0.7.0"
bitsandbytes = "0.42.0"
typing = "3.7.4.3"
bitsandbytes = "0.43.0"
transformers = "4.39.1"
einops-exts = "0.0.4"
torchvision = "0.17.0"
accelerate = "0.28.0"
datasets = "*"
lion-pytorch = "0.1.2"
loguru = "*"
sentencepiece = "0.2.0"
vector-quantize-pytorch = "1.14.5"
tokenmonster = "1.1.12"
scipy = "1.9.3"
beartype = "0.17.2"
tiktoken = "0.6.0"
tqdm = "4.66.2"
rich = "13.7.0"
rich = "13.7.1"
colt5-attention = "*"
argparse = "^1.4.0"
skypilot = "0.4.1"
numexpr = "*"

local-attention = "*"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand All @@ -61,7 +52,6 @@ types-chardet = "^5.0.4.6"
mypy-protobuf = "^3.0.0"
pytest = "8.1.1"


[tool.ruff]
line-length = 80

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ loguru
rich==13.7.0
tiktoken==0.6.0
transformers==4.36.0
tqdm==4.66.1
tqdm==4.66.2
mkdocs
mkdocs-material
mkdocs-glightbox
Expand Down
5 changes: 3 additions & 2 deletions zeta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

disable_warnings_and_logs()

from zeta.cloud import * # noqa: F403, E402
# from zeta.cloud import * # noqa: F403, E402
from zeta.models import * # noqa: F403, E402
from zeta.nn import * # noqa: F403, E402
from zeta.ops import * # noqa: F403, E402
from zeta.optim import * # noqa: F403, E402
from zeta.quant import * # noqa: F403, E402
from zeta.rl import * # noqa: F403, E402
from zeta.tokenizers import * # noqa: F403, E402

# from zeta.tokenizers import * # noqa: F403, E402
from zeta.training import * # noqa: F403, E402
from zeta.utils import * # noqa: F403, E402
6 changes: 2 additions & 4 deletions zeta/nn/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@
from zeta.nn.attention.spatial_linear_attention import SpatialLinearAttention
from zeta.structs.transformer import Attention, AttentionLayers
from zeta.nn.attention.multi_grouped_attn import MultiGroupedQueryAttn

# from zeta.nn.attention.flash_attention2 import FlashAttentionTwo
# from zeta.nn.attention.mgqa import MGQA

from zeta.nn.attention.scalable_img_self_attn import ScalableImgSelfAttention

__all__ = [
"Attend",
Expand All @@ -48,4 +45,5 @@
"Attention",
"AttentionLayers",
"MultiGroupedQueryAttn",
"ScalableImgSelfAttention",
]
2 changes: 1 addition & 1 deletion zeta/nn/attention/multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
self,
embed_dim: int = None,
num_heads: int = None,
dropout: int = 0.0,
dropout: float = 0.0,
self_attention: bool = False,
subln: bool = False,
layernorm_eps=1e-05,
Expand Down
129 changes: 129 additions & 0 deletions zeta/nn/attention/scalable_img_self_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import torch
from torch import nn, Tensor
from zeta.nn.modules.chan_layer_norm import ChanLayerNorm
from einops import rearrange


class ScalableImgSelfAttention(nn.Module):
"""
ScalableImgSelfAttention module applies self-attention mechanism to image data.
Args:
dim (int): The input dimension of the image.
heads (int, optional): The number of attention heads. Defaults to 8.
dim_key (int, optional): The dimension of the key vectors. Defaults to 32.
dim_value (int, optional): The dimension of the value vectors. Defaults to 32.
dropout (float, optional): The dropout rate. Defaults to 0.0.
reduction_factor (int, optional): The reduction factor for downscaling the image. Defaults to 1.
Attributes:
dim (int): The input dimension of the image.
heads (int): The number of attention heads.
dim_key (int): The dimension of the key vectors.
dim_value (int): The dimension of the value vectors.
reduction_factor (int): The reduction factor for downscaling the image.
scale (float): The scaling factor for the key vectors.
attend (nn.Softmax): The softmax function for attention calculation.
dropout (nn.Dropout): The dropout layer.
norm (ChanLayerNorm): The channel-wise layer normalization.
to_q (nn.Conv2d): The convolutional layer for query projection.
to_k (nn.Conv2d): The convolutional layer for key projection.
to_v (nn.Conv2d): The convolutional layer for value projection.
to_out (nn.Sequential): The sequential layer for output projection.
"""

def __init__(
self,
dim: int,
heads: int = 8,
dim_key: int = 32,
dim_value: int = 32,
dropout: float = 0.0,
reduction_factor: int = 1,
*args,
**kwargs,
):
super().__init__()
self.dim = dim
self.heads = heads
self.dim_key = dim_key
self.dim_value = dim_value
self.reduction_factor = reduction_factor

self.scale = dim_key**-0.5
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.norm = ChanLayerNorm(dim)

# Projections
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias=False)
self.to_k = nn.Conv2d(
dim,
dim_key * heads,
reduction_factor,
stride=reduction_factor,
bias=False,
)
self.to_v = nn.Conv2d(
dim,
dim_value * heads,
reduction_factor,
stride=reduction_factor,
bias=False,
)

self.to_out = nn.Sequential(
nn.Conv2d(dim_value * heads, dim, 1), nn.Dropout(dropout)
)

def forward(self, x: Tensor) -> Tensor:
"""
Forward pass of the ScalableImgSelfAttention module.
Args:
x (Tensor): The input tensor of shape (batch_size, channels, height, width).
Returns:
Tensor: The output tensor of shape (batch_size, channels, height, width).
"""
h, w, h = *x.shape[-2:], self.heads

x = self.norm(x)

q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)

# Split out heads
q, k, v = map(
lambda t: rearrange(t, "b (h d) ... -> b h (...) d", h=h),
(
q,
k,
),
)

# Similarity
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

# Attention
attn = self.attend(dots)
attn = self.dropout(attn)

# Aggregate values
out = torch.matmul(attn, v)

# Merge back heads
out = rearrange(
out,
"b h (x y) d -> b (h d) x y",
x=h,
y=w,
)
return self.to_out(out)


# x = torch.randn(1, 3, 64, 64)
# peg = ScalableImgSelfAttention(3)
# out = peg(x)
# print(out.shape)
Loading

0 comments on commit f7acd58

Please sign in to comment.