init
This commit is contained in:
270
tests/engine/output_processor/test_multi_step.py
Normal file
270
tests/engine/output_processor/test_multi_step.py
Normal file
@@ -0,0 +1,270 @@
|
||||
import random
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from tests.core.utils import create_seq_group
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput,
|
||||
SequenceStatus)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.utils import Counter
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_output_len", [128])
|
||||
@pytest.mark.parametrize("num_new_tokens", [1, 12])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):
|
||||
"""Verify multi-step decoding appends token ids correctly.
|
||||
|
||||
We append token ids and verify all the token ids were appended correctly.
|
||||
Note that ignore_eos=True.
|
||||
"""
|
||||
detokenizer = MagicMock(spec=Detokenizer)
|
||||
scheduler = MagicMock(spec=Scheduler)
|
||||
stop_checker = MagicMock(spec=StopChecker)
|
||||
seq_counter = Counter()
|
||||
|
||||
output_processor = MultiStepOutputProcessor(
|
||||
detokenizer=detokenizer,
|
||||
scheduler=scheduler,
|
||||
seq_counter=seq_counter,
|
||||
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
|
||||
stop_checker=stop_checker,
|
||||
)
|
||||
|
||||
seq_group = create_seq_group(
|
||||
seq_prompt_len=1024,
|
||||
seq_output_lens=[seq_output_len],
|
||||
sampling_params=SamplingParams(max_tokens=seq_output_len +
|
||||
num_new_tokens,
|
||||
ignore_eos=True),
|
||||
)
|
||||
|
||||
seq = seq_group.get_seqs()[0]
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
new_token_ids = list(range(num_new_tokens))
|
||||
|
||||
outputs = [
|
||||
SequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
parent_seq_id=seq.seq_id,
|
||||
output_token=output_token,
|
||||
logprobs={output_token: Logprob(0.0)},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
) for output_token in new_token_ids
|
||||
]
|
||||
|
||||
assert seq.get_token_ids()[-len(new_token_ids):] != new_token_ids
|
||||
output_processor.process_outputs(seq_group, outputs)
|
||||
assert seq.get_token_ids()[-len(new_token_ids):] == new_token_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_prompt_len", [1024])
|
||||
@pytest.mark.parametrize("seq_output_len", [128])
|
||||
@pytest.mark.parametrize("num_new_tokens", [5, 6, 7, 8])
|
||||
@pytest.mark.parametrize("max_tokens", [128 + 3])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,
|
||||
seq_output_len: int, max_tokens: int):
|
||||
"""Verify tokens after max_tokens are dropped and not appended to the
|
||||
sequence.
|
||||
"""
|
||||
detokenizer = MagicMock(spec=Detokenizer)
|
||||
scheduler = MagicMock(spec=Scheduler)
|
||||
stop_checker = MagicMock(spec=StopChecker)
|
||||
seq_counter = Counter()
|
||||
|
||||
output_processor = MultiStepOutputProcessor(
|
||||
detokenizer=detokenizer,
|
||||
scheduler=scheduler,
|
||||
seq_counter=seq_counter,
|
||||
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
|
||||
stop_checker=stop_checker,
|
||||
)
|
||||
|
||||
seq_group = create_seq_group(
|
||||
seq_prompt_len=seq_prompt_len,
|
||||
seq_output_lens=[seq_output_len],
|
||||
sampling_params=SamplingParams(max_tokens=max_tokens, ),
|
||||
)
|
||||
|
||||
seq = seq_group.get_seqs()[0]
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
new_token_ids = list(range(num_new_tokens))
|
||||
|
||||
outputs = [
|
||||
SequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
parent_seq_id=seq.seq_id,
|
||||
output_token=output_token,
|
||||
logprobs={output_token: Logprob(0.0)},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
) for output_token in new_token_ids
|
||||
]
|
||||
|
||||
assert seq.get_len() == seq_prompt_len + seq_output_len
|
||||
output_processor.process_outputs(seq_group, outputs)
|
||||
|
||||
# Expect the processed sequence to not go over max tokens in len.
|
||||
assert seq.get_len() == seq_prompt_len + max_tokens
|
||||
|
||||
# Expect the correct tokens were appended.
|
||||
expected_appended_tokens = new_token_ids[:max_tokens - seq_output_len]
|
||||
assert seq.get_token_ids(
|
||||
)[-len(expected_appended_tokens):] == expected_appended_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_prompt_len", [1024])
|
||||
@pytest.mark.parametrize("seq_output_len", [128])
|
||||
@pytest.mark.parametrize("num_new_tokens", [12])
|
||||
@pytest.mark.parametrize("seed", list(range(6)))
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
|
||||
seq_output_len: int, seed: int):
|
||||
"""Verify the eos token id is included in the sequence, but subsequent
|
||||
tokens are dropped (not appended to sequence).
|
||||
"""
|
||||
random.seed(seed)
|
||||
detokenizer = MagicMock(spec=Detokenizer)
|
||||
scheduler = MagicMock(spec=Scheduler)
|
||||
stop_checker = MagicMock(spec=StopChecker)
|
||||
seq_counter = Counter()
|
||||
|
||||
eos_token_id = 100
|
||||
|
||||
output_processor = MultiStepOutputProcessor(
|
||||
detokenizer=detokenizer,
|
||||
scheduler=scheduler,
|
||||
seq_counter=seq_counter,
|
||||
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
|
||||
stop_checker=stop_checker,
|
||||
)
|
||||
|
||||
seq_group = create_seq_group(
|
||||
seq_prompt_len=seq_prompt_len,
|
||||
seq_output_lens=[seq_output_len],
|
||||
sampling_params=SamplingParams(
|
||||
# Ensure enough space.
|
||||
max_tokens=seq_output_len + num_new_tokens, ),
|
||||
)
|
||||
|
||||
seq = seq_group.get_seqs()[0]
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
new_token_ids = list(range(num_new_tokens))
|
||||
assert eos_token_id not in new_token_ids
|
||||
eos_index = random.randint(0, len(new_token_ids) - 1)
|
||||
new_token_ids[eos_index] = eos_token_id
|
||||
|
||||
outputs = [
|
||||
SequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
parent_seq_id=seq.seq_id,
|
||||
output_token=output_token,
|
||||
logprobs={output_token: Logprob(0.0)},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
) for output_token in new_token_ids
|
||||
]
|
||||
|
||||
assert seq.get_len() == seq_prompt_len + seq_output_len
|
||||
output_processor.process_outputs(seq_group, outputs)
|
||||
|
||||
# Expect the processed sequence to not go beyond provided eos.
|
||||
assert seq.get_len() == seq_prompt_len + seq_output_len + (eos_index + 1)
|
||||
|
||||
# Expect the correct tokens were appended.
|
||||
expected_appended_tokens = new_token_ids[:eos_index + 1]
|
||||
assert seq.get_token_ids(
|
||||
)[-len(expected_appended_tokens):] == expected_appended_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_prompt_len", [1024])
|
||||
@pytest.mark.parametrize("seq_output_len", [128])
|
||||
@pytest.mark.parametrize("num_new_tokens", [12])
|
||||
@pytest.mark.parametrize("seed", list(range(6)))
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
|
||||
seq_output_len: int, seed: int):
|
||||
"""When sampling parameters dictate that we should ignore the eos token id,
|
||||
ensure all token ids are appended even if the eos token id is emitted.
|
||||
"""
|
||||
random.seed(seed)
|
||||
detokenizer = MagicMock(spec=Detokenizer)
|
||||
scheduler = MagicMock(spec=Scheduler)
|
||||
stop_checker = MagicMock(spec=StopChecker)
|
||||
seq_counter = Counter()
|
||||
|
||||
eos_token_id = 100
|
||||
|
||||
output_processor = MultiStepOutputProcessor(
|
||||
detokenizer=detokenizer,
|
||||
scheduler=scheduler,
|
||||
seq_counter=seq_counter,
|
||||
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
|
||||
stop_checker=stop_checker,
|
||||
)
|
||||
|
||||
seq_group = create_seq_group(
|
||||
seq_prompt_len=seq_prompt_len,
|
||||
seq_output_lens=[seq_output_len],
|
||||
sampling_params=SamplingParams(
|
||||
# Ensure enough space.
|
||||
max_tokens=seq_output_len + num_new_tokens,
|
||||
ignore_eos=True,
|
||||
),
|
||||
)
|
||||
|
||||
seq = seq_group.get_seqs()[0]
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
new_token_ids = list(range(num_new_tokens))
|
||||
assert eos_token_id not in new_token_ids
|
||||
eos_index = random.randint(0, len(new_token_ids) - 1)
|
||||
new_token_ids[eos_index] = eos_token_id
|
||||
|
||||
outputs = [
|
||||
SequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
parent_seq_id=seq.seq_id,
|
||||
output_token=output_token,
|
||||
logprobs={output_token: Logprob(0.0)},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
) for output_token in new_token_ids
|
||||
]
|
||||
|
||||
assert seq.get_len() == seq_prompt_len + seq_output_len
|
||||
output_processor.process_outputs(seq_group, outputs)
|
||||
|
||||
# Expect the processed sequence to go beyond eos.
|
||||
assert seq.get_len() == seq_prompt_len + seq_output_len + num_new_tokens
|
||||
|
||||
# Expect the correct tokens were appended.
|
||||
expected_appended_tokens = new_token_ids[:seq_output_len + num_new_tokens -
|
||||
seq_output_len]
|
||||
assert seq.get_token_ids(
|
||||
)[-len(expected_appended_tokens):] == expected_appended_tokens
|
||||
|
||||
|
||||
def mock_tokenizer(eos_token_id=1000):
|
||||
tokenizer = MagicMock(spec=PreTrainedTokenizer)
|
||||
tokenizer.eos_token_id = eos_token_id
|
||||
return tokenizer
|
||||
Reference in New Issue
Block a user