Skip to content

Commit

Permalink
Draft
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Nov 21, 2024
1 parent 43043ca commit 9eb2ba6
Showing 1 changed file with 78 additions and 105 deletions.
183 changes: 78 additions & 105 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import re
from dataclasses import dataclass
from functools import lru_cache
from itertools import groupby
from typing import (Any, Callable, Generic, List, Mapping, NamedTuple,
from typing import (Any, Callable, Generic, Iterable, List, Mapping, NamedTuple,
Optional, TypeVar, Union, final)

Check failure on line 6 in vllm/multimodal/processing.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F401)

vllm/multimodal/processing.py:6:21: F401 `typing.Optional` imported but unused

from transformers import BatchFeature
Expand Down Expand Up @@ -146,7 +147,7 @@ def _encode(
return tokenizer.encode(text, add_special_tokens=add_special_tokens)


_cached_encode = lru_cache(_encode)
_cached_encode = lru_cache(maxsize=2048)(_encode)


@lru_cache
Expand All @@ -159,151 +160,124 @@ class _TokenMatch(NamedTuple):
end_idx: int


def find_token_match(
def iter_token_matches(
token_ids: List[int],
match_ids: List[int],
) -> Optional[_TokenMatch]:
) -> Iterable[_TokenMatch]:
"""
Find the first occurrence of :code:`match_ids` in :code:`token_ids`.
Yield each occurrence of :code:`match_ids` in :code:`token_ids`.
"""
match_len = len(match_ids)

for start_idx in range(len(token_ids) - match_len + 1):
end_idx = start_idx + match_len
if token_ids[start_idx:end_idx] == match_ids:
return _TokenMatch(start_idx, end_idx)
yield _TokenMatch(start_idx, end_idx)

return None


class _TokenMatchFromTextCandidate(NamedTuple):
class _TokenMatchFromText(NamedTuple):
start_idx: int
end_idx: int

match_prefix: List[int]
match_suffix: List[int]

match_text_prefix: str
match_text_suffix: str

@property
def distance(self) -> int:
return len(self.match_text_prefix) + len(self.match_text_suffix)

def _iter_token_matches_by_window(
tokenizer: AnyTokenizer,
token_ids: List[int],
token_text: str,
match_text: str,
match: re.Match[str],
window_span: int,
) -> Iterable[_TokenMatchFromText]:
min_idx = max(0, match.start() - window_span)
max_idx = min(match.end() + window_span, len(token_text))

for window_size in range(len(match_text), max_idx - min_idx):
for window_start_idx in range(min_idx, max_idx - window_size + 1):
window_end_idx = window_start_idx + window_size
window_text = token_text[window_start_idx:window_end_idx]
window_ids = _cached_encode(
tokenizer,
window_text,
add_special_tokens=False,
)

class _TokenMatchFromText(NamedTuple):
start_idx: int
end_idx: int
(
window_match_text_prefix,
window_match_text_suffix,
) = window_text.split(match_text, 1)

match_prefix: List[int]
match_suffix: List[int]
window_match_token_prefix = _cached_encode(
tokenizer,
window_match_text_prefix,
add_special_tokens=False,
)
window_match_token_suffix = _cached_encode(
tokenizer,
window_match_text_suffix,
add_special_tokens=False,
)

match_text_prefix: str
match_text_suffix: str
for window_match in iter_token_matches(token_ids, window_ids):
yield _TokenMatchFromText(
window_match.start_idx,
window_match.end_idx,
match_prefix=window_match_token_prefix,
match_suffix=window_match_token_suffix,
match_text_prefix=window_match_text_prefix,
match_text_suffix=window_match_text_suffix,
)


def find_token_match_by_text(
def iter_token_matches_by_text(
tokenizer: AnyTokenizer,
token_ids: List[int],
token_text: str,
match_text: str,
) -> Optional[_TokenMatchFromText]:
) -> Iterable[_TokenMatchFromText]:
"""
Find the first occurrence of the tokenized :code:`match_text` in
Yield each occurrence of the tokenized :code:`match_text` in
:code:`token_ids`.
"""
match_ids = _cached_encode(tokenizer, match_text, add_special_tokens=False)
if (match := find_token_match(token_ids, match_ids)):
return _TokenMatchFromText(
match.start_idx,
match.end_idx,
for token_match in iter_token_matches(token_ids, match_ids):
yield _TokenMatchFromText(
token_match.start_idx,
token_match.end_idx,
match_prefix=[],
match_suffix=[],
match_text_prefix="",
match_text_suffix="",
)

# When `match_text` is not mapped to a special token ID,
# it may be tokenized differently based on the surrounding tokens
# as well as whether it is at the start/end of the string.
# Therefore, we need to use `token_text` as a reference.
text_start_idx = token_text.find(match_text)
if text_start_idx == -1:
return None

text_end_idx = text_start_idx + len(match_text)

# In case the left/right side of `match_text` is fused with the
# the left/right side of `match_text` may be fused with the
# string immediately before/after it as a single token
text_buffer = _max_vocab_token_len(tokenizer) - 1
left_text = token_text[:max(0, text_start_idx - text_buffer)]
right_text = token_text[:text_end_idx + text_buffer]

left_idx = len(_encode(tokenizer, left_text, add_special_tokens=False))
right_idx = len(_encode(tokenizer, right_text, add_special_tokens=True))
window_size = len(match_ids)

best_distance = len(token_text)
best_candidate = None

for start_idx in range(left_idx, right_idx - window_size + 1):
end_idx = start_idx + window_size
candidate_text = tokenizer.decode(
token_ids[start_idx:end_idx],
# In case match_text is a special token
skip_special_tokens=False,
window_span = _max_vocab_token_len(tokenizer) - len(match_text)

for text_match in re.finditer(re.escape(match_text), token_text):
yield from _iter_token_matches_by_window(
tokenizer,
token_ids,
token_text,
match_text,
text_match,
window_span,
)

if match_text in candidate_text:
candidate = _TokenMatchFromTextCandidate(
start_idx,
end_idx,
*candidate_text.split(match_text, 1),
)

if candidate.distance < best_distance:
best_candidate = candidate
best_distance = candidate.distance

if best_distance == 0:
break

assert best_candidate is not None, dict(
# To facilitate debugging
token_ids=token_ids,
match_ids=match_ids,
left_text=left_text,
right_text=right_text,
left_idx=left_idx,
right_idx=right_idx,
)

match_token_prefix = _cached_encode(
tokenizer,
best_candidate.match_text_prefix,
add_special_tokens=False,
)
match_token_suffix = _cached_encode(
tokenizer,
best_candidate.match_text_suffix,
add_special_tokens=False,
)

return _TokenMatchFromText(
start_idx=best_candidate.start_idx,
end_idx=best_candidate.end_idx,
match_prefix=match_token_prefix,
match_suffix=match_token_suffix,
match_text_prefix=best_candidate.match_text_prefix,
match_text_suffix=best_candidate.match_text_suffix,
)


def replace_by_text(
tokenizer: AnyTokenizer,
token_ids: List[int],
token_text: str,
match_text: str,
replacement_id: int,
replacement_count: int,
) -> tuple[List[int], str, Optional[PlaceholderRange]]:
# match_text -> ((item_idx) -> repl_id, repl_count)
match_to_replacement: Mapping[str, Callable[[int], tuple[int, int]]],
) -> tuple[List[int], str, List[PlaceholderRange]]:
"""
Find the first occurrence of the tokenized :code:`match_text` in
:code:`token_ids`, and replace it with
Expand Down Expand Up @@ -380,7 +354,7 @@ def apply(
new_token_ids, = processed_inputs.pop("input_ids").tolist()
mm_kwargs = MultiModalKwargs(processed_inputs)

new_prompt = prompt
new_prompt = tokenizer.decode(new_token_ids)
mm_placeholders: Mapping[str, List[PlaceholderRange]] = {}

for modality, orig_inputs in to_multi_format(mm_data).items():
Expand All @@ -400,9 +374,8 @@ def apply(
if new_token_id in repl_token_ids:
modality_placeholders.append(run_info)

if modality_placeholders:
new_prompt = tokenizer.decode(new_token_ids)
else: # Otherwise, we insert them ourselves
# Otherwise, we insert them ourselves
if not modality_placeholders:
for item_idx, orig_item in enumerate(orig_inputs):
for match_str, replacement in placeholder_repls.items():
replacement_count = replacement["count"]
Expand Down

0 comments on commit 9eb2ba6

Please sign in to comment.