diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 47f1a6bfc4bf9..0ca4a86da1740 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,7 +1,8 @@ +import re from dataclasses import dataclass from functools import lru_cache from itertools import groupby -from typing import (Any, Callable, Generic, List, Mapping, NamedTuple, +from typing import (Any, Callable, Generic, Iterable, List, Mapping, NamedTuple, Optional, TypeVar, Union, final) from transformers import BatchFeature @@ -146,7 +147,7 @@ def _encode( return tokenizer.encode(text, add_special_tokens=add_special_tokens) -_cached_encode = lru_cache(_encode) +_cached_encode = lru_cache(maxsize=2048)(_encode) @lru_cache @@ -159,61 +160,95 @@ class _TokenMatch(NamedTuple): end_idx: int -def find_token_match( +def iter_token_matches( token_ids: List[int], match_ids: List[int], -) -> Optional[_TokenMatch]: +) -> Iterable[_TokenMatch]: """ - Find the first occurrence of :code:`match_ids` in :code:`token_ids`. + Yield each occurrence of :code:`match_ids` in :code:`token_ids`. """ match_len = len(match_ids) for start_idx in range(len(token_ids) - match_len + 1): end_idx = start_idx + match_len if token_ids[start_idx:end_idx] == match_ids: - return _TokenMatch(start_idx, end_idx) + yield _TokenMatch(start_idx, end_idx) - return None - -class _TokenMatchFromTextCandidate(NamedTuple): +class _TokenMatchFromText(NamedTuple): start_idx: int end_idx: int + match_prefix: List[int] + match_suffix: List[int] + match_text_prefix: str match_text_suffix: str - @property - def distance(self) -> int: - return len(self.match_text_prefix) + len(self.match_text_suffix) +def _iter_token_matches_by_window( + tokenizer: AnyTokenizer, + token_ids: List[int], + token_text: str, + match_text: str, + match: re.Match[str], + window_span: int, +) -> Iterable[_TokenMatchFromText]: + min_idx = max(0, match.start() - window_span) + max_idx = min(match.end() + window_span, len(token_text)) + + for window_size in range(len(match_text), max_idx - min_idx): + for window_start_idx in range(min_idx, max_idx - window_size + 1): + window_end_idx = window_start_idx + window_size + window_text = token_text[window_start_idx:window_end_idx] + window_ids = _cached_encode( + tokenizer, + window_text, + add_special_tokens=False, + ) -class _TokenMatchFromText(NamedTuple): - start_idx: int - end_idx: int + ( + window_match_text_prefix, + window_match_text_suffix, + ) = window_text.split(match_text, 1) - match_prefix: List[int] - match_suffix: List[int] + window_match_token_prefix = _cached_encode( + tokenizer, + window_match_text_prefix, + add_special_tokens=False, + ) + window_match_token_suffix = _cached_encode( + tokenizer, + window_match_text_suffix, + add_special_tokens=False, + ) - match_text_prefix: str - match_text_suffix: str + for window_match in iter_token_matches(token_ids, window_ids): + yield _TokenMatchFromText( + window_match.start_idx, + window_match.end_idx, + match_prefix=window_match_token_prefix, + match_suffix=window_match_token_suffix, + match_text_prefix=window_match_text_prefix, + match_text_suffix=window_match_text_suffix, + ) -def find_token_match_by_text( +def iter_token_matches_by_text( tokenizer: AnyTokenizer, token_ids: List[int], token_text: str, match_text: str, -) -> Optional[_TokenMatchFromText]: +) -> Iterable[_TokenMatchFromText]: """ - Find the first occurrence of the tokenized :code:`match_text` in + Yield each occurrence of the tokenized :code:`match_text` in :code:`token_ids`. """ match_ids = _cached_encode(tokenizer, match_text, add_special_tokens=False) - if (match := find_token_match(token_ids, match_ids)): - return _TokenMatchFromText( - match.start_idx, - match.end_idx, + for token_match in iter_token_matches(token_ids, match_ids): + yield _TokenMatchFromText( + token_match.start_idx, + token_match.end_idx, match_prefix=[], match_suffix=[], match_text_prefix="", @@ -221,89 +256,28 @@ def find_token_match_by_text( ) # When `match_text` is not mapped to a special token ID, - # it may be tokenized differently based on the surrounding tokens - # as well as whether it is at the start/end of the string. - # Therefore, we need to use `token_text` as a reference. - text_start_idx = token_text.find(match_text) - if text_start_idx == -1: - return None - - text_end_idx = text_start_idx + len(match_text) - - # In case the left/right side of `match_text` is fused with the + # the left/right side of `match_text` may be fused with the # string immediately before/after it as a single token - text_buffer = _max_vocab_token_len(tokenizer) - 1 - left_text = token_text[:max(0, text_start_idx - text_buffer)] - right_text = token_text[:text_end_idx + text_buffer] - - left_idx = len(_encode(tokenizer, left_text, add_special_tokens=False)) - right_idx = len(_encode(tokenizer, right_text, add_special_tokens=True)) - window_size = len(match_ids) - - best_distance = len(token_text) - best_candidate = None - - for start_idx in range(left_idx, right_idx - window_size + 1): - end_idx = start_idx + window_size - candidate_text = tokenizer.decode( - token_ids[start_idx:end_idx], - # In case match_text is a special token - skip_special_tokens=False, + window_span = _max_vocab_token_len(tokenizer) - len(match_text) + + for text_match in re.finditer(re.escape(match_text), token_text): + yield from _iter_token_matches_by_window( + tokenizer, + token_ids, + token_text, + match_text, + text_match, + window_span, ) - if match_text in candidate_text: - candidate = _TokenMatchFromTextCandidate( - start_idx, - end_idx, - *candidate_text.split(match_text, 1), - ) - - if candidate.distance < best_distance: - best_candidate = candidate - best_distance = candidate.distance - - if best_distance == 0: - break - - assert best_candidate is not None, dict( - # To facilitate debugging - token_ids=token_ids, - match_ids=match_ids, - left_text=left_text, - right_text=right_text, - left_idx=left_idx, - right_idx=right_idx, - ) - - match_token_prefix = _cached_encode( - tokenizer, - best_candidate.match_text_prefix, - add_special_tokens=False, - ) - match_token_suffix = _cached_encode( - tokenizer, - best_candidate.match_text_suffix, - add_special_tokens=False, - ) - - return _TokenMatchFromText( - start_idx=best_candidate.start_idx, - end_idx=best_candidate.end_idx, - match_prefix=match_token_prefix, - match_suffix=match_token_suffix, - match_text_prefix=best_candidate.match_text_prefix, - match_text_suffix=best_candidate.match_text_suffix, - ) - def replace_by_text( tokenizer: AnyTokenizer, token_ids: List[int], token_text: str, - match_text: str, - replacement_id: int, - replacement_count: int, -) -> tuple[List[int], str, Optional[PlaceholderRange]]: + # match_text -> ((item_idx) -> repl_id, repl_count) + match_to_replacement: Mapping[str, Callable[[int], tuple[int, int]]], +) -> tuple[List[int], str, List[PlaceholderRange]]: """ Find the first occurrence of the tokenized :code:`match_text` in :code:`token_ids`, and replace it with @@ -380,7 +354,7 @@ def apply( new_token_ids, = processed_inputs.pop("input_ids").tolist() mm_kwargs = MultiModalKwargs(processed_inputs) - new_prompt = prompt + new_prompt = tokenizer.decode(new_token_ids) mm_placeholders: Mapping[str, List[PlaceholderRange]] = {} for modality, orig_inputs in to_multi_format(mm_data).items(): @@ -400,9 +374,8 @@ def apply( if new_token_id in repl_token_ids: modality_placeholders.append(run_info) - if modality_placeholders: - new_prompt = tokenizer.decode(new_token_ids) - else: # Otherwise, we insert them ourselves + # Otherwise, we insert them ourselves + if not modality_placeholders: for item_idx, orig_item in enumerate(orig_inputs): for match_str, replacement in placeholder_repls.items(): replacement_count = replacement["count"]