@@ -88,10 +88,11 @@ def __init__(self,
88
88
sparse : bool = False ,
89
89
_weight : Optional [Tensor ] = None ,
90
90
_freeze : bool = False ,
91
- device = None , dtype = None ) -> None :
91
+ device = None ,
92
+ dtype = None ):
92
93
super ().__init__ (num_embeddings , embedding_dim , padding_idx ,
93
94
max_norm , norm_type , scale_grad_by_freq ,
94
- sparse , _weight , _freeze , device , dtype )
95
+ sparse , _weight , True , device , dtype )
95
96
self .filename = "embeddings.bin"
96
97
self .weight .data .flatten ().half ().numpy ().tofile (self .filename )
97
98
dummy_weight = torch .empty (0 , 0 , dtype = self .weight .dtype , device = self .weight .device )
@@ -118,6 +119,22 @@ def restore(self):
118
119
)
119
120
self .weight = torch .nn .Parameter (embeds , requires_grad = False )
120
121
122
+ @classmethod
123
+ def from_embedding (cls , embedding : torch .nn .Embedding ):
124
+ return cls (
125
+ embedding .num_embeddings ,
126
+ embedding .embedding_dim ,
127
+ embedding .padding_idx ,
128
+ embedding .max_norm ,
129
+ embedding .norm_type ,
130
+ embedding .scale_grad_by_freq ,
131
+ embedding .sparse ,
132
+ embedding .weight .data ,
133
+ True ,
134
+ embedding .weight .device ,
135
+ embedding .weight .dtype ,
136
+ )
137
+
121
138
122
139
class LowBitEmbedding (torch .nn .Embedding ):
123
140
def __init__ (self ,
0 commit comments