Skip to content

Commit

Permalink
Test and fix candidates detection
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Nov 20, 2024
1 parent 47d3dcc commit a85c542
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 168 deletions.
177 changes: 79 additions & 98 deletions tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import pytest
from transformers import PreTrainedTokenizerBase

from vllm.multimodal.processing import apply_placeholders, iter_token_runs
from vllm.multimodal.processing import (find_token_match_by_text,
iter_token_runs)
from vllm.multimodal.utils import cached_get_tokenizer


# yapf: disable
Expand Down Expand Up @@ -34,108 +37,86 @@ def test_iter_token_runs(token_ids, expected):
assert result == expected


# yapf: disable
@pytest.mark.parametrize("tokenizer_id", [
"llava-hf/llava-1.5-7b-hf",
"meta-llama/Llama-3.2-11B-Vision-Instruct",
"microsoft/Phi-3.5-vision-instruct",
"Qwen/Qwen2-VL-2B-Instruct",
])
@pytest.mark.parametrize("add_special_tokens", [True, False])
@pytest.mark.parametrize(
(
"token_ids", "match_ids", "replacement_id", "replacement_count",
"expected_new_token_ids", "expected_range",
),
"text",
[
"What is in this image?",
# LLaVA
"<image>What is in this image?",
"What is<image>in this image?",
"What is in this image?<image>",
# LLama-3.2
"<|image|>What is in this image?",
"What is<|image|>in this image?",
"What is in this image?<|image|>",
# Phi-3-vision
"<image_1>What is in this image?",
"What is<image_1>in this image?",
"What is in this image?<image_1>",
# Qwen2-VL
"<|vision_start|><|image_pad|><|vision_end|>What is in this image?",
"What is<|vision_start|><|image_pad|><|vision_end|>in this image?",
"What is in this image?<|vision_start|><|image_pad|><|vision_end|>",
])
@pytest.mark.parametrize(
"match_str",
[
# Empty
(
[], [-1], +1, 0,
[], None,
),
# No match
(
[32000, 32000, 32000], [-1], +1, 0,
[32000, 32000, 32000], None,
),
# Match first
(
[-1, 32000, 32000], [-1], +1, 0,
[32000, 32000], { "offset": 0, "length": 0 },
),
(
[-1, 32000, 32000], [-1], +1, 1,
[+1, 32000, 32000], { "offset": 0, "length": 1 },
),
(
[-1, 32000, 32000], [-1], +1, 2,
[+1, +1, 32000, 32000], { "offset": 0, "length": 2 },
),
# Match middle
(
[32000, -1, 32000], [-1], +1, 0,
[32000, 32000], { "offset": 1, "length": 0 },
),
(
[32000, -1, 32000], [-1], +1, 1,
[32000, +1, 32000], { "offset": 1, "length": 1 },
),
(
[32000, -1, 32000], [-1], +1, 2,
[32000, +1, +1, 32000], { "offset": 1, "length": 2},
),
# Match last
(
[32000, 32000, -1], [-1], +1, 0,
[32000, 32000], { "offset": 2, "length": 0 },
),
(
[32000, 32000, -1], [-1], +1, 1,
[32000, 32000, +1], { "offset": 2, "length": 1 },
),
(
[32000, 32000, -1], [-1], +1, 2,
[32000, 32000, +1, +1], { "offset": 2, "length": 2},
),
# Match all
(
[32000, 32000, 32000], [32000], +1, 0,
[32000, 32000], { "offset": 0, "length": 0 },
),
(
[32000, 32000, 32000], [32000], +1, 1,
[+1, 32000, 32000], { "offset": 0, "length": 1 },
),
(
[32000, 32000, 32000], [32000], +1, 2,
[+1, +1, 32000, 32000], { "offset": 0, "length": 2 },
),
],
)
# yapf: enable
def test_apply_placeholders(
token_ids,
match_ids,
replacement_id,
replacement_count,
expected_new_token_ids,
expected_range,
"No",
# Has match
"i",
"What",
"What is",
"image",
"image?",
"<image>",
"<|image|>",
"<image_1>",
"<|vision_start|><|image_pad|><|vision_end|>",
"<s>",
"</s>",
])
def test_token_match_by_text(
tokenizer_id,
add_special_tokens,
text,
match_str,
):
orig_token_ids = token_ids[:]

placeholder_range = apply_placeholders(
token_ids,
match_ids,
replacement_id,
replacement_count,
)
tokenizer = cached_get_tokenizer(tokenizer_id)
assert isinstance(tokenizer, PreTrainedTokenizerBase)

# Invariants
if placeholder_range is None:
assert orig_token_ids == token_ids
else:
offset = placeholder_range["offset"]
match_len = len(match_ids)
repl_len = placeholder_range["length"]
token_ids = tokenizer.encode(text, add_special_tokens=add_special_tokens)
match = find_token_match_by_text(tokenizer, token_ids, text, match_str)

assert orig_token_ids[offset:offset + match_len] == match_ids
# These are only shown in the output if the test fails
print("token_ids:", token_ids)
print("match:", match)

repl_ids = [replacement_id] * replacement_count
assert token_ids[offset:offset + repl_len] == repl_ids
# Invariants
if (match_str in text
or match_str in tokenizer.decode(token_ids,
skip_special_tokens=False)):
assert match is not None
match_start_idx, match_end_idx = match

# Manually constructed results
assert token_ids == expected_new_token_ids
assert placeholder_range == expected_range
assert match_str in tokenizer.decode(
token_ids[match_start_idx:match_end_idx],
skip_special_tokens=False,
)
assert match_str not in tokenizer.decode(
token_ids[match_start_idx + 1:match_end_idx],
skip_special_tokens=False,
)
assert match_str not in tokenizer.decode(
token_ids[match_start_idx:match_end_idx - 1],
skip_special_tokens=False,
)
else:
assert match is None
Loading

0 comments on commit a85c542

Please sign in to comment.