Skip to content

Commit c279849

Browse files
authored
add disk embedding api (#11585)
1 parent 79c742d commit c279849

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

python/llm/src/ipex_llm/transformers/embedding.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,11 @@ def __init__(self,
8888
sparse: bool = False,
8989
_weight: Optional[Tensor] = None,
9090
_freeze: bool = False,
91-
device=None, dtype=None) -> None:
91+
device=None,
92+
dtype=None):
9293
super().__init__(num_embeddings, embedding_dim, padding_idx,
9394
max_norm, norm_type, scale_grad_by_freq,
94-
sparse, _weight, _freeze, device, dtype)
95+
sparse, _weight, True, device, dtype)
9596
self.filename = "embeddings.bin"
9697
self.weight.data.flatten().half().numpy().tofile(self.filename)
9798
dummy_weight = torch.empty(0, 0, dtype=self.weight.dtype, device=self.weight.device)
@@ -118,6 +119,22 @@ def restore(self):
118119
)
119120
self.weight = torch.nn.Parameter(embeds, requires_grad=False)
120121

122+
@classmethod
123+
def from_embedding(cls, embedding: torch.nn.Embedding):
124+
return cls(
125+
embedding.num_embeddings,
126+
embedding.embedding_dim,
127+
embedding.padding_idx,
128+
embedding.max_norm,
129+
embedding.norm_type,
130+
embedding.scale_grad_by_freq,
131+
embedding.sparse,
132+
embedding.weight.data,
133+
True,
134+
embedding.weight.device,
135+
embedding.weight.dtype,
136+
)
137+
121138

122139
class LowBitEmbedding(torch.nn.Embedding):
123140
def __init__(self,

0 commit comments

Comments
 (0)