Skip to content

Commit

Permalink
Update imports
Browse files Browse the repository at this point in the history
  • Loading branch information
Ram authored and Ram committed Apr 6, 2024
1 parent cb58448 commit d3941ee
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 10 deletions.
4 changes: 2 additions & 2 deletions zeta/models/andromeda.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# the best llm ever made
from torch.nn import Module

from zeta.structs.auto_regressive_wrapper import AutoregressiveWrapper
from zeta.structs.auto_regressive_wrapper import AutoRegressiveWrapper
from zeta.structs.transformer import Decoder, Transformer


Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(
),
)

self.decoder = AutoregressiveWrapper(self.Andromeda)
self.decoder = AutoregRessiveWrapper(self.Andromeda)

except Exception as e:
print("Failed to initialize Andromeda: ", e)
Expand Down
4 changes: 2 additions & 2 deletions zeta/models/gpt4.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch import Tensor, nn

from zeta.structs.auto_regressive_wrapper import AutoregressiveWrapper
from zeta.structs.auto_regressive_wrapper import AutoRegressiveWrapper
from zeta.structs.transformer import (
Decoder,
Encoder,
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(
),
)

self.decoder = AutoregressiveWrapper(self.decoder)
self.decoder = AutoRegressiveWrapper(self.decoder)

except Exception as e:
print("Failed to initialize Andromeda: ", e)
Expand Down
4 changes: 2 additions & 2 deletions zeta/models/llama.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from zeta.structs.auto_regressive_wrapper import AutoregressiveWrapper
from zeta.structs.auto_regressive_wrapper import AutoRegressiveWrapper
from zeta.structs.transformer import Decoder, Transformer


Expand Down Expand Up @@ -28,7 +28,7 @@ def __init__(
rotary_xpos=rotary_xpos,
),
)
self.decoder = AutoregressiveWrapper(self.decoder)
self.decoder = AutoRegressiveWrapper(self.decoder)

def forward(self, text):
model_input = self.decoder.forward(text)[0]
Expand Down
4 changes: 2 additions & 2 deletions zeta/models/palme.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from zeta.structs.auto_regressive_wrapper import AutoregressiveWrapper
from zeta.structs.auto_regressive_wrapper import AutoRegressiveWrapper
from zeta.structs.transformer import (
Decoder,
Encoder,
Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(
),
)

self.decoder = AutoregressiveWrapper(self.decoder)
self.decoder = AutoRegressiveWrapper(self.decoder)

def forward(self, img, text):
try:
Expand Down
4 changes: 2 additions & 2 deletions zeta/structs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from zeta.structs.auto_regressive_wrapper import AutoregressiveWrapper
from zeta.structs.auto_regressive_wrapper import AutoRegressiveWrapper
from zeta.structs.clip_encoder import CLIPVisionTower, build_vision_tower
from zeta.structs.encoder_decoder import EncoderDecoder
from zeta.structs.hierarchical_transformer import (
Expand All @@ -21,7 +21,7 @@
from zeta.structs.transformer_block import TransformerBlock

__all__ = [
"AutoregressiveWrapper",
"AutoRegressiveWrapper",
"Encoder",
"Decoder",
"EncoderDecoder",
Expand Down

0 comments on commit d3941ee

Please sign in to comment.