Skip to content

Commit

Permalink
Sentence transformers 3.3.1 (#1628)
Browse files Browse the repository at this point in the history
  • Loading branch information
yafshar authored Jan 27, 2025
1 parent 3c251d0 commit 8f930cb
Show file tree
Hide file tree
Showing 6 changed files with 690 additions and 161 deletions.
9 changes: 2 additions & 7 deletions optimum/habana/sentence_transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ def adapt_sentence_transformers_to_gaudi():
Replaces some SentenceTransformer' methods for equivalent methods optimized
for Gaudi.
"""

from sentence_transformers import SentenceTransformer
from sentence_transformers.data_collator import SentenceTransformerDataCollator
from sentence_transformers.models import Transformer

from optimum.habana.sentence_transformers import (
st_gaudi_data_collator_call,
Expand All @@ -30,12 +31,6 @@ def adapt_sentence_transformers_to_gaudi():
)

SentenceTransformer.encode = st_gaudi_encode

from sentence_transformers.models import Transformer

Transformer.tokenize = st_gaudi_transformer_tokenize
Transformer.save = st_gaudi_transformer_save

from sentence_transformers.data_collator import SentenceTransformerDataCollator

SentenceTransformerDataCollator.__call__ = st_gaudi_data_collator_call
42 changes: 25 additions & 17 deletions optimum/habana/sentence_transformers/st_gaudi_data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,47 +5,55 @@


def st_gaudi_data_collator_call(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
"""data collator for sentence transformer"""
"""Collator for a SentenceTransformers model."""

columns = list(features[0].keys())
column_names = list(features[0].keys())

# We should always be able to return a loss, label or not:
batch = {"return_loss": True}

if "dataset_name" in columns:
columns.remove("dataset_name")
if "dataset_name" in column_names:
column_names.remove("dataset_name")
batch["dataset_name"] = features[0]["dataset_name"]

if tuple(column_names) not in self._warned_columns:
self.maybe_warn_about_column_order(column_names)

# Extract the label column if it exists
for label_column in self.valid_label_columns:
if label_column in columns:
if label_column in column_names:
batch["label"] = torch.tensor([row[label_column] for row in features])
columns.remove(label_column)
column_names.remove(label_column)
break

# Extract the feature columns
cnt = 0
cnt1 = 0
power2_len = [0, 0]
for column in columns:
tokenized = self.tokenize_fn([row[column] for row in features])
for column_name in column_names:
# If the prompt length has been set, we should add it to the batch
if column_name.endswith("_prompt_length") and column_name[: -len("_prompt_length")] in column_names:
batch[column_name] = torch.tensor([row[column_name] for row in features], dtype=torch.int)
continue

tokenized = self.tokenize_fn([row[column_name] for row in features])
for key, value in tokenized.items():
curr_tokenize_len = value.shape
if curr_tokenize_len[1] > 4096:
power2_len[cnt % 2] = math.ceil(curr_tokenize_len[1] / 128) * 128
additional_pad_len = math.ceil(curr_tokenize_len[1] / 128) * 128 - curr_tokenize_len[1]
power2_len[cnt1] = math.ceil(curr_tokenize_len[1] / 128) * 128
else:
power2_len[cnt % 2] = 2 ** math.ceil(math.log2(curr_tokenize_len[1]))
additional_pad_len = 2 ** math.ceil(math.log2(curr_tokenize_len[1])) - curr_tokenize_len[1]

if (cnt % 2 == 1) and (power2_len[0] == power2_len[1]):
additional_pad_len = additional_pad_len + 1
power2_len[cnt1] = 2 ** math.ceil(math.log2(curr_tokenize_len[1]))
additional_pad_len = power2_len[cnt1] - curr_tokenize_len[1]
if (cnt1 == 1) and (power2_len[0] == power2_len[1]):
additional_pad_len += 1

batch[f"{column}_{key}"] = torch.cat(
batch[f"{column_name}_{key}"] = torch.cat(
(
value,
torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8),
),
-1,
)
cnt = cnt + 1
cnt += 1
cnt1 = cnt & 1
return batch
23 changes: 13 additions & 10 deletions optimum/habana/sentence_transformers/st_gaudi_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@

import numpy as np
import torch
from numpy import ndarray
from sentence_transformers.quantization import quantize_embeddings
from sentence_transformers.util import (
batch_to_device,
truncate_embeddings,
)
from torch import Tensor
from tqdm.autonotebook import trange


Expand All @@ -24,14 +22,15 @@ def st_gaudi_encode(
prompt_name: Optional[str] = None,
prompt: Optional[str] = None,
batch_size: int = 32,
show_progress_bar: bool = None,
show_progress_bar: Optional[bool] = None,
output_value: Optional[Literal["sentence_embedding", "token_embeddings"]] = "sentence_embedding",
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
device: str = None,
device: Optional[str] = None,
normalize_embeddings: bool = False,
) -> Union[List[Tensor], ndarray, Tensor]:
**kwargs,
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
"""
Computes sentence embeddings.
Expand Down Expand Up @@ -63,7 +62,7 @@ def st_gaudi_encode(
the faster dot-product (util.dot_score) instead of cosine similarity can be used. Defaults to False.
Returns:
Union[List[Tensor], ndarray, Tensor]: By default, a 2d numpy array with shape [num_inputs, output_dimension] is returned.
Union[List[torch.Tensor], np.ndarray, torch.Tensor]: By default, a 2d numpy array with shape [num_inputs, output_dimension] is returned.
If only one string input is provided, then the output is a 1d array with shape [output_dimension]. If ``convert_to_tensor``,
a torch Tensor is returned instead. If ``self.truncate_dim <= output_dimension`` then output_dimension is ``self.truncate_dim``.
Expand All @@ -85,9 +84,10 @@ def st_gaudi_encode(
print(embeddings.shape)
# (3, 768)
"""

self.eval()
if show_progress_bar is None:
show_progress_bar = logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG
show_progress_bar = logger.getEffectiveLevel() in (logging.INFO, logging.DEBUG)

if convert_to_tensor:
convert_to_numpy = False
Expand Down Expand Up @@ -119,6 +119,7 @@ def st_gaudi_encode(
"Encode with either a `prompt`, a `prompt_name`, or neither, but not both. "
"Ignoring the `prompt_name` in favor of `prompt`."
)

extra_features = {}
if prompt is not None:
sentences = [prompt + sentence for sentence in sentences]
Expand All @@ -132,14 +133,15 @@ def st_gaudi_encode(
if device is None:
device = self.device

self.to(device)

all_embeddings = []
length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences])
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]

for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
sentences_batch = sentences_sorted[start_index : start_index + batch_size]
features = self.tokenize(sentences_batch)

if self.device.type == "hpu":
if "input_ids" in features:
curr_tokenize_len = features["input_ids"].shape
Expand All @@ -166,11 +168,12 @@ def st_gaudi_encode(
),
-1,
)

features = batch_to_device(features, device)
features.update(extra_features)

with torch.no_grad():
out_features = self.forward(features)
out_features = self.forward(features, **kwargs)
if self.device.type == "hpu":
out_features = copy.deepcopy(out_features)

Expand Down Expand Up @@ -218,7 +221,7 @@ def st_gaudi_encode(
all_embeddings = torch.Tensor()
elif convert_to_numpy:
if not isinstance(all_embeddings, np.ndarray):
if all_embeddings[0].dtype == torch.bfloat16:
if all_embeddings and all_embeddings[0].dtype == torch.bfloat16:
all_embeddings = np.asarray([emb.float().numpy() for emb in all_embeddings])
else:
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
Expand Down
Loading

0 comments on commit 8f930cb

Please sign in to comment.