From 8f930cbd6f12c53db70129aafae0939c332da320 Mon Sep 17 00:00:00 2001 From: Yaser Afshar Date: Mon, 27 Jan 2025 09:23:36 -0800 Subject: [PATCH] Sentence transformers 3.3.1 (#1628) --- .../sentence_transformers/modeling_utils.py | 9 +- .../st_gaudi_data_collator.py | 42 +- .../sentence_transformers/st_gaudi_encoder.py | 23 +- .../sentence_transformers/st_gaudi_trainer.py | 742 +++++++++++++++--- .../st_gaudi_training_args.py | 33 +- setup.py | 2 +- 6 files changed, 690 insertions(+), 161 deletions(-) diff --git a/optimum/habana/sentence_transformers/modeling_utils.py b/optimum/habana/sentence_transformers/modeling_utils.py index 7690483e39..6be3065c32 100644 --- a/optimum/habana/sentence_transformers/modeling_utils.py +++ b/optimum/habana/sentence_transformers/modeling_utils.py @@ -19,8 +19,9 @@ def adapt_sentence_transformers_to_gaudi(): Replaces some SentenceTransformer' methods for equivalent methods optimized for Gaudi. """ - from sentence_transformers import SentenceTransformer + from sentence_transformers.data_collator import SentenceTransformerDataCollator + from sentence_transformers.models import Transformer from optimum.habana.sentence_transformers import ( st_gaudi_data_collator_call, @@ -30,12 +31,6 @@ def adapt_sentence_transformers_to_gaudi(): ) SentenceTransformer.encode = st_gaudi_encode - - from sentence_transformers.models import Transformer - Transformer.tokenize = st_gaudi_transformer_tokenize Transformer.save = st_gaudi_transformer_save - - from sentence_transformers.data_collator import SentenceTransformerDataCollator - SentenceTransformerDataCollator.__call__ = st_gaudi_data_collator_call diff --git a/optimum/habana/sentence_transformers/st_gaudi_data_collator.py b/optimum/habana/sentence_transformers/st_gaudi_data_collator.py index 25e015fe24..51e823e1ae 100644 --- a/optimum/habana/sentence_transformers/st_gaudi_data_collator.py +++ b/optimum/habana/sentence_transformers/st_gaudi_data_collator.py @@ -5,47 +5,55 @@ def st_gaudi_data_collator_call(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: - """data collator for sentence transformer""" + """Collator for a SentenceTransformers model.""" - columns = list(features[0].keys()) + column_names = list(features[0].keys()) # We should always be able to return a loss, label or not: batch = {"return_loss": True} - if "dataset_name" in columns: - columns.remove("dataset_name") + if "dataset_name" in column_names: + column_names.remove("dataset_name") batch["dataset_name"] = features[0]["dataset_name"] + if tuple(column_names) not in self._warned_columns: + self.maybe_warn_about_column_order(column_names) + # Extract the label column if it exists for label_column in self.valid_label_columns: - if label_column in columns: + if label_column in column_names: batch["label"] = torch.tensor([row[label_column] for row in features]) - columns.remove(label_column) + column_names.remove(label_column) break # Extract the feature columns cnt = 0 + cnt1 = 0 power2_len = [0, 0] - for column in columns: - tokenized = self.tokenize_fn([row[column] for row in features]) + for column_name in column_names: + # If the prompt length has been set, we should add it to the batch + if column_name.endswith("_prompt_length") and column_name[: -len("_prompt_length")] in column_names: + batch[column_name] = torch.tensor([row[column_name] for row in features], dtype=torch.int) + continue + + tokenized = self.tokenize_fn([row[column_name] for row in features]) for key, value in tokenized.items(): curr_tokenize_len = value.shape if curr_tokenize_len[1] > 4096: - power2_len[cnt % 2] = math.ceil(curr_tokenize_len[1] / 128) * 128 - additional_pad_len = math.ceil(curr_tokenize_len[1] / 128) * 128 - curr_tokenize_len[1] + power2_len[cnt1] = math.ceil(curr_tokenize_len[1] / 128) * 128 else: - power2_len[cnt % 2] = 2 ** math.ceil(math.log2(curr_tokenize_len[1])) - additional_pad_len = 2 ** math.ceil(math.log2(curr_tokenize_len[1])) - curr_tokenize_len[1] - - if (cnt % 2 == 1) and (power2_len[0] == power2_len[1]): - additional_pad_len = additional_pad_len + 1 + power2_len[cnt1] = 2 ** math.ceil(math.log2(curr_tokenize_len[1])) + additional_pad_len = power2_len[cnt1] - curr_tokenize_len[1] + if (cnt1 == 1) and (power2_len[0] == power2_len[1]): + additional_pad_len += 1 - batch[f"{column}_{key}"] = torch.cat( + batch[f"{column_name}_{key}"] = torch.cat( ( value, torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8), ), -1, ) - cnt = cnt + 1 + cnt += 1 + cnt1 = cnt & 1 return batch diff --git a/optimum/habana/sentence_transformers/st_gaudi_encoder.py b/optimum/habana/sentence_transformers/st_gaudi_encoder.py index db253953db..df8d06956c 100644 --- a/optimum/habana/sentence_transformers/st_gaudi_encoder.py +++ b/optimum/habana/sentence_transformers/st_gaudi_encoder.py @@ -5,13 +5,11 @@ import numpy as np import torch -from numpy import ndarray from sentence_transformers.quantization import quantize_embeddings from sentence_transformers.util import ( batch_to_device, truncate_embeddings, ) -from torch import Tensor from tqdm.autonotebook import trange @@ -24,14 +22,15 @@ def st_gaudi_encode( prompt_name: Optional[str] = None, prompt: Optional[str] = None, batch_size: int = 32, - show_progress_bar: bool = None, + show_progress_bar: Optional[bool] = None, output_value: Optional[Literal["sentence_embedding", "token_embeddings"]] = "sentence_embedding", precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32", convert_to_numpy: bool = True, convert_to_tensor: bool = False, - device: str = None, + device: Optional[str] = None, normalize_embeddings: bool = False, -) -> Union[List[Tensor], ndarray, Tensor]: + **kwargs, +) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]: """ Computes sentence embeddings. @@ -63,7 +62,7 @@ def st_gaudi_encode( the faster dot-product (util.dot_score) instead of cosine similarity can be used. Defaults to False. Returns: - Union[List[Tensor], ndarray, Tensor]: By default, a 2d numpy array with shape [num_inputs, output_dimension] is returned. + Union[List[torch.Tensor], np.ndarray, torch.Tensor]: By default, a 2d numpy array with shape [num_inputs, output_dimension] is returned. If only one string input is provided, then the output is a 1d array with shape [output_dimension]. If ``convert_to_tensor``, a torch Tensor is returned instead. If ``self.truncate_dim <= output_dimension`` then output_dimension is ``self.truncate_dim``. @@ -85,9 +84,10 @@ def st_gaudi_encode( print(embeddings.shape) # (3, 768) """ + self.eval() if show_progress_bar is None: - show_progress_bar = logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG + show_progress_bar = logger.getEffectiveLevel() in (logging.INFO, logging.DEBUG) if convert_to_tensor: convert_to_numpy = False @@ -119,6 +119,7 @@ def st_gaudi_encode( "Encode with either a `prompt`, a `prompt_name`, or neither, but not both. " "Ignoring the `prompt_name` in favor of `prompt`." ) + extra_features = {} if prompt is not None: sentences = [prompt + sentence for sentence in sentences] @@ -132,6 +133,8 @@ def st_gaudi_encode( if device is None: device = self.device + self.to(device) + all_embeddings = [] length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences]) sentences_sorted = [sentences[idx] for idx in length_sorted_idx] @@ -139,7 +142,6 @@ def st_gaudi_encode( for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): sentences_batch = sentences_sorted[start_index : start_index + batch_size] features = self.tokenize(sentences_batch) - if self.device.type == "hpu": if "input_ids" in features: curr_tokenize_len = features["input_ids"].shape @@ -166,11 +168,12 @@ def st_gaudi_encode( ), -1, ) + features = batch_to_device(features, device) features.update(extra_features) with torch.no_grad(): - out_features = self.forward(features) + out_features = self.forward(features, **kwargs) if self.device.type == "hpu": out_features = copy.deepcopy(out_features) @@ -218,7 +221,7 @@ def st_gaudi_encode( all_embeddings = torch.Tensor() elif convert_to_numpy: if not isinstance(all_embeddings, np.ndarray): - if all_embeddings[0].dtype == torch.bfloat16: + if all_embeddings and all_embeddings[0].dtype == torch.bfloat16: all_embeddings = np.asarray([emb.float().numpy() for emb in all_embeddings]) else: all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) diff --git a/optimum/habana/sentence_transformers/st_gaudi_trainer.py b/optimum/habana/sentence_transformers/st_gaudi_trainer.py index ccbd2e1fb2..f7d73d231c 100644 --- a/optimum/habana/sentence_transformers/st_gaudi_trainer.py +++ b/optimum/habana/sentence_transformers/st_gaudi_trainer.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2022 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,17 +12,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import os -import warnings +from collections import OrderedDict from contextlib import nullcontext +from functools import partial from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch from accelerate.utils import DistributedDataParallelKwargs +from packaging.version import parse as parse_version from sentence_transformers.data_collator import SentenceTransformerDataCollator from sentence_transformers.evaluation import SentenceEvaluator, SequentialEvaluator from sentence_transformers.losses.CoSENTLoss import CoSENTLoss +from sentence_transformers.model_card import ModelCardCallback +from sentence_transformers.models import Pooling from sentence_transformers.models.Transformer import Transformer from sentence_transformers.sampler import ( DefaultBatchSampler, @@ -38,21 +44,19 @@ from sentence_transformers.util import disable_logging, is_datasets_available from torch.utils.data import BatchSampler, ConcatDataset, DataLoader, SubsetRandomSampler from transformers import EvalPrediction, PreTrainedTokenizerBase, TrainerCallback +from transformers import __version__ as transformers_version from transformers.data.data_collator import DataCollator from transformers.integrations import WandbCallback -from transformers.modeling_utils import unwrap_model -from transformers.trainer import TRAINING_ARGS_NAME +from transformers.trainer import TRAINING_ARGS_NAME, _is_peft_model from transformers.trainer_utils import EvalLoopOutput from transformers.training_args import ParallelMode -from optimum.habana.transformers.trainer import _is_peft_model - from ..transformers import GaudiConfig, GaudiTrainer from .st_gaudi_training_args import SentenceTransformerGaudiTrainingArguments if is_datasets_available(): - from datasets import Dataset, DatasetDict + from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict, Value logger = logging.getLogger(__name__) @@ -62,33 +66,58 @@ class SentenceTransformerGaudiTrainer(GaudiTrainer): """ - Inherits from GaudiTrainer and adapted from: https://github.com/UKPLab/sentence-transformers/blob/v3.0.1/sentence_transformers/trainer.py + SentenceTransformerGaudiTrainer is a simple but feature-complete training and eval loop for PyTorch + based on the 🤗 Transformers :class:`~transformers.Trainer`. + + It inherits from GaudiTrainer and adapted from: + https://github.com/UKPLab/sentence-transformers/blob/v3.3.1/sentence_transformers/trainer.py """ def __init__( self, - model: Optional["SentenceTransformer"] = None, - gaudi_config: GaudiConfig = None, - args: SentenceTransformerGaudiTrainingArguments = None, - train_dataset: Optional[Union["Dataset", "DatasetDict", Dict[str, "Dataset"]]] = None, - eval_dataset: Optional[Union["Dataset", "DatasetDict", Dict[str, "Dataset"]]] = None, + model: Optional[SentenceTransformer] = None, + gaudi_config: Optional[GaudiConfig] = None, + args: Optional[SentenceTransformerGaudiTrainingArguments] = None, + train_dataset: Optional[Union[Dataset, DatasetDict, IterableDataset, Dict[str, Dataset]]] = None, + eval_dataset: Optional[Union[Dataset, DatasetDict, IterableDataset, Dict[str, Dataset]]] = None, loss: Optional[ Union[ torch.nn.Module, Dict[str, torch.nn.Module], - Callable[["SentenceTransformer"], torch.nn.Module], - Dict[str, Callable[["SentenceTransformer"], torch.nn.Module]], + Callable[[SentenceTransformer], torch.nn.Module], + Dict[str, Callable[[SentenceTransformer], torch.nn.Module]], ] ] = None, evaluator: Optional[Union[SentenceEvaluator, List[SentenceEvaluator]]] = None, data_collator: Optional[DataCollator] = None, tokenizer: Optional[Union[PreTrainedTokenizerBase, Callable]] = None, - model_init: Optional[Callable[[], "SentenceTransformer"]] = None, + model_init: Optional[Callable[[], SentenceTransformer]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, callbacks: Optional[List[TrainerCallback]] = None, - optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + optimizers: Tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, ) -> None: + if not is_datasets_available(): + raise RuntimeError( + "To train a SentenceTransformerGaudiTrainer model, you need to install the `datasets` module. " + "To fix: pip install datasets" + ) + + if args is None: + output_dir = "tmp_trainer" + logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.") + args = SentenceTransformerGaudiTrainingArguments( + output_dir=output_dir, + use_habana=True, + gaudi_config_name="Habana/distilbert-base-uncased", + use_lazy_mode=True, + use_hpu_graphs=True, + use_hpu_graphs_for_inference=False, + use_hpu_graphs_for_training=True, + ) + elif not isinstance(args, SentenceTransformerGaudiTrainingArguments): + raise ValueError("Please use `TrainingArguments` imported from `optimum.habana.sentence_transformers`.") + if model is None: if model_init is not None: self.model_init = model_init @@ -97,16 +126,33 @@ def __init__( raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument") else: if model_init is not None: - warnings.warn( + logger.warning( "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will" - " overwrite your model when calling the `train` method. This will become a fatal error in the next" - " release.", - FutureWarning, + " overwrite your model when calling the `train` method." ) self.model_init = model_init - # If the model ID is set via the SentenceTransformerTrainingArguments, but not via the SentenceTransformerModelCardData, - # then we can set it here for the model card regardless + if compute_metrics is not None: + logger.warning( + "`compute_metrics` is currently not compatible with the SentenceTransformerGaudiTrainer. Please use the " + "`evaluator` argument instead for detailed evaluation metrics, or the `eval_dataset` argument for " + "the evaluation loss." + ) + + # Get a dictionary of the default training arguments, so we can determine which arguments have been changed + # for the model card + default_args_dict = SentenceTransformerGaudiTrainingArguments( + output_dir="unused", + use_habana=True, + gaudi_config_name="Habana/distilbert-base-uncased", + use_lazy_mode=True, + use_hpu_graphs=True, + use_hpu_graphs_for_inference=False, + use_hpu_graphs_for_training=True, + ).to_dict() + + # If the model ID is set via the SentenceTransformerGaudiTrainingArguments, but not via the + # SentenceTransformerModelCardData, then we can set it here for the model card regardless if args.hub_model_id and not model.model_card_data.model_id: model.model_card_data.set_model_id(args.hub_model_id) @@ -116,30 +162,57 @@ def __init__( if data_collator is None: data_collator = SentenceTransformerDataCollator(tokenize_fn=model.tokenize) + for dataset_name, dataset in zip(["train", "eval"], [train_dataset, eval_dataset]): + if isinstance(dataset, IterableDataset) and dataset.column_names is None: + sample = next(iter(dataset)) + naive_type_mapping = {str: "string", int: "int64", float: "float32", bool: "bool"} + example_features = { + key: Value(naive_type_mapping.get(type(value), "null")) for key, value in sample.items() + } + raise ValueError( + f"The provided `{dataset_name}_dataset` must have Features. Specify them with e.g.:\n" + f"{dataset_name}_dataset = {dataset_name}_dataset.cast(Features({example_features}))\n" + "or by providing the Features to the IterableDataset initialization method. See the Datasets " + "documentation for more information on dataset Features: " + "https://huggingface.co/docs/datasets/en/about_dataset_features" + ) + if isinstance(train_dataset, dict) and not isinstance(train_dataset, DatasetDict): train_dataset = DatasetDict(train_dataset) - if isinstance(eval_dataset, dict) and not isinstance(eval_dataset, Dataset): + if isinstance(eval_dataset, dict) and not isinstance(eval_dataset, DatasetDict): eval_dataset = DatasetDict(eval_dataset) + super_kwargs = { + "model": None if self.model_init else model, + "gaudi_config": gaudi_config, + "args": args, + "data_collator": data_collator, + "train_dataset": train_dataset, + "eval_dataset": eval_dataset if eval_dataset is not None or evaluator is None else "dummy", + "model_init": model_init, + "compute_metrics": compute_metrics, + "callbacks": callbacks, + "optimizers": optimizers, + "preprocess_logits_for_metrics": preprocess_logits_for_metrics, + } + # Transformers v4.46.0 changed the `tokenizer` argument to a more general `processing_class` argument + if parse_version(transformers_version) >= parse_version("4.46.0"): + super_kwargs["processing_class"] = tokenizer + else: + super_kwargs["tokenizer"] = tokenizer + super().__init__(**super_kwargs) - super().__init__( - model=None if self.model_init else model, - gaudi_config=gaudi_config, - args=args, - data_collator=data_collator, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - tokenizer=tokenizer, - model_init=model_init, - compute_metrics=compute_metrics, - callbacks=callbacks, - optimizers=optimizers, - preprocess_logits_for_metrics=preprocess_logits_for_metrics, - ) + # Transformers v4.46.0 introduced a ValueError if `eval_dataset` is None while eval_strategy is not "no", + # but in Sentence Transformers you can also evaluate without an eval_dataset via an evaluator, so we set + # it to "dummy" in that case to avoid the ValueError + if self.eval_dataset == "dummy": + self.eval_dataset = None # Every Sentence Transformer model can always return a loss, so we set this to True # to avoid having to specify it in the data collator or model's forward self.can_return_loss = True + self._prompt_length_mapping = {} + self.model: SentenceTransformer self.args: SentenceTransformerGaudiTrainingArguments self.data_collator: SentenceTransformerDataCollator @@ -167,18 +240,49 @@ def __init__( ) else: self.loss = self.prepare_loss(loss, model) + # If evaluator is a list, we wrap it in a SequentialEvaluator if evaluator is not None and not isinstance(evaluator, SentenceEvaluator): evaluator = SequentialEvaluator(evaluator) self.evaluator = evaluator + if self.train_dataset is not None: + self.train_dataset = self.maybe_add_prompts_or_dataset_name_column( + train_dataset, args.prompts, dataset_name="train" + ) + if self.eval_dataset is not None: + self.eval_dataset = self.maybe_add_prompts_or_dataset_name_column( + eval_dataset, args.prompts, dataset_name="eval" + ) + self.add_model_card_callback(default_args_dict) + + def add_model_card_callback(self, default_args_dict: dict[str, Any]) -> None: + """ + Add a callback responsible for automatically tracking data required for the automatic model card generation + + This method is called in the ``__init__`` method of the + :class:`~sentence_transformers.trainer.SentenceTransformerTrainer` class. + + Args: + default_args_dict (Dict[str, Any]): A dictionary of the default training arguments, so we can determine + which arguments have been changed for the model card. + + .. note:: + + This method can be overriden by subclassing the trainer to remove/customize this callback in custom uses cases + """ + + model_card_callback = ModelCardCallback(self, default_args_dict) + self.add_callback(model_card_callback) + model_card_callback.on_init_end(self.args, self.state, self.control, self.model) + def _wrap_model(self, model, training=True, dataloader=None): """ Differs from GaudiTrainer._wrap_model: - `allow_unused_input=True` was added to `ht.hpu.ModuleCacher()` """ # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again - if unwrap_model(model) is not model: + if self.accelerator.unwrap_model(model) is not model: return model # Note: in torch.distributed mode, there's no point in wrapping the model @@ -216,7 +320,7 @@ def _wrap_model(self, model, training=True, dataloader=None): return model - def call_model_init(self, trial=None) -> "SentenceTransformer": + def call_model_init(self, trial=None) -> SentenceTransformer: model = super().call_model_init(trial=trial) # If the Trainer already has a loss, then we'll want to override the model in the loss function if not hasattr(self, "loss"): @@ -241,7 +345,7 @@ def call_model_init(self, trial=None) -> "SentenceTransformer": self.loss = self.override_model_in_loss(self.loss, model) return model - def override_model_in_loss(self, loss: torch.nn.Module, model: "SentenceTransformer") -> torch.nn.Module: + def override_model_in_loss(self, loss: torch.nn.Module, model: SentenceTransformer) -> torch.nn.Module: from sentence_transformers import SentenceTransformer for name, child in loss.named_children(): @@ -255,14 +359,14 @@ def override_model_in_loss(self, loss: torch.nn.Module, model: "SentenceTransfor def prepare_loss( self, - loss: Union[Callable[["SentenceTransformer"], torch.nn.Module], torch.nn.Module], - model: "SentenceTransformer", + loss: Union[Callable[[SentenceTransformer], torch.nn.Module], torch.nn.Module], + model: SentenceTransformer, ) -> torch.nn.Module: if isinstance(loss, torch.nn.Module): return loss.to(model.device) return loss(model).to(model.device) - def add_dataset_name_column(self, dataset_dict: "DatasetDict") -> "DatasetDict": + def add_dataset_name_column(self, dataset_dict: DatasetDict) -> DatasetDict: for key, dataset in dataset_dict.items(): if "dataset_name" not in dataset.column_names: dataset_dict[key] = dataset.add_column("dataset_name", [key] * len(dataset)) @@ -270,9 +374,10 @@ def add_dataset_name_column(self, dataset_dict: "DatasetDict") -> "DatasetDict": def compute_loss( self, - model: "SentenceTransformer", + model: SentenceTransformer, inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs: bool = False, + num_items_in_batch: Optional[int] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, Any]]]: """ Computes the loss for the SentenceTransformer model. @@ -287,6 +392,7 @@ def compute_loss( model (SentenceTransformer): The SentenceTransformer model. inputs (Dict[str, Union[torch.Tensor, Any]]): The input data for the model. return_outputs (bool, optional): Whether to return the outputs along with the loss. Defaults to False. + num_items_in_batch (int, optional): The number of items in the batch. Defaults to None. Unused, but required by the transformers Trainer. Returns: Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, Any]]]: The computed loss. If `return_outputs` is True, returns a tuple of loss and outputs. Otherwise, returns only the loss. @@ -300,7 +406,6 @@ def compute_loss( # Insert the wrapped (e.g. distributed or compiled) model into the loss function, # if the loss stores the model. Only called once per process - # from https://github.com/UKPLab/sentence-transformers/blob/v3.1.0/sentence_transformers/trainer.py#L337 if ( model == self.model_wrapped and model != self.model # Only if the model is wrapped @@ -312,7 +417,7 @@ def compute_loss( if return_outputs: # During prediction/evaluation, `compute_loss` will be called with `return_outputs=True`. # However, Sentence Transformer losses do not return outputs, so we return an empty dictionary. - # This does not result in any problems, as the SentenceTransformerTrainingArguments sets + # This does not result in any problems, as the SentenceTransformerGaudiTrainingArguments sets # `prediction_loss_only=True` which means that the output is not used. return loss, {} return loss @@ -354,13 +459,16 @@ def collect_features( def evaluate( self, - eval_dataset: Optional[Union["Dataset", Dict[str, "Dataset"]]] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval", ) -> Dict[str, float]: - eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset - if isinstance(eval_dataset, DatasetDict) and isinstance(self.loss, dict): - eval_dataset = self.add_dataset_name_column(eval_dataset) + if eval_dataset is not None: + eval_dataset = self.maybe_add_prompts_or_dataset_name_column( + eval_dataset, self.args.prompts, dataset_name="eval" + ) + else: + eval_dataset = self.eval_dataset return super().evaluate(eval_dataset, ignore_keys, metric_key_prefix) def evaluation_loop( @@ -420,7 +528,7 @@ def _load_best_model(self) -> None: except Exception: pass - # Override the model with the `tranformers`-based auto_model, and restore the original SentenceTransformers + # Override the model with the `transformers`-based auto_model, and restore the original SentenceTransformers # model with the loaded `transformers` model full_model = self.model self.model = self.model[0].auto_model @@ -431,7 +539,12 @@ def _load_best_model(self) -> None: self.model = full_model self.model[0].auto_model = loaded_auto_model - def validate_column_names(self, dataset: "Dataset", dataset_name: Optional[str] = None) -> bool: + def validate_column_names(self, dataset: Dataset, dataset_name: Optional[str] = None) -> None: + if isinstance(dataset, dict): + for dataset_name, dataset in dataset.items(): + self.validate_column_names(dataset, dataset_name=dataset_name) + return + if overlap := set(dataset.column_names) & {"return_loss", "dataset_name"}: raise ValueError( f"The following column names are invalid in your {dataset_name + ' ' if dataset_name else ''}dataset: {list(overlap)}." @@ -440,12 +553,36 @@ def validate_column_names(self, dataset: "Dataset", dataset_name: Optional[str] def get_batch_sampler( self, - dataset: "Dataset", + dataset: Dataset, batch_size: int, drop_last: bool, valid_label_columns: Optional[List[str]] = None, generator: Optional[torch.Generator] = None, - ) -> BatchSampler: + ) -> Optional[BatchSampler]: + """ + Returns the appropriate batch sampler based on the ``batch_sampler`` argument in ``self.args``. + This batch sampler class supports ``__len__`` and ``__iter__`` methods, and is used as the ``batch_sampler`` + to create the :class:`torch.utils.data.DataLoader`. + + .. note:: + Override this method to provide a custom batch sampler. + + Args: + dataset (Dataset): The dataset to sample from. + batch_size (int): Number of samples per batch. + drop_last (bool): If True, drop the last incomplete batch if the dataset size + is not divisible by the batch size. + valid_label_columns (List[str]): List of column names to check for labels. + The first column name from ``valid_label_columns`` found in the dataset will + be used as the label column. + generator (torch.Generator, optional): Optional random number generator for shuffling + the indices. + """ + if isinstance(dataset, IterableDataset): + if self.args.batch_sampler != BatchSamplers.BATCH_SAMPLER: + logger.warning("When using an IterableDataset, you cannot specify a batch sampler.") + return None + if self.args.batch_sampler == BatchSamplers.NO_DUPLICATES: return NoDuplicatesBatchSampler( dataset=dataset, @@ -473,10 +610,24 @@ def get_batch_sampler( def get_multi_dataset_batch_sampler( self, dataset: ConcatDataset, - batch_samplers: List[BatchSampler], + batch_samplers: list[BatchSampler], generator: Optional[torch.Generator] = None, seed: Optional[int] = 0, ) -> BatchSampler: + """ + Returns the appropriate multi-dataset batch sampler based on the ``multi_dataset_batch_sampler`` argument + in ``self.args``. This batch sampler class supports ``__len__`` and ``__iter__`` methods, and is used as the + ``batch_sampler`` to create the :class:`torch.utils.data.DataLoader`. + + .. note:: + Override this method to provide a custom multi-dataset batch sampler. + + Args: + dataset (ConcatDataset): The concatenation of all datasets. + batch_samplers (List[BatchSampler]): List of batch samplers for each dataset in the concatenated dataset. + generator (torch.Generator, optional): Optional random number generator for shuffling the indices. + seed (int, optional): Optional seed for the random number generator + """ if self.args.multi_dataset_batch_sampler == MultiDatasetBatchSamplers.ROUND_ROBIN: return RoundRobinBatchSampler( dataset=dataset, @@ -503,7 +654,7 @@ def get_train_dataloader(self) -> DataLoader: Subclass and override this method if you want to inject some custom behavior. """ if self.train_dataset is None: - raise ValueError("Trainer: training requires a train_dataset.") + raise ValueError("Training requires specifying a train_dataset to the SentenceTransformerGaudiTrainer.") train_dataset = self.train_dataset data_collator = self.data_collator @@ -512,15 +663,40 @@ def get_train_dataloader(self) -> DataLoader: if self.args.seed: generator.manual_seed(self.args.seed) - if isinstance(train_dataset, DatasetDict): - for dataset_name, dataset in train_dataset.items(): - self.validate_column_names(dataset, dataset_name=dataset_name) - if isinstance(self.loss, dict): - train_dataset = self.add_dataset_name_column(train_dataset) + dataloader_params = { + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + "prefetch_factor": self.args.dataloader_prefetch_factor, + } + + if isinstance(train_dataset, IterableDataset): + dataloader_params.update( + { + "batch_size": self.args.train_batch_size, + "drop_last": self.args.dataloader_drop_last, + } + ) + if self.args.batch_sampler != BatchSamplers.BATCH_SAMPLER: + logger.warning("When using an IterableDataset, you cannot specify a batch sampler.") + + elif isinstance(train_dataset, IterableDatasetDict): + raise ValueError( + "Sentence Transformers is not compatible with IterableDatasetDict. Please use a DatasetDict instead." + ) + + elif isinstance(train_dataset, DatasetDict): + for dataset in train_dataset.values(): + if isinstance(dataset, IterableDataset): + raise ValueError( + "Sentence Transformers is not compatible with a DatasetDict containing an IterableDataset." + ) + batch_samplers = [ self.get_batch_sampler( dataset, - batch_size=self.args.per_device_train_batch_size, + batch_size=self.args.train_batch_size, drop_last=self.args.dataloader_drop_last, valid_label_columns=data_collator.valid_label_columns, generator=generator, @@ -535,10 +711,9 @@ def get_train_dataloader(self) -> DataLoader: generator=generator, seed=self.args.seed, ) + dataloader_params["batch_sampler"] = batch_sampler - else: - self.validate_column_names(train_dataset) - + elif isinstance(train_dataset, Dataset): batch_sampler = self.get_batch_sampler( train_dataset, batch_size=self.args.train_batch_size, @@ -546,15 +721,11 @@ def get_train_dataloader(self) -> DataLoader: valid_label_columns=data_collator.valid_label_columns, generator=generator, ) - - dataloader_params = { - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "persistent_workers": self.args.dataloader_persistent_workers, - "prefetch_factor": self.args.dataloader_prefetch_factor, - "batch_sampler": batch_sampler, - } + dataloader_params["batch_sampler"] = batch_sampler + else: + raise ValueError( + "Unsupported `train_dataset` type. Use a Dataset, DatasetDict, or IterableDataset for training." + ) # If 'even_batches' is True, it will use the initial few samples to pad out the last sample. This can # cause issues with multi-dataset training, so we want to set this to False. @@ -563,7 +734,9 @@ def get_train_dataloader(self) -> DataLoader: self._train_dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) return self._train_dataloader - def get_eval_dataloader(self, eval_dataset: Union["Dataset", None] = None) -> DataLoader: + def get_eval_dataloader( + self, eval_dataset: Optional[Union[Dataset, DatasetDict, IterableDataset]] = None + ) -> DataLoader: """ Returns the evaluation [`~torch.utils.data.DataLoader`]. @@ -578,7 +751,8 @@ def get_eval_dataloader(self, eval_dataset: Union["Dataset", None] = None) -> Da # Prevent errors if the evaluator is set but no eval_dataset is provided if self.evaluator is not None: return DataLoader([]) - raise ValueError("Trainer: evaluation requires an eval_dataset.") + raise ValueError("Evaluation requires specifying an eval_dataset to the SentenceTransformerGaudiTrainer.") + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset data_collator = self.data_collator @@ -586,14 +760,37 @@ def get_eval_dataloader(self, eval_dataset: Union["Dataset", None] = None) -> Da if self.args.seed: generator.manual_seed(self.args.seed) - # TODO: Correctly validate the column names for the eval_dataset - if isinstance(eval_dataset, DatasetDict): - if isinstance(self.loss, dict): - eval_dataset = self.add_dataset_name_column(eval_dataset) + dataloader_params = { + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + "prefetch_factor": self.args.dataloader_prefetch_factor, + } + if isinstance(eval_dataset, IterableDataset): + dataloader_params.update( + { + "batch_size": self.args.eval_batch_size, + "drop_last": self.args.dataloader_drop_last, + } + ) + + elif isinstance(eval_dataset, IterableDatasetDict): + raise ValueError( + "Sentence Transformers is not compatible with IterableDatasetDict. Please use a DatasetDict instead." + ) + + elif isinstance(eval_dataset, DatasetDict): + for dataset in eval_dataset.values(): + if isinstance(dataset, IterableDataset): + raise ValueError( + "Sentence Transformers is not compatible with a DatasetDict containing an IterableDataset." + ) + batch_samplers = [ self.get_batch_sampler( dataset, - batch_size=self.args.per_device_eval_batch_size, + batch_size=self.args.eval_batch_size, drop_last=self.args.dataloader_drop_last, valid_label_columns=data_collator.valid_label_columns, generator=generator, @@ -608,23 +805,22 @@ def get_eval_dataloader(self, eval_dataset: Union["Dataset", None] = None) -> Da generator=generator, seed=self.args.seed, ) - else: + dataloader_params["batch_sampler"] = batch_sampler + + elif isinstance(eval_dataset, Dataset): batch_sampler = self.get_batch_sampler( eval_dataset, - batch_size=self.args.train_batch_size, + batch_size=self.args.eval_batch_size, drop_last=self.args.dataloader_drop_last, valid_label_columns=data_collator.valid_label_columns, generator=generator, ) + dataloader_params["batch_sampler"] = batch_sampler - dataloader_params = { - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "persistent_workers": self.args.dataloader_persistent_workers, - "prefetch_factor": self.args.dataloader_prefetch_factor, - "batch_sampler": batch_sampler, - } + else: + raise ValueError( + "Unsupported `eval_dataset` type. Use a Dataset, DatasetDict, or IterableDataset for evaluation." + ) # If 'even_batches' is True, it will use the initial few samples to pad out the last sample. This can # cause issues with multi-dataset training, so we want to set this to False during training. @@ -632,7 +828,7 @@ def get_eval_dataloader(self, eval_dataset: Union["Dataset", None] = None) -> Da self.accelerator.even_batches = True return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) - def get_test_dataloader(self, test_dataset: "Dataset") -> DataLoader: + def get_test_dataloader(self, test_dataset: Union[Dataset, DatasetDict, IterableDataset]) -> DataLoader: """ Returns the training [`~torch.utils.data.DataLoader`]. @@ -649,15 +845,38 @@ def get_test_dataloader(self, test_dataset: "Dataset") -> DataLoader: if self.args.seed: generator.manual_seed(self.args.seed) - if isinstance(test_dataset, DatasetDict): - for dataset_name, dataset in test_dataset.items(): - self.validate_column_names(dataset, dataset_name=dataset_name) - if isinstance(self.loss, dict): - test_dataset = self.add_dataset_name_column(test_dataset) + dataloader_params = { + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + "prefetch_factor": self.args.dataloader_prefetch_factor, + } + + if isinstance(test_dataset, IterableDataset): + dataloader_params.update( + { + "batch_size": self.args.eval_batch_size, + "drop_last": self.args.dataloader_drop_last, + } + ) + + elif isinstance(test_dataset, IterableDatasetDict): + raise ValueError( + "Sentence Transformers is not compatible with IterableDatasetDict. Please use a DatasetDict instead." + ) + + elif isinstance(test_dataset, DatasetDict): + for dataset in test_dataset.values(): + if isinstance(dataset, IterableDataset): + raise ValueError( + "Sentence Transformers is not compatible with a DatasetDict containing an IterableDataset." + ) + batch_samplers = [ self.get_batch_sampler( dataset, - batch_size=self.args.per_device_train_batch_size, + batch_size=self.args.eval_batch_size, drop_last=self.args.dataloader_drop_last, valid_label_columns=data_collator.valid_label_columns, generator=generator, @@ -672,33 +891,28 @@ def get_test_dataloader(self, test_dataset: "Dataset") -> DataLoader: generator=generator, seed=self.args.seed, ) + dataloader_params["batch_sampler"] = batch_sampler - else: - self.validate_column_names(test_dataset) - + elif isinstance(test_dataset, Dataset): batch_sampler = self.get_batch_sampler( test_dataset, - batch_size=self.args.train_batch_size, + batch_size=self.args.eval_batch_size, drop_last=self.args.dataloader_drop_last, valid_label_columns=data_collator.valid_label_columns, generator=generator, ) + dataloader_params["batch_sampler"] = batch_sampler - dataloader_params = { - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "persistent_workers": self.args.dataloader_persistent_workers, - "prefetch_factor": self.args.dataloader_prefetch_factor, - "batch_sampler": batch_sampler, - } + else: + raise ValueError( + "Unsupported `test_dataset` type. Use a Dataset, DatasetDict, or IterableDataset for testing." + ) # If 'even_batches' is True, it will use the initial few samples to pad out the last sample. This can - # cause issues with multi-dataset training, so we want to set this to False. - # For evaluation, setting 'even_batches' to False results in hanging, so we keep it as True there. - self.accelerator.even_batches = False - self._train_dataloader = self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params)) - return self._train_dataloader + # cause issues with multi-dataset training, so we want to set this to False during training. + # For evaluation, setting 'even_batches' to False results in hanging, so we keep it as True here. + self.accelerator.even_batches = True + return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params)) def _save(self, output_dir: Optional[str] = None, state_dict=None) -> None: # If we are executing this function, we are the process zero, so we don't check for that. @@ -708,8 +922,13 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None) -> None: self.model.save_pretrained(output_dir, safe_serialization=self.args.save_safetensors) - if self.tokenizer is not None: - self.tokenizer.save_pretrained(output_dir) + # Transformers v4.46.0 changed the `tokenizer` attribute to a more general `processing_class` attribute + if parse_version(transformers_version) >= parse_version("4.46.0"): + if self.processing_class is not None: + self.processing_class.save_pretrained(output_dir) + else: + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) # Good practice: save your training arguments together with the trained model torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) @@ -717,20 +936,257 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None) -> None: def _load_from_checkpoint(self, checkpoint_path: str) -> None: from sentence_transformers import SentenceTransformer - loaded_model = SentenceTransformer(checkpoint_path) + loaded_model = SentenceTransformer(checkpoint_path, trust_remote_code=self.model.trust_remote_code) self.model.load_state_dict(loaded_model.state_dict()) + def _get_prompt_length(self, prompt: str) -> int: + try: + return self._prompt_length_mapping[prompt] + except KeyError: + prompt_length = self.model.tokenize([prompt])["input_ids"].shape[-1] - 1 + self._prompt_length_mapping[prompt] = prompt_length + return prompt_length + + def _include_prompt_length(self) -> bool: + """ + Return whether the prompt length should be passed to the model's forward method. + + True if the model does not include the prompt in the pooling layer. Can be + overridden by the user if it's useful to include the prompt length. + """ + for module in self.model: + if isinstance(module, Pooling): + return not module.include_prompt + return False + + @staticmethod + def add_prompts_or_dataset_name_transform( + batch: Dict[str, List[Any]], + prompts: Optional[Union[Dict[str, str], str]] = None, + prompt_lengths: Optional[Union[Dict[str, int], int]] = None, + dataset_name: Optional[str] = None, + transform: Optional[Callable[[Dict[str, List[Any]]], Dict[str, List[Any]]]] = None, + **kwargs, + ) -> Dict[str, List[Any]]: + """A transform/map function that adds prompts or dataset names to the batch. + + Args: + batch (dict[str, list[Any]]): The batch of data, where each key is a column name and each value + is a list of values. + prompts (dict[str, str] | str | None, optional): An optional mapping of column names to string + prompts, or a string prompt for all columns. Defaults to None. + prompt_lengths (dict[str, int] | int | None, optional): An optional mapping of prompts names to + prompt token length, or a prompt token length if the prompt is a string. Defaults to None. + dataset_name (str | None, optional): The name of this dataset, only if there are multiple datasets + that use a different loss. Defaults to None. + transform (Callable[[dict[str, list[Any]]], dict[str, list[Any]]], optional): An optional transform + function to apply on the batch before adding prompts, etc. Defaults to None. + + Returns: + dict[str, list[Any]]: The "just-in-time" transformed batch with prompts and/or dataset names added. + """ + # If the dataset is a Dataset(Dict), then we use set_transform and we want to also apply any + # previous transform if it exists + if transform: + batch = transform(batch) + + # Return if the batch has no columns... + if not batch: + return batch + + # ... or if it's empty + first_column = list(batch.keys())[0] + if not batch[first_column]: + return batch + + # Apply one prompt to all columns... + if isinstance(prompts, str): + for column_name, column in list(batch.items()): + if isinstance(column[0], str): + batch[column_name] = [prompts + value for value in column] + + if prompt_lengths is not None: + batch[f"{column_name}_prompt_length"] = [prompt_lengths] * len(column) + + # ... or a column-specific prompt + if isinstance(prompts, dict): + for column_name, prompt in prompts.items(): + if column_name in batch: + batch[column_name] = [prompt + value for value in batch[column_name]] + + if prompt_lengths: + batch[f"{column_name}_prompt_length"] = [prompt_lengths[prompt]] * len(batch[column_name]) + + # If we have multiple losses, then we need to add the dataset name to the batch + if dataset_name: + batch["dataset_name"] = [dataset_name] * len(batch[first_column]) + + return batch + + def maybe_add_prompts_or_dataset_name_column( + self, + dataset_dict: Union[DatasetDict, Dataset, None], + prompts: Optional[Union[Dict[str, Dict[str, str]], Dict[str, str], str]] = None, + dataset_name: Optional[str] = None, + ) -> Union[DatasetDict, Dataset, None]: + """ + Maybe add prompts or dataset names to the dataset. We add the dataset_name column to the dataset if: + + 1. The loss is a dictionary and the dataset is a DatasetDict, or + 2. The prompts contain a mapping to dataset names. + + There are 4 cases for the prompts: + + 1. `str`: One prompt for all datasets and columns. + 2. `dict[str, str]`: A column to prompt mapping. + 3. `dict[str, str]`: A dataset to prompt mapping. + 4. `dict[str, dict[str, str]]`: A dataset to column to prompt mapping. + + And 2 cases for the dataset: + + A. `Dataset`: A single dataset. + B. `DatasetDict`: A dictionary of datasets. + + 3A is not allowed, and 2A doesn't make sense. + + Args: + dataset_dict (DatasetDict | Dataset | None): The dataset to add prompts or dataset names to. + + Returns: + DatasetDict | Dataset | None: The dataset with prompts or dataset names added. + """ + if dataset_dict is None: + return None + + include_dataset_name = isinstance(self.loss, dict) + + # If we've already added the transform to this (iterable) dataset, don't add it again + if hasattr(dataset_dict, "_sentence_transformers_preprocessed"): + return dataset_dict + + # Ensure that there's no "dataset_name"/"return_loss" columns in the unprocessed datasets + self.validate_column_names(dataset_dict, dataset_name=dataset_name) + + # Only add if 1) we have prompts or 2) we need the dataset name for the loss dictionary + if prompts or include_dataset_name: + include_prompt_lengths = self._include_prompt_length() + dataset_dict = self.add_prompts_or_dataset_name_column( + dataset_dict, + prompts=prompts, + include_prompt_lengths=include_prompt_lengths, + include_dataset_name=include_dataset_name, + ) + return dataset_dict + + def add_prompts_or_dataset_name_column( + self, + dataset_dict: Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset], + prompts: Optional[Union[Dict[str, str], str]] = None, + dataset_name: Optional[str] = None, + include_prompt_lengths: bool = False, + include_dataset_name: bool = False, + ) -> Union[DatasetDict, Dataset, None]: + # If we have DatasetDict, recurse + if isinstance(dataset_dict, (IterableDatasetDict, DatasetDict)): + for dataset_name, dataset in dataset_dict.items(): + # If prompts is a dictionary that matches the dataset names, then take the nested prompts + nested_prompts = prompts.get(dataset_name, prompts) if isinstance(prompts, dict) else prompts + dataset_dict[dataset_name] = self.add_prompts_or_dataset_name_column( + dataset_dict=dataset, + prompts=nested_prompts, + dataset_name=dataset_name if include_dataset_name else None, + include_prompt_lengths=include_prompt_lengths, + include_dataset_name=include_dataset_name, + ) + return dataset_dict + + # Get the prompt lengths if needed for the pooling layer + prompt_lengths = None + if prompts: + if isinstance(prompts, str): + if include_prompt_lengths: + prompt_lengths = self._get_prompt_length(prompts) + elif isinstance(prompts, dict): + first_key = list(prompts.keys())[0] + if isinstance(prompts[first_key], dict): + raise ValueError( + "The prompts provided to the trainer are a nested dictionary. In this setting, the first " + "level of the dictionary should map to dataset names and the second level to column names. " + "However, as the provided dataset is a not a DatasetDict, no dataset names can be inferred. " + f"The keys to the provided prompts dictionary are {list(prompts.keys())!r}" + ) + if include_prompt_lengths: + # If prompt columns exist, add the prompt length column + prompt_lengths = { + prompt: self._get_prompt_length(prompt) + for column_name, prompt in prompts.items() + if column_name in dataset_dict.column_names + } + + # If we have a Dataset, we can set the transform directly... + if isinstance(dataset_dict, Dataset): + dataset_dict.set_transform( + partial( + self.add_prompts_or_dataset_name_transform, + prompts=prompts, + prompt_lengths=prompt_lengths, + dataset_name=dataset_name, + **dataset_dict._format_kwargs, + ) + ) + + # ... otherwise, we have an IterableDataset and we need to map it, which performs the same operation as above + elif isinstance(dataset_dict, IterableDataset): + # Update the features to include the new columns + features = dataset_dict.features + if dataset_name: + features["dataset_name"] = Value("string") + if prompt_lengths: + if isinstance(prompts, str): + for column_name in dataset_dict.column_names: + feature = features[column_name] + if isinstance(feature, Value) and feature.dtype in ("string", "large_string"): + features[f"{column_name}_prompt_length"] = Value("int16") + elif isinstance(prompts, dict): + for column_name, prompt in prompts.items(): + feature = features[column_name] + if ( + prompt in prompt_lengths + and isinstance(feature, Value) + and feature.dtype in ("string", "large_string") + ): + features[f"{column_name}_prompt_length"] = Value("int16") + + dataset_dict = dataset_dict.map( + partial( + self.add_prompts_or_dataset_name_transform, + prompts=prompts, + prompt_lengths=prompt_lengths, + dataset_name=dataset_name, + ), + batched=True, + features=features, + ) + + else: + raise ValueError("Unsupported dataset type.") + + # Add a tag to the dataset to indicate that it has been preprocessed, to ensure that we don't apply the map or + # transform multiple times. + dataset_dict._sentence_transformers_preprocessed = True + return dataset_dict + def create_model_card( self, language: Optional[str] = None, license: Optional[str] = None, - tags: Union[str, List[str], None] = None, + tags: Optional[Union[str, List[str]]] = None, model_name: Optional[str] = None, finetuned_from: Optional[str] = None, - tasks: Union[str, List[str], None] = None, - dataset_tags: Union[str, List[str], None] = None, - dataset: Union[str, List[str], None] = None, - dataset_args: Union[str, List[str], None] = None, + tasks: Optional[Union[str, List[str]]] = None, + dataset_tags: Optional[Union[str, List[str]]] = None, + dataset: Optional[Union[str, List[str]]] = None, + dataset_args: Optional[Union[str, List[str]]] = None, **kwargs, ) -> None: if not self.is_world_process_zero(): @@ -744,3 +1200,41 @@ def create_model_card( self.model.model_card_data.add_tags(tags) self.model._create_model_card(self.args.output_dir, model_name=model_name) + + def get_optimizer_cls_and_kwargs( + self, args: SentenceTransformerGaudiTrainingArguments, model: Optional[SentenceTransformer] = None + ) -> Tuple[Any, Any]: + """ + We have to override the optimizer_grouped_parameters because the Trainer superclass bases it on the `model` + itself, but the SentenceTransformer losses can have weights that should be updated as well, e.g. + SoftmaxLoss (see #2872). + + This method requires `transformers` >= 4.43.0. + """ + + if isinstance(self.loss, dict): + loss_model = torch.nn.Sequential(OrderedDict(self.loss)) + else: + loss_model = self.loss + optimizer_cls, optimizer_kwargs = super().get_optimizer_cls_and_kwargs(args, loss_model) + + # If the kwargs were not overridden by the super() call, then we should override them here so that the potential + # weights in the loss(es) can also be updated. + if not {"params", "model", "optimizer_dict"} & set(optimizer_kwargs.keys()): + decay_parameters = self.get_decay_parameter_names(loss_model) + optimizer_kwargs["optimizer_dict"] = [ + { + "params": [ + p for n, p in loss_model.named_parameters() if (n in decay_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + }, + { + "params": [ + p for n, p in loss_model.named_parameters() if (n not in decay_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + + return optimizer_cls, optimizer_kwargs diff --git a/optimum/habana/sentence_transformers/st_gaudi_training_args.py b/optimum/habana/sentence_transformers/st_gaudi_training_args.py index 07f98c3fbc..b47434c10a 100644 --- a/optimum/habana/sentence_transformers/st_gaudi_training_args.py +++ b/optimum/habana/sentence_transformers/st_gaudi_training_args.py @@ -14,7 +14,7 @@ # limitations under the License. import logging from dataclasses import dataclass, field -from typing import Union +from typing import Dict, Optional, Union from sentence_transformers.training_args import BatchSamplers, MultiDatasetBatchSamplers from transformers.training_args import ParallelMode @@ -28,9 +28,38 @@ @dataclass class SentenceTransformerGaudiTrainingArguments(GaudiTrainingArguments): """ - Inherits from GaudiTrainingArguments and adapted from: https://github.com/UKPLab/sentence-transformers/blob/v3.0.1/sentence_transformers/training_args.py + SentenceTransformerGaudiTrainingArguments extends :class:`~transformers.TrainingArguments` with additional arguments + specific to Sentence Transformers. See :class:`~transformers.TrainingArguments` for the complete list of + available arguments. + + It inherits from GaudiTrainingArguments and adapted from: + https://github.com/UKPLab/sentence-transformers/blob/v3.3.1/sentence_transformers/training_args.py + + Args: + output_dir (`str`): + The output directory where the model checkpoints will be written. + prompts (`Union[Dict[str, Dict[str, str]], Dict[str, str], str]`, *optional*): + The prompts to use for each column in the training, evaluation and test datasets. Four formats are accepted: + + 1. `str`: A single prompt to use for all columns in the datasets, regardless of whether the training/evaluation/test + datasets are :class:`datasets.Dataset` or a :class:`datasets.DatasetDict`. + 2. `Dict[str, str]`: A dictionary mapping column names to prompts, regardless of whether the training/evaluation/test + datasets are :class:`datasets.Dataset` or a :class:`datasets.DatasetDict`. + 3. `Dict[str, str]`: A dictionary mapping dataset names to prompts. This should only be used if your training/evaluation/test + datasets are a :class:`datasets.DatasetDict` or a dictionary of :class:`datasets.Dataset`. + 4. `Dict[str, Dict[str, str]]`: A dictionary mapping dataset names to dictionaries mapping column names to + prompts. This should only be used if your training/evaluation/test datasets are a + :class:`datasets.DatasetDict` or a dictionary of :class:`datasets.Dataset`. + + batch_sampler (Union[:class:`~sentence_transformers.training_args.BatchSamplers`, `str`], *optional*): + The batch sampler to use. See :class:`~sentence_transformers.training_args.BatchSamplers` for valid options. + Defaults to ``BatchSamplers.BATCH_SAMPLER``. + multi_dataset_batch_sampler (Union[:class:`~sentence_transformers.training_args.MultiDatasetBatchSamplers`, `str`], *optional*): + The multi-dataset batch sampler to use. See :class:`~sentence_transformers.training_args.MultiDatasetBatchSamplers` + for valid options. Defaults to ``MultiDatasetBatchSamplers.PROPORTIONAL``. """ + prompts: Optional[Union[Dict[str, Dict[str, str]], Dict[str, str], str]] = None batch_sampler: Union[BatchSamplers, str] = field( default=BatchSamplers.BATCH_SAMPLER, metadata={"help": "The batch sampler to use."} ) diff --git a/setup.py b/setup.py index 0bb36466ee..2fc42f7711 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ "accelerate >= 0.33.0, < 0.34.0", "diffusers >= 0.31.0, < 0.32.0", "huggingface_hub >= 0.24.7", - "sentence-transformers == 3.2.1", + "sentence-transformers == 3.3.1", ] TESTS_REQUIRE = [