Skip to content

Commit

Permalink
Vision fix for non-clip models (#439)
Browse files Browse the repository at this point in the history
* fix vision

* fix vision: lint
  • Loading branch information
michaelfeil authored Oct 23, 2024
1 parent 02480ab commit 4fc18bf
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 16 deletions.
29 changes: 20 additions & 9 deletions libs/infinity_emb/infinity_emb/transformer/vision/torch_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,13 @@ def encode_core(
else:
if "input_ids" in features:
text_embeds: "Tensor" = self.model.get_text_features( # type: ignore
input_ids=features.get("input_ids"),
input_ids=features.get("input_ids"), # requires int32
attention_mask=features.get("attention_mask"),
)
if "pixel_values" in features:
image_embeds: "Tensor" = self.model.get_image_features( # type: ignore
pixel_values=features.get("pixel_values"),
pixel_values=features.get("pixel_values").to(self.model.dtype), # type: ignore
# requires float32 or float16 or bfloat16
)
return text_embeds, image_embeds, type_is_img # type: ignore

Expand All @@ -219,10 +220,20 @@ def encode_post(self, out_features) -> list[float]:
return embeddings

def tokenize_lengths(self, text_list: list[str]) -> list[int]:
preprocessed = self.processor(
text=text_list,
images=[self.mock_image] * len(text_list),
truncation=True,
max_length=self.max_length,
)
return [len(t) for t in preprocessed["input_ids"]]
if self.is_colipali:
preprocessed = self.processor(
text=text_list,
images=[self.mock_image] * len(text_list),
truncation=True,
max_length=self.max_length,
)
return [len(t) for t in preprocessed["input_ids"]]
else:
preprocessed = self.processor(
text=text_list,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_length,
)
return [len(t) for t in preprocessed["input_ids"]]
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
import torch
from colpali_engine.models import ColPali, ColPaliProcessor # type: ignore
from PIL import Image # type: ignore
from transformers import CLIPModel, CLIPProcessor # type: ignore
from transformers import AutoModel, AutoProcessor # type: ignore

from infinity_emb.args import EngineArgs
from infinity_emb.transformer.vision.torch_vision import TIMM


def test_clip_like_model(image_sample):
model_name = pytest.DEFAULT_IMAGE_MODEL
@pytest.mark.parametrize("model_name", ["default", "google/siglip-so400m-patch14-384"])
def test_clip_like_model(image_sample, model_name: str):
if model_name == "default":
model_name = pytest.DEFAULT_IMAGE_MODEL
model = TIMM(engine_args=EngineArgs(model_name_or_path=model_name, dtype="auto"))
image = Image.open(image_sample[0].raw)

Expand All @@ -26,8 +28,8 @@ def test_clip_like_model(image_sample):
assert isinstance(embeddings[0], np.ndarray)
assert len(embeddings) == len(inputs)
embeddings = torch.tensor(embeddings)
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
processor = AutoProcessor.from_pretrained(model_name)

inputs_clip = processor(
text=["a photo of a cat"], images=[image], return_tensors="pt", padding=True
Expand Down Expand Up @@ -113,5 +115,9 @@ def test_colpali(dtype, image_sample):
if __name__ == "__main__":
import requests

test_colpali("int8", requests.get(pytest.DEFAULT_IMAGE_SAMPLE, stream=True)) # type: ignore
test_colpali("auto", requests.get(pytest.DEFAULT_IMAGE_SAMPLE, stream=True)) # type: ignore
pytest.IMAGE_SAMPLE_URL = "https://github.com/michaelfeil/infinity/raw/06fd1f4d8f0a869f4482fc1c78b62a75ccbb66a1/docs/assets/cats_coco_sample.jpg"

test_clip_like_model(
[requests.get(pytest.IMAGE_SAMPLE_URL, stream=True)], "google/siglip-so400m-patch14-384"
) # type: ignore
test_colpali("auto", [requests.get(pytest.IMAGE_SAMPLE_URL, stream=True)]) # type: ignore

0 comments on commit 4fc18bf

Please sign in to comment.