Skip to content

Commit

Permalink
add dtype-based loading (#461)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfeil authored Nov 13, 2024
1 parent 4ab717b commit 0a688b6
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 15 deletions.
25 changes: 21 additions & 4 deletions libs/infinity_emb/infinity_emb/transformer/classifier/torch.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2023-now michaelfeil

from infinity_emb._optional_imports import CHECK_TRANSFORMERS
from infinity_emb._optional_imports import CHECK_TRANSFORMERS, CHECK_TORCH
from infinity_emb.args import EngineArgs
from infinity_emb.log_handler import logger
from infinity_emb.transformer.abstract import BaseClassifer
from infinity_emb.transformer.acceleration import to_bettertransformer
from infinity_emb.transformer.quantization.interface import quant_interface
from infinity_emb.primitives import Device

if CHECK_TRANSFORMERS.is_available:
from transformers import AutoTokenizer, pipeline # type: ignore
if CHECK_TORCH.is_available:
import torch


class SentenceClassifier(BaseClassifer):
Expand All @@ -21,24 +25,37 @@ def __init__(
model_kwargs = {}
if engine_args.bettertransformer:
model_kwargs["attn_implementation"] = "eager"
ls = engine_args._loading_strategy
assert ls is not None

if ls.loading_dtype is not None: # type: ignore
model_kwargs["torch_dtype"] = ls.loading_dtype

self._pipe = pipeline(
task="text-classification",
model=engine_args.model_name_or_path,
trust_remote_code=engine_args.trust_remote_code,
device=engine_args.device.resolve(),
device=ls.device_placement,
top_k=None,
revision=engine_args.revision,
model_kwargs=model_kwargs,
)
if self._pipe.device.type != "cpu": # and engine_args.dtype == "float16":
self._pipe.model = self._pipe.model.half()

self._pipe.model = to_bettertransformer(
self._pipe.model,
engine_args,
logger,
)

if ls.quantization_dtype is not None:
self._pipe.model = quant_interface( # TODO: add ls.quantization_dtype and ls.placement
self._pipe.model, engine_args.dtype, device=Device[self._pipe.model.device.type]
)

if engine_args.compile:
logger.info("using torch.compile(dynamic=True)")
self._pipe.model = torch.compile(self._pipe.model, dynamic=True)

self._infinity_tokenizer = AutoTokenizer.from_pretrained(
engine_args.model_name_or_path,
revision=engine_args.revision,
Expand Down
33 changes: 23 additions & 10 deletions libs/infinity_emb/infinity_emb/transformer/crossencoder/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
from infinity_emb._optional_imports import CHECK_SENTENCE_TRANSFORMERS, CHECK_TORCH
from infinity_emb.args import EngineArgs
from infinity_emb.log_handler import logger
from infinity_emb.primitives import Dtype
from infinity_emb.primitives import Device
from infinity_emb.transformer.abstract import BaseCrossEncoder
from infinity_emb.transformer.quantization.interface import (
quant_interface,
)

if CHECK_TORCH.is_available and CHECK_SENTENCE_TRANSFORMERS.is_available:
import torch
Expand Down Expand Up @@ -42,14 +45,20 @@ def __init__(self, *, engine_args: EngineArgs):
if engine_args.bettertransformer:
model_kwargs["attn_implementation"] = "eager"

ls = engine_args._loading_strategy
assert ls is not None

if ls.loading_dtype is not None: # type: ignore
model_kwargs["torch_dtype"] = ls.loading_dtype

super().__init__(
engine_args.model_name_or_path,
revision=engine_args.revision,
device=engine_args.device.resolve(), # type: ignore
trust_remote_code=engine_args.trust_remote_code,
device=ls.device_placement,
automodel_args=model_kwargs,
)
self.model.to(self._target_device) # type: ignore
self.model.to(ls.device_placement)

# make a copy of the tokenizer,
# to be able to could the tokens in another thread
Expand All @@ -64,12 +73,16 @@ def __init__(self, *, engine_args: EngineArgs):
logger,
)

if self._target_device.type == "cuda" and engine_args.dtype in [
Dtype.auto,
Dtype.float16,
]:
logger.info("Switching to half() precision (cuda: fp16). ")
self.model.to(dtype=torch.float16)
self.model.to(ls.loading_dtype)

if ls.quantization_dtype is not None:
self.model = quant_interface( # TODO: add ls.quantization_dtype and ls.placement
self.model, engine_args.dtype, device=Device[self.model.device.type]
)

if engine_args.compile:
logger.info("using torch.compile(dynamic=True)")
self.model = torch.compile(self.model, dynamic=True)

def encode_pre(self, input_tuples: list[tuple[str, str]]):
# return input_tuples
Expand All @@ -91,7 +104,7 @@ def encode_core(self, features: dict[str, "Tensor"]):
return out_features.detach().cpu()

def encode_post(self, out_features) -> list[float]:
return out_features.flatten()
return out_features.flatten().to(torch.float32).numpy()

def tokenize_lengths(self, sentences: list[str]) -> list[int]:
tks = self._infinity_tokenizer.batch_encode_plus(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,17 @@ def __init__(self, *, engine_args=EngineArgs):
model_kwargs["attn_implementation"] = "eager"

ls = engine_args._loading_strategy
assert ls is not None

if ls.loading_dtype is not None:
model_kwargs["torch_dtype"] = ls.loading_dtype

super().__init__(
engine_args.model_name_or_path,
revision=engine_args.revision,
trust_remote_code=engine_args.trust_remote_code,
device=ls.device_placement,
model_kwargs=model_kwargs,
# TODO: set torch_dtype=ls.loading_dtype to save memory on loading.
)
self.to(ls.device_placement)
# make a copy of the tokenizer,
Expand Down

0 comments on commit 0a688b6

Please sign in to comment.