diff --git a/libs/infinity_emb/infinity_emb/transformer/classifier/torch.py b/libs/infinity_emb/infinity_emb/transformer/classifier/torch.py index cf49eb2d..c4ddb4d2 100644 --- a/libs/infinity_emb/infinity_emb/transformer/classifier/torch.py +++ b/libs/infinity_emb/infinity_emb/transformer/classifier/torch.py @@ -1,14 +1,18 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2023-now michaelfeil -from infinity_emb._optional_imports import CHECK_TRANSFORMERS +from infinity_emb._optional_imports import CHECK_TRANSFORMERS, CHECK_TORCH from infinity_emb.args import EngineArgs from infinity_emb.log_handler import logger from infinity_emb.transformer.abstract import BaseClassifer from infinity_emb.transformer.acceleration import to_bettertransformer +from infinity_emb.transformer.quantization.interface import quant_interface +from infinity_emb.primitives import Device if CHECK_TRANSFORMERS.is_available: from transformers import AutoTokenizer, pipeline # type: ignore +if CHECK_TORCH.is_available: + import torch class SentenceClassifier(BaseClassifer): @@ -21,17 +25,21 @@ def __init__( model_kwargs = {} if engine_args.bettertransformer: model_kwargs["attn_implementation"] = "eager" + ls = engine_args._loading_strategy + assert ls is not None + + if ls.loading_dtype is not None: # type: ignore + model_kwargs["torch_dtype"] = ls.loading_dtype + self._pipe = pipeline( task="text-classification", model=engine_args.model_name_or_path, trust_remote_code=engine_args.trust_remote_code, - device=engine_args.device.resolve(), + device=ls.device_placement, top_k=None, revision=engine_args.revision, model_kwargs=model_kwargs, ) - if self._pipe.device.type != "cpu": # and engine_args.dtype == "float16": - self._pipe.model = self._pipe.model.half() self._pipe.model = to_bettertransformer( self._pipe.model, @@ -39,6 +47,15 @@ def __init__( logger, ) + if ls.quantization_dtype is not None: + self._pipe.model = quant_interface( # TODO: add ls.quantization_dtype and ls.placement + self._pipe.model, engine_args.dtype, device=Device[self._pipe.model.device.type] + ) + + if engine_args.compile: + logger.info("using torch.compile(dynamic=True)") + self._pipe.model = torch.compile(self._pipe.model, dynamic=True) + self._infinity_tokenizer = AutoTokenizer.from_pretrained( engine_args.model_name_or_path, revision=engine_args.revision, diff --git a/libs/infinity_emb/infinity_emb/transformer/crossencoder/torch.py b/libs/infinity_emb/infinity_emb/transformer/crossencoder/torch.py index b425f163..6e740d51 100644 --- a/libs/infinity_emb/infinity_emb/transformer/crossencoder/torch.py +++ b/libs/infinity_emb/infinity_emb/transformer/crossencoder/torch.py @@ -9,8 +9,11 @@ from infinity_emb._optional_imports import CHECK_SENTENCE_TRANSFORMERS, CHECK_TORCH from infinity_emb.args import EngineArgs from infinity_emb.log_handler import logger -from infinity_emb.primitives import Dtype +from infinity_emb.primitives import Device from infinity_emb.transformer.abstract import BaseCrossEncoder +from infinity_emb.transformer.quantization.interface import ( + quant_interface, +) if CHECK_TORCH.is_available and CHECK_SENTENCE_TRANSFORMERS.is_available: import torch @@ -42,14 +45,20 @@ def __init__(self, *, engine_args: EngineArgs): if engine_args.bettertransformer: model_kwargs["attn_implementation"] = "eager" + ls = engine_args._loading_strategy + assert ls is not None + + if ls.loading_dtype is not None: # type: ignore + model_kwargs["torch_dtype"] = ls.loading_dtype + super().__init__( engine_args.model_name_or_path, revision=engine_args.revision, - device=engine_args.device.resolve(), # type: ignore trust_remote_code=engine_args.trust_remote_code, + device=ls.device_placement, automodel_args=model_kwargs, ) - self.model.to(self._target_device) # type: ignore + self.model.to(ls.device_placement) # make a copy of the tokenizer, # to be able to could the tokens in another thread @@ -64,12 +73,16 @@ def __init__(self, *, engine_args: EngineArgs): logger, ) - if self._target_device.type == "cuda" and engine_args.dtype in [ - Dtype.auto, - Dtype.float16, - ]: - logger.info("Switching to half() precision (cuda: fp16). ") - self.model.to(dtype=torch.float16) + self.model.to(ls.loading_dtype) + + if ls.quantization_dtype is not None: + self.model = quant_interface( # TODO: add ls.quantization_dtype and ls.placement + self.model, engine_args.dtype, device=Device[self.model.device.type] + ) + + if engine_args.compile: + logger.info("using torch.compile(dynamic=True)") + self.model = torch.compile(self.model, dynamic=True) def encode_pre(self, input_tuples: list[tuple[str, str]]): # return input_tuples @@ -91,7 +104,7 @@ def encode_core(self, features: dict[str, "Tensor"]): return out_features.detach().cpu() def encode_post(self, out_features) -> list[float]: - return out_features.flatten() + return out_features.flatten().to(torch.float32).numpy() def tokenize_lengths(self, sentences: list[str]) -> list[int]: tks = self._infinity_tokenizer.batch_encode_plus( diff --git a/libs/infinity_emb/infinity_emb/transformer/embedder/sentence_transformer.py b/libs/infinity_emb/infinity_emb/transformer/embedder/sentence_transformer.py index 4d414a90..cf595f3c 100644 --- a/libs/infinity_emb/infinity_emb/transformer/embedder/sentence_transformer.py +++ b/libs/infinity_emb/infinity_emb/transformer/embedder/sentence_transformer.py @@ -57,6 +57,10 @@ def __init__(self, *, engine_args=EngineArgs): model_kwargs["attn_implementation"] = "eager" ls = engine_args._loading_strategy + assert ls is not None + + if ls.loading_dtype is not None: + model_kwargs["torch_dtype"] = ls.loading_dtype super().__init__( engine_args.model_name_or_path, @@ -64,7 +68,6 @@ def __init__(self, *, engine_args=EngineArgs): trust_remote_code=engine_args.trust_remote_code, device=ls.device_placement, model_kwargs=model_kwargs, - # TODO: set torch_dtype=ls.loading_dtype to save memory on loading. ) self.to(ls.device_placement) # make a copy of the tokenizer,