-
Notifications
You must be signed in to change notification settings - Fork 114
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from michaelfeil/add-fastembed
Refactor model dir
- Loading branch information
Showing
18 changed files
with
344 additions
and
200 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Any, List | ||
|
||
from infinity_emb.inference.primitives import NpEmbeddingType | ||
|
||
INPUT_FEATURE = Any | ||
OUT_FEATURES = Any | ||
|
||
|
||
class BaseTransformer(ABC): # Inherit from ABC(Abstract base class) | ||
@abstractmethod # Decorator to define an abstract method | ||
def encode_pre(self, sentences: List[str]) -> INPUT_FEATURE: | ||
pass | ||
|
||
@abstractmethod | ||
def encode_core(self, features: INPUT_FEATURE) -> OUT_FEATURES: | ||
pass | ||
|
||
@abstractmethod | ||
def encode_post(self, embedding: OUT_FEATURES) -> NpEmbeddingType: | ||
pass | ||
|
||
@abstractmethod | ||
def tokenize_lengths(self, sentences: List[str]) -> List[int]: | ||
pass |
27 changes: 27 additions & 0 deletions
27
libs/infinity_emb/infinity_emb/transformer/dummytransformer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from typing import List | ||
|
||
import numpy as np | ||
|
||
from infinity_emb.inference.primitives import NpEmbeddingType | ||
from infinity_emb.transformer.abstract import BaseTransformer | ||
|
||
|
||
class DummyTransformer(BaseTransformer): | ||
"""fix-13 dimension embedding, filled with length of sentence""" | ||
|
||
def __init__(self, *args, **kwargs) -> None: | ||
pass | ||
|
||
def encode_pre(self, sentences: List[str]) -> np.ndarray: | ||
return np.asarray(sentences) | ||
|
||
def encode_core(self, features: np.ndarray) -> NpEmbeddingType: | ||
lengths = np.array([[len(s) for s in features]]) | ||
# embedding of size 13 | ||
return np.ones([len(features), 13]) * lengths.T | ||
|
||
def encode_post(self, embedding: NpEmbeddingType): | ||
return embedding | ||
|
||
def tokenize_lengths(self, sentences: List[str]) -> List[int]: | ||
return [len(s) for s in sentences] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# from typing import List, Dict | ||
# from infinity_emb.inference.primitives import NpEmbeddingType | ||
# from infinity_emb.transformer.abstract import BaseTransformer | ||
# import numpy as np | ||
# import copy | ||
|
||
# class FlagEmbeddingFake: | ||
# def __init__(self, *args, **kwargs) -> None: | ||
# pass | ||
|
||
# try: | ||
# from fastembed.embedding import FlagEmbedding, normalize | ||
# except: | ||
# FlagEmbedding = FlagEmbeddingFake | ||
|
||
# class FastEmbed(FlagEmbedding, BaseTransformer): | ||
# def __init__(self, *args, **kwargs): | ||
# FlagEmbedding.__init__(self)(*args, **kwargs) | ||
# if FlagEmbedding == FlagEmbeddingFake: | ||
# raise ImportError("fastembed is not installed.") | ||
# self._infinity_tokenizer = copy.deepcopy(self.tokenizer) | ||
|
||
# def encode_pre(self, sentences: List[str]) -> Dict[str, np.ndarray[int]]: | ||
# encoded = self.tokenizer.encode_batch(sentences) | ||
# input_ids = np.array([e.ids for e in encoded]) | ||
# attention_mask = np.array([e.attention_mask for e in encoded]) | ||
|
||
# onnx_input = { | ||
# "input_ids": np.array(input_ids, dtype=np.int64), | ||
# "attention_mask": np.array(attention_mask, dtype=np.int64), | ||
# } | ||
|
||
# if not self.exclude_token_type_ids: | ||
# onnx_input["token_type_ids"] = np.array( | ||
# [np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64 | ||
# ) | ||
# return onnx_input | ||
|
||
# def encode_core(self, features: Dict[str, np.ndarray[int]]) -> np.ndarray: | ||
# model_output = self.model.run(None, features) | ||
# last_hidden_state = model_output[0][:, 0] | ||
# return last_hidden_state | ||
|
||
# def encode_post(self, embedding: np.ndarray) -> NpEmbeddingType: | ||
# return normalize(embedding).astype(np.float32) | ||
|
||
# def tokenize_lengths(self, sentences: List[str]) -> List[int]: | ||
# # tks = self._infinity_tokenizer.encode_batch( | ||
# # sentences, | ||
# # ) | ||
# # return [len(t.tokens) for t in tks] | ||
# return [len(s) for s in sentences] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from enum import Enum | ||
from typing import Callable, Dict, List, Tuple | ||
|
||
from infinity_emb.transformer.dummytransformer import DummyTransformer | ||
from infinity_emb.transformer.sentence_transformer import ( | ||
CT2SentenceTransformer, | ||
SentenceTransformerPatched, | ||
) | ||
|
||
# from infinity_emb.transformer.fastembed import FastEmbed | ||
__all__ = [ | ||
"InferenceEngine", | ||
"InferenceEngineTypeHint", | ||
"length_tokenizer", | ||
"get_lengths_with_tokenize", | ||
] | ||
|
||
|
||
class InferenceEngine(Enum): | ||
torch = SentenceTransformerPatched | ||
ctranslate2 = CT2SentenceTransformer | ||
debugengine = DummyTransformer | ||
|
||
|
||
types: Dict[str, str] = {e.name: e.name for e in InferenceEngine} | ||
InferenceEngineTypeHint = Enum("InferenceEngineTypeHint", types) # type: ignore | ||
|
||
|
||
def length_tokenizer( | ||
_sentences: List[str], | ||
) -> List[int]: | ||
return [len(i) for i in _sentences] | ||
|
||
|
||
def get_lengths_with_tokenize( | ||
_sentences: List[str], tokenize: Callable = length_tokenizer | ||
) -> Tuple[List[int], int]: | ||
_lengths = tokenize(_sentences) | ||
return _lengths, sum(_lengths) |
Oops, something went wrong.