diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 9a510ebc83255..65045c4b9c6b4 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -1,6 +1,9 @@ import pytest +from transformers import PreTrainedTokenizerBase -from vllm.multimodal.processing import apply_placeholders, iter_token_runs +from vllm.multimodal.processing import (find_token_match_by_text, + iter_token_runs) +from vllm.multimodal.utils import cached_get_tokenizer # yapf: disable @@ -34,108 +37,86 @@ def test_iter_token_runs(token_ids, expected): assert result == expected -# yapf: disable +@pytest.mark.parametrize("tokenizer_id", [ + "llava-hf/llava-1.5-7b-hf", + "meta-llama/Llama-3.2-11B-Vision-Instruct", + "microsoft/Phi-3.5-vision-instruct", + "Qwen/Qwen2-VL-2B-Instruct", +]) +@pytest.mark.parametrize("add_special_tokens", [True, False]) @pytest.mark.parametrize( - ( - "token_ids", "match_ids", "replacement_id", "replacement_count", - "expected_new_token_ids", "expected_range", - ), + "text", + [ + "What is in this image?", + # LLaVA + "What is in this image?", + "What isin this image?", + "What is in this image?", + # LLama-3.2 + "<|image|>What is in this image?", + "What is<|image|>in this image?", + "What is in this image?<|image|>", + # Phi-3-vision + "What is in this image?", + "What isin this image?", + "What is in this image?", + # Qwen2-VL + "<|vision_start|><|image_pad|><|vision_end|>What is in this image?", + "What is<|vision_start|><|image_pad|><|vision_end|>in this image?", + "What is in this image?<|vision_start|><|image_pad|><|vision_end|>", + ]) +@pytest.mark.parametrize( + "match_str", [ - # Empty - ( - [], [-1], +1, 0, - [], None, - ), # No match - ( - [32000, 32000, 32000], [-1], +1, 0, - [32000, 32000, 32000], None, - ), - # Match first - ( - [-1, 32000, 32000], [-1], +1, 0, - [32000, 32000], { "offset": 0, "length": 0 }, - ), - ( - [-1, 32000, 32000], [-1], +1, 1, - [+1, 32000, 32000], { "offset": 0, "length": 1 }, - ), - ( - [-1, 32000, 32000], [-1], +1, 2, - [+1, +1, 32000, 32000], { "offset": 0, "length": 2 }, - ), - # Match middle - ( - [32000, -1, 32000], [-1], +1, 0, - [32000, 32000], { "offset": 1, "length": 0 }, - ), - ( - [32000, -1, 32000], [-1], +1, 1, - [32000, +1, 32000], { "offset": 1, "length": 1 }, - ), - ( - [32000, -1, 32000], [-1], +1, 2, - [32000, +1, +1, 32000], { "offset": 1, "length": 2}, - ), - # Match last - ( - [32000, 32000, -1], [-1], +1, 0, - [32000, 32000], { "offset": 2, "length": 0 }, - ), - ( - [32000, 32000, -1], [-1], +1, 1, - [32000, 32000, +1], { "offset": 2, "length": 1 }, - ), - ( - [32000, 32000, -1], [-1], +1, 2, - [32000, 32000, +1, +1], { "offset": 2, "length": 2}, - ), - # Match all - ( - [32000, 32000, 32000], [32000], +1, 0, - [32000, 32000], { "offset": 0, "length": 0 }, - ), - ( - [32000, 32000, 32000], [32000], +1, 1, - [+1, 32000, 32000], { "offset": 0, "length": 1 }, - ), - ( - [32000, 32000, 32000], [32000], +1, 2, - [+1, +1, 32000, 32000], { "offset": 0, "length": 2 }, - ), - ], -) -# yapf: enable -def test_apply_placeholders( - token_ids, - match_ids, - replacement_id, - replacement_count, - expected_new_token_ids, - expected_range, + "No", + # Has match + "i", + "What", + "What is", + "image", + "image?", + "", + "<|image|>", + "", + "<|vision_start|><|image_pad|><|vision_end|>", + "", + "", + ]) +def test_token_match_by_text( + tokenizer_id, + add_special_tokens, + text, + match_str, ): - orig_token_ids = token_ids[:] - - placeholder_range = apply_placeholders( - token_ids, - match_ids, - replacement_id, - replacement_count, - ) + tokenizer = cached_get_tokenizer(tokenizer_id) + assert isinstance(tokenizer, PreTrainedTokenizerBase) - # Invariants - if placeholder_range is None: - assert orig_token_ids == token_ids - else: - offset = placeholder_range["offset"] - match_len = len(match_ids) - repl_len = placeholder_range["length"] + token_ids = tokenizer.encode(text, add_special_tokens=add_special_tokens) + match = find_token_match_by_text(tokenizer, token_ids, text, match_str) - assert orig_token_ids[offset:offset + match_len] == match_ids + # These are only shown in the output if the test fails + print("token_ids:", token_ids) + print("match:", match) - repl_ids = [replacement_id] * replacement_count - assert token_ids[offset:offset + repl_len] == repl_ids + # Invariants + if (match_str in text + or match_str in tokenizer.decode(token_ids, + skip_special_tokens=False)): + assert match is not None + match_start_idx, match_end_idx = match - # Manually constructed results - assert token_ids == expected_new_token_ids - assert placeholder_range == expected_range + assert match_str in tokenizer.decode( + token_ids[match_start_idx:match_end_idx], + skip_special_tokens=False, + ) + assert match_str not in tokenizer.decode( + token_ids[match_start_idx + 1:match_end_idx], + skip_special_tokens=False, + ) + assert match_str not in tokenizer.decode( + token_ids[match_start_idx:match_end_idx - 1], + skip_special_tokens=False, + ) + else: + assert match is None diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 389b831841997..60eda586d2fdd 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,7 +1,8 @@ from dataclasses import dataclass from functools import lru_cache +from heapq import nsmallest from itertools import groupby -from typing import (Any, Callable, Collection, Generic, List, Mapping, +from typing import (Any, Callable, Generic, List, Mapping, NamedTuple, Optional, TypeVar, Union, final) from transformers import BatchFeature @@ -128,95 +129,148 @@ def iter_token_runs(token_ids: List[int]): start_idx += length -def encode_no_special_tokens( +def _encode( tokenizer: AnyTokenizer, text: str, + *, + add_special_tokens: bool = False, ) -> List[int]: """ Backend-agnostic equivalent of HF's - :code:`tokenizer.encode(text, add_special_tokens=False)`. + :code:`tokenizer.encode(text, add_special_tokens=...)`. """ if isinstance(tokenizer, MistralTokenizer): - return tokenizer.tokenizer.encode(text, bos=False, eos=False) + return tokenizer.tokenizer.encode(text, + bos=add_special_tokens, + eos=add_special_tokens) - return tokenizer.encode(text, add_special_tokens=False) + return tokenizer.encode(text, add_special_tokens=add_special_tokens) @lru_cache -def candidate_placeholders( - tokenizer: AnyTokenizer, - placeholder_text: str, -) -> Collection[List[int]]: - """Generate token ID sequences that may represent a placeholder text.""" - # When the placeholder text is not mapped to a special token ID, - # it may be tokenized differently based on whether it is at the start/end - # of the string. So, we go through each combination of whether the text - # is at the start and end boundaries of the string +def _max_vocab_token_len(tokenizer: AnyTokenizer) -> int: + return max(len(token_text) for token_text in tokenizer.get_vocab()) - # Matches the placeholder when it is in the middle of the string - start_id, = encode_no_special_tokens(tokenizer, "a") - end_id, = encode_no_special_tokens(tokenizer, "b") - candidate_basic = encode_no_special_tokens(tokenizer, placeholder_text) +class _TokenMatch(NamedTuple): + start_idx: int + end_idx: int - start_id_, *candidate_a = encode_no_special_tokens( - tokenizer, - f"a{placeholder_text}", - ) - assert start_id == start_id_ - start_id_, *candidate_ab, end_id_ = encode_no_special_tokens( - tokenizer, - f"a{placeholder_text}b", - ) - assert start_id == start_id_ and end_id == end_id_ +def find_token_match(token_ids: List[int], match_ids: List[int]): + """ + Find the first occurrence of :code:`match_ids` in :code:`token_ids`. + """ + match_len = len(match_ids) - *candidate_b, end_id_ = encode_no_special_tokens( - tokenizer, - f"{placeholder_text}b", - ) - assert end_id == end_id_ + for start_idx in range(len(token_ids) - match_len + 1): + end_idx = start_idx + match_len + if token_ids[start_idx:end_idx] == match_ids: + return _TokenMatch(start_idx, end_idx) + + return None - # Remove duplicates (need to convert to tuple to be hashable) - unique_candidates = { - tuple(c) - for c in [candidate_basic, candidate_a, candidate_ab, candidate_b] - } - # Convert back to list - return [list(c) for c in unique_candidates] +class _Candidate(NamedTuple): + start_idx: int + end_idx: int + distance: int + + +def find_token_match_by_text( + tokenizer: AnyTokenizer, + token_ids: List[int], + token_text: str, + match_text: str, +): + """ + Find the first occurrence of the tokenized :code:`match_text` in + :code:`token_ids`. + """ + match_ids = _encode(tokenizer, match_text, add_special_tokens=False) + if (match := find_token_match(token_ids, match_ids)): + return match + + # When `match_text` is not mapped to a special token ID, + # it may be tokenized differently based on the surrounding tokens + # as well as whether it is at the start/end of the string. + # Therefore, we need to use `token_text` as a reference. + text_start_idx = token_text.find(match_text) + if text_start_idx == -1: + return None + + text_end_idx = text_start_idx + len(match_text) + + # In case the left/right side of `match_text` is fused with the + # string immediately before/after it during tokenization + text_buffer = _max_vocab_token_len(tokenizer) - 1 + left_text = token_text[:max(0, text_start_idx - text_buffer)] + right_text = token_text[:text_end_idx + text_buffer] + + left_idx = len(_encode(tokenizer, left_text, add_special_tokens=False)) + right_idx = len(_encode(tokenizer, right_text, add_special_tokens=True)) + + valid_candidates = list[_Candidate]() + for window_size in (len(match_ids) - 1, len(match_ids)): + for start_idx in range(left_idx, right_idx - window_size + 1): + end_idx = start_idx + window_size + candidate_text = tokenizer.decode( + token_ids[start_idx:end_idx], + skip_special_tokens=False, + ) + + if match_text in candidate_text: + candidate = _Candidate( + start_idx=start_idx, + end_idx=end_idx, + distance=len(candidate_text) - len(match_text), + ) + valid_candidates.append(candidate) + + assert len(valid_candidates) > 0, dict( + # To facilicate debugging + token_ids=token_ids, + match_ids=match_ids, + left_text=left_text, + right_text=right_text, + left_idx=left_idx, + right_idx=right_idx, + ) + + best_candidate, = nsmallest(1, valid_candidates, key=lambda x: x.distance) + return best_candidate.start_idx, best_candidate.end_idx def apply_placeholders( + tokenizer: AnyTokenizer, token_ids: List[int], - match_ids: List[int], + token_text: str, + match_text: str, replacement_id: int, replacement_count: int, ) -> Optional[PlaceholderRange]: """ - Find the first occurrence of :code:`placeholder_ids`, - and replace it with the output of :code:`get_replacement_ids`. + Find the first occurrence of the tokenized :code:`match_text` in + :code:`token_ids`, and replace it with + :code:`[replacement_id] * replacement_count`. This function updates :code:`token_ids` in place. """ - if len(match_ids) == 0: - raise ValueError("Match tokens should not be empty") - if replacement_id in match_ids: - raise ValueError(f"Match tokens ({match_ids}) should not include " - f"replacement token ({replacement_id})") - - placeholder_length = len(match_ids) + match = find_token_match_by_text( + tokenizer, + token_ids, + token_text, + match_text, + ) - for start_idx in range(len(token_ids) - placeholder_length + 1): - end_idx = start_idx + placeholder_length + if match is None: + return None - if token_ids[start_idx:end_idx] == match_ids: - replacement_ids = [replacement_id] * replacement_count - token_ids[start_idx:end_idx] = replacement_ids + # TODO(youkaichao): Don't update new_token_ids + start_idx, end_idx = match + token_ids[start_idx:end_idx] = [replacement_id] * replacement_count - return PlaceholderRange(offset=start_idx, length=replacement_count) - - return None + return PlaceholderRange(offset=start_idx, length=replacement_count) class MultiModalProcessor: @@ -290,18 +344,17 @@ def apply( item_idx, ) - for match_ids in candidate_placeholders( - tokenizer, match_str): - # TODO(youkaichao): Don't update new_token_ids - placeholders = apply_placeholders( - new_token_ids, - match_ids, - replacement["token_id"], - replacement_count, - ) - - if placeholders is not None: - modality_placeholders.append(placeholders) + placeholders = apply_placeholders( + tokenizer, + new_token_ids, + prompt, + match_str, + replacement["token_id"], + replacement_count, + ) + + if placeholders is not None: + modality_placeholders.append(placeholders) mm_placeholders[modality] = modality_placeholders # type: ignore[index] # yapf: disable