Sync from v0.13
This commit is contained in:
0
tests/samplers/__init__.py
Normal file
0
tests/samplers/__init__.py
Normal file
@@ -1,21 +1,26 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Compare the outputs of HF and vLLM when using beam search.
|
||||
|
||||
Run `pytest tests/samplers/test_beam_search.py`.
|
||||
"""
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
# FIXME(zhuohan): The test can not pass if we:
|
||||
# 1. Increase max_tokens to 256.
|
||||
# 2. Increase beam_width to 8.
|
||||
# 3. Use the model "huggyllama/llama-7b".
|
||||
MAX_TOKENS = [128]
|
||||
MAX_TOKENS = [64]
|
||||
BEAM_WIDTHS = [4]
|
||||
MODELS = ["facebook/opt-125m"]
|
||||
MM_BEAM_WIDTHS = [2]
|
||||
MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
|
||||
|
||||
|
||||
@pytest.mark.skip_v1 # FIXME: This fails on V1 right now.
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
|
||||
@@ -30,25 +35,152 @@ def test_beam_search_single_input(
|
||||
beam_width: int,
|
||||
) -> None:
|
||||
example_prompts = example_prompts[:1]
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
|
||||
max_tokens)
|
||||
del hf_model
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_beam_search(
|
||||
example_prompts, beam_width, max_tokens
|
||||
)
|
||||
|
||||
vllm_model = vllm_runner(model, dtype=dtype)
|
||||
vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width,
|
||||
max_tokens)
|
||||
del vllm_model
|
||||
# NOTE(woosuk): For some reason, the following GC is required to avoid
|
||||
# GPU OOM errors in the following tests using `vllm_runner`.
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_beam_search(
|
||||
example_prompts, beam_width, max_tokens
|
||||
)
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, _ = hf_outputs[i]
|
||||
vllm_output_ids, _ = vllm_outputs[i]
|
||||
hf_output_ids, hf_output_texts = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_texts = vllm_outputs[i]
|
||||
for j, (hf_text, vllm_text) in enumerate(
|
||||
zip(hf_output_texts, vllm_output_texts)
|
||||
):
|
||||
print(f">>>{j}-th hf output:")
|
||||
print(hf_text)
|
||||
print(f">>>{j}-th vllm output:")
|
||||
print(vllm_text)
|
||||
assert len(hf_output_ids) == len(vllm_output_ids)
|
||||
for j in range(len(hf_output_ids)):
|
||||
assert hf_output_ids[j] == vllm_output_ids[j], (
|
||||
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
|
||||
f"vLLM: {vllm_output_ids}")
|
||||
f"Test{i} output{j}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip_v1 # FIXME: This fails on V1 right now.
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
|
||||
@pytest.mark.parametrize("beam_width", BEAM_WIDTHS)
|
||||
def test_beam_search_with_concurrency_limit(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
beam_width: int,
|
||||
) -> None:
|
||||
# example_prompts[1]&[3]&[7] fails due to unknown reason even without
|
||||
# concurrency limit. skip them for now.
|
||||
example_prompts = example_prompts[:8]
|
||||
concurrency_limit = 2
|
||||
assert len(example_prompts) > concurrency_limit
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
outputs_with_limit = vllm_model.generate_beam_search(
|
||||
example_prompts, beam_width, max_tokens, concurrency_limit=concurrency_limit
|
||||
)
|
||||
outputs_without_limit = []
|
||||
|
||||
for i in range(0, len(example_prompts), concurrency_limit):
|
||||
outputs_without_limit.extend(
|
||||
vllm_model.generate_beam_search(
|
||||
example_prompts[i : i + concurrency_limit], beam_width, max_tokens
|
||||
)
|
||||
)
|
||||
|
||||
correct = True
|
||||
for i in range(len(example_prompts)):
|
||||
output_ids_with_limit, output_texts_with_limit = outputs_with_limit[i]
|
||||
output_ids_without_limit, output_texts_without_limit = outputs_without_limit[i]
|
||||
for j, (text_with_limit, text_without_limit) in enumerate(
|
||||
zip(output_texts_with_limit, output_texts_without_limit)
|
||||
):
|
||||
print(f">>>{j}-th with limit output:")
|
||||
print(text_with_limit)
|
||||
print(f">>>{j}-th without limit output:")
|
||||
print(text_without_limit)
|
||||
assert len(output_ids_with_limit) == len(output_ids_without_limit)
|
||||
for j in range(len(output_ids_with_limit)):
|
||||
if output_ids_with_limit[j] != output_ids_without_limit[j]:
|
||||
print(
|
||||
f"Test{i} output{j}:\n+limit: {output_ids_with_limit}\n"
|
||||
f"-limit: {output_ids_without_limit}"
|
||||
)
|
||||
correct = False
|
||||
assert correct
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
|
||||
@pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS)
|
||||
def test_beam_search_passes_multimodal_data(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
beam_width: int,
|
||||
) -> None:
|
||||
"""Ensure that beam search passes multimodal data through correctly."""
|
||||
# NOTE - this test is primarily to check that mm data is passed to beams
|
||||
# correctly. As such, we just need to check one extra modality to make
|
||||
# sure things pass through properly.
|
||||
audios = [AudioAsset("mary_had_lamb").audio_and_sample_rate]
|
||||
model = "Qwen/Qwen2-Audio-7B-Instruct"
|
||||
audio_seq = "<|audio_bos|><|AUDIO|><|audio_eos|>"
|
||||
prompts = [
|
||||
f"<|im_start|>user\n{audio_seq}Can you transcribe this?<|im_end|>\n<|im_start|>assistant\n" # noqa: E501
|
||||
]
|
||||
|
||||
with hf_runner(model, dtype=dtype, auto_cls=AutoModelForSeq2SeqLM) as hf_model:
|
||||
audio_token_id = hf_model.config.audio_token_index
|
||||
eos_token_id = hf_model.tokenizer.eos_token_id # <|im_end|>
|
||||
hf_outputs = hf_model.generate_beam_search(
|
||||
prompts,
|
||||
beam_width=beam_width,
|
||||
max_tokens=max_tokens,
|
||||
audios=audios,
|
||||
)
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_beam_search(
|
||||
prompts,
|
||||
beam_width=beam_width,
|
||||
max_tokens=max_tokens,
|
||||
audios=audios,
|
||||
)
|
||||
|
||||
seq_with_no_audio_toks = lambda seq: [tok for tok in seq if tok != audio_token_id]
|
||||
|
||||
for i in range(len(prompts)):
|
||||
hf_output_ids, hf_output_texts = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_texts = vllm_outputs[i]
|
||||
|
||||
for j, (hf_text, vllm_text) in enumerate(
|
||||
zip(hf_output_texts, vllm_output_texts)
|
||||
):
|
||||
print(f">>>{j}-th hf output [NOTE: special tokens are filtered]:")
|
||||
print(hf_text)
|
||||
print(f">>>{j}-th vllm output:")
|
||||
print(vllm_text)
|
||||
assert len(hf_output_ids) == len(vllm_output_ids)
|
||||
|
||||
for j in range(len(hf_output_ids)):
|
||||
# Compare everything except for the audio tokens; we do this since
|
||||
# the IDs returned from the transformers helper expands the audio
|
||||
# token to match features, while the vLLM helper maintains the
|
||||
# single audio token in the input text
|
||||
filtered_hf_output_ids = seq_with_no_audio_toks(hf_output_ids[j])
|
||||
filtered_vllm_output_ids = seq_with_no_audio_toks(vllm_output_ids[j])
|
||||
|
||||
# HF output IDs may contain the end of sequence
|
||||
if len(filtered_hf_output_ids) == len(filtered_vllm_output_ids) + 1:
|
||||
assert filtered_hf_output_ids[-1] == eos_token_id
|
||||
filtered_hf_output_ids = filtered_hf_output_ids[:-1]
|
||||
|
||||
assert filtered_hf_output_ids == filtered_vllm_output_ids
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Make sure ignore_eos works.
|
||||
|
||||
Run `pytest tests/samplers/test_ignore_eos.py`.
|
||||
@@ -7,25 +9,27 @@ import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
MODELS = ["facebook/opt-125m"]
|
||||
# We also test with llama because it has generation_config to specify EOS
|
||||
# (past regression).
|
||||
MODELS = ["distilbert/distilgpt2", "meta-llama/Llama-3.2-1B"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [1024])
|
||||
def test_beam_search_single_input(
|
||||
@pytest.mark.parametrize("max_tokens", [512])
|
||||
def test_ignore_eos(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
example_prompts = "1 + 1 is"
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True)
|
||||
|
||||
vllm_model = vllm_runner(model, dtype=dtype)
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True)
|
||||
ignore_eos_output = vllm_model.model.generate(
|
||||
example_prompts, sampling_params=sampling_params)
|
||||
print(len(ignore_eos_output[0].outputs[0].token_ids))
|
||||
assert max_tokens - len(ignore_eos_output[0].outputs[0].token_ids) < 10
|
||||
assert max_tokens - len(ignore_eos_output[0].outputs[0].token_ids) >= 0
|
||||
for prompt in example_prompts:
|
||||
ignore_eos_output = vllm_model.llm.generate(
|
||||
prompt, sampling_params=sampling_params
|
||||
)
|
||||
output_length = len(ignore_eos_output[0].outputs[0].token_ids)
|
||||
assert output_length == max_tokens
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
MODELS = ["facebook/opt-125m"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_logits_processor_force_generate(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
vllm_model = vllm_runner(model, dtype=dtype)
|
||||
tokenizer = vllm_model.model.get_tokenizer()
|
||||
repeat_times = 2
|
||||
enforced_answers = " vLLM"
|
||||
vllm_token_ids = tokenizer.encode(enforced_answers,
|
||||
add_special_tokens=False)
|
||||
max_tokens = len(vllm_token_ids) * repeat_times
|
||||
|
||||
def pick_vllm(token_ids, logits):
|
||||
token_id = vllm_token_ids[len(token_ids) % len(vllm_token_ids)]
|
||||
logits[token_id] = torch.finfo(logits.dtype).max
|
||||
return logits
|
||||
|
||||
params_with_logprobs = SamplingParams(
|
||||
logits_processors=[pick_vllm],
|
||||
prompt_logprobs=3,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
# test logits_processors when prompt_logprobs is not None
|
||||
vllm_model.model._add_request(
|
||||
prompt=example_prompts[0],
|
||||
sampling_params=params_with_logprobs,
|
||||
prompt_token_ids=None,
|
||||
)
|
||||
|
||||
# test prompt_logprobs is not None
|
||||
vllm_model.model._add_request(
|
||||
prompt=example_prompts[1],
|
||||
sampling_params=SamplingParams(
|
||||
prompt_logprobs=3,
|
||||
max_tokens=max_tokens,
|
||||
),
|
||||
prompt_token_ids=None,
|
||||
)
|
||||
|
||||
# test grouped requests
|
||||
vllm_model.model._add_request(
|
||||
prompt=example_prompts[2],
|
||||
sampling_params=SamplingParams(max_tokens=max_tokens),
|
||||
prompt_token_ids=None,
|
||||
)
|
||||
|
||||
outputs = vllm_model.model._run_engine(False)
|
||||
|
||||
assert outputs[0].outputs[0].text == enforced_answers * repeat_times
|
||||
@@ -1,124 +1,91 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.conftest import VllmRunner
|
||||
from vllm import SamplingParams
|
||||
from vllm.logprobs import FlatLogprobs
|
||||
|
||||
MODELS = ["facebook/opt-125m"]
|
||||
MODELS = ["distilbert/distilgpt2"]
|
||||
MAX_TOKENS = 5
|
||||
NUM_TOP_LOGPROBS = 5
|
||||
NUM_PROMPT_LOGPROBS = 7
|
||||
MAX_LOGPROBS = max(NUM_TOP_LOGPROBS, NUM_PROMPT_LOGPROBS)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
|
||||
@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size
|
||||
def test_get_prompt_logprobs(
|
||||
hf_runner,
|
||||
@pytest.mark.parametrize("greedy", [True, False])
|
||||
@pytest.mark.parametrize("flat_logprobs", [True, False])
|
||||
def test_ranks(
|
||||
vllm_runner,
|
||||
model,
|
||||
dtype,
|
||||
chunked_prefill_token_size: int,
|
||||
num_top_logprobs: int,
|
||||
greedy,
|
||||
flat_logprobs,
|
||||
example_prompts,
|
||||
):
|
||||
max_num_seqs = 256
|
||||
enable_chunked_prefill = False
|
||||
max_num_batched_tokens = None
|
||||
if chunked_prefill_token_size != -1:
|
||||
enable_chunked_prefill = True
|
||||
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model:
|
||||
tokenizer = vllm_model.llm.get_tokenizer()
|
||||
example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts]
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0 if greedy else 1.0,
|
||||
top_p=1.0,
|
||||
max_tokens=MAX_TOKENS,
|
||||
logprobs=NUM_TOP_LOGPROBS,
|
||||
prompt_logprobs=NUM_PROMPT_LOGPROBS,
|
||||
flat_logprobs=flat_logprobs,
|
||||
)
|
||||
results = vllm_model.generate_w_logprobs(example_prompts, sampling_params)
|
||||
|
||||
max_tokens = 5
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_logprobs = hf_model.generate_greedy_logprobs(
|
||||
example_prompts,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
del hf_model
|
||||
assert len(results) == len(example_prompt_tokens)
|
||||
for i, (result, prompt_tokens) in enumerate(zip(results, example_prompt_tokens)):
|
||||
decode_tokens, _, decode_logprobs, prompt_logprobs = result
|
||||
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_logprobs=num_top_logprobs,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_seqs=max_num_seqs,
|
||||
)
|
||||
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
logprobs=num_top_logprobs,
|
||||
prompt_logprobs=num_top_logprobs,
|
||||
temperature=0.0)
|
||||
vllm_results = vllm_model.model.generate(
|
||||
example_prompts, sampling_params=vllm_sampling_params)
|
||||
# Ensure the return type of logprobs is accurate
|
||||
assert isinstance(prompt_logprobs, FlatLogprobs if flat_logprobs else list)
|
||||
assert isinstance(decode_logprobs, FlatLogprobs if flat_logprobs else list)
|
||||
|
||||
# Test whether logprobs are included in the results.
|
||||
for result in vllm_results:
|
||||
assert result.prompt_logprobs is not None
|
||||
assert result.outputs[0].logprobs is not None
|
||||
assert len(result.outputs[0].logprobs) == max_tokens
|
||||
for logprobs in result.outputs[0].logprobs:
|
||||
assert len(logprobs) == num_top_logprobs
|
||||
output_text = result.outputs[0].text
|
||||
output_string_from_most_likely_tokens = []
|
||||
for top_logprobs in result.outputs[0].logprobs:
|
||||
top_logprob = next(iter(top_logprobs.values()))
|
||||
output_string_from_most_likely_tokens.append(
|
||||
top_logprob.decoded_token)
|
||||
output_string_from_most_likely_tokens = "".join(
|
||||
output_string_from_most_likely_tokens)
|
||||
assert output_text == output_string_from_most_likely_tokens, (
|
||||
"The output text from the top logprob for each token position "
|
||||
"should be the same as the output text in the result.")
|
||||
|
||||
# The first prompt logprob is always None
|
||||
assert result.prompt_logprobs[0] is None
|
||||
for prompt_logprobs in result.prompt_logprobs[1:]:
|
||||
# If the prompt token is not included in the top X
|
||||
# logprob, it can return 1 more data
|
||||
assert (len(prompt_logprobs) == num_top_logprobs
|
||||
or len(prompt_logprobs) == num_top_logprobs + 1)
|
||||
|
||||
# Test whether prompt logprobs are consistent with HF
|
||||
for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs):
|
||||
########################
|
||||
# Check prompt logprobs
|
||||
# The first prompt logprob is always None, so we compare it from 1:.
|
||||
vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
|
||||
for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
|
||||
for token_id, logprob in vllm_prompt_logprob_dict.items():
|
||||
torch.testing.assert_close(logprob.logprob,
|
||||
hf_logprob[0][i][token_id].item(),
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
vllm_sample_logprobs = vllm_result.outputs[0].logprobs
|
||||
for i, top_logprobs in enumerate(vllm_sample_logprobs):
|
||||
for token_id, sample_logprob in top_logprobs.items():
|
||||
logprob = sample_logprob.logprob
|
||||
torch.testing.assert_close(logprob,
|
||||
hf_logprob[i][-1][token_id].item(),
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
assert isinstance(sample_logprob.decoded_token, str), (
|
||||
"The token should be decoded by the time it is returned "
|
||||
" to the user.")
|
||||
########################
|
||||
assert len(prompt_tokens) == len(prompt_logprobs)
|
||||
# No logprob for first prompt token
|
||||
assert not prompt_logprobs[0]
|
||||
for position, (token, logprobs) in enumerate(
|
||||
zip(prompt_tokens[1:], prompt_logprobs[1:]), start=1
|
||||
):
|
||||
# Ensure logprobs of prompt token is always returned
|
||||
logprob = logprobs.get(token)
|
||||
assert logprob is not None
|
||||
assert logprob.rank >= 1
|
||||
# Ensure # of returned logprobs should be
|
||||
# either NUM_PROMPT_LOGPROBS or NUM_PROMPT_LOGPROBS+1
|
||||
assert NUM_PROMPT_LOGPROBS <= len(logprobs) <= NUM_PROMPT_LOGPROBS + 1
|
||||
# Ensure top NUM_PROMPT_LOGPROBS is always extracted
|
||||
assert set(range(1, NUM_PROMPT_LOGPROBS + 1)).issubset(
|
||||
{logprob.rank for logprob in logprobs.values()}
|
||||
)
|
||||
|
||||
# Test if prompt logprobs are correctly set.
|
||||
for vllm_result in vllm_results:
|
||||
token_ids = vllm_result.prompt_token_ids
|
||||
prompt_logprobs = vllm_result.prompt_logprobs
|
||||
|
||||
# The first token doesn't have logprob.
|
||||
assert prompt_logprobs[0] is None
|
||||
|
||||
for token_id, logprob_dict in zip(token_ids[1:], prompt_logprobs[1:]):
|
||||
assert token_id in logprob_dict
|
||||
|
||||
|
||||
def test_max_logprobs():
|
||||
runner = VllmRunner("facebook/opt-125m", max_logprobs=1)
|
||||
vllm_sampling_params = SamplingParams(logprobs=1)
|
||||
# should pass
|
||||
runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
|
||||
|
||||
bad_sampling_params = SamplingParams(logprobs=2)
|
||||
with pytest.raises(ValueError):
|
||||
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
|
||||
########################
|
||||
# Check sample logprobs
|
||||
########################
|
||||
assert len(decode_tokens) == len(decode_logprobs)
|
||||
for position, (token, logprobs) in enumerate(
|
||||
zip(decode_tokens, decode_logprobs)
|
||||
):
|
||||
# Ensure logprobs of chosen token is always returned
|
||||
logprob = logprobs.get(token)
|
||||
assert logprob is not None
|
||||
if greedy:
|
||||
# For greedy sampling, all chosen logprob should be top ranked
|
||||
assert logprob.rank == 1
|
||||
else:
|
||||
assert logprob.rank >= 1
|
||||
# Ensure # of returned logprobs should be
|
||||
# either NUM_TOP_LOGPROBS or NUM_TOP_LOGPROBS+1
|
||||
assert NUM_TOP_LOGPROBS <= len(logprobs) <= NUM_TOP_LOGPROBS + 1
|
||||
# Ensure top NUM_TOP_LOGPROBS logprobs is always extracted
|
||||
assert set(range(1, NUM_TOP_LOGPROBS + 1)).issubset(
|
||||
{logprob.rank for logprob in logprobs.values()}
|
||||
)
|
||||
|
||||
185
tests/samplers/test_no_bad_words.py
Normal file
185
tests/samplers/test_no_bad_words.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Make sure bad_words works.
|
||||
|
||||
Run `pytest tests/samplers/test_no_bad_words.py`.
|
||||
|
||||
"""
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def _generate(
|
||||
llm: LLM,
|
||||
prompt: str,
|
||||
num_prompt_tokens: int,
|
||||
temperature: float = 0,
|
||||
bad_words: list[str] | None = None,
|
||||
) -> list[int]:
|
||||
sampling_params = SamplingParams(
|
||||
temperature=temperature,
|
||||
bad_words=bad_words,
|
||||
)
|
||||
|
||||
# [([output_token_ids, ], [output_text, ]), ]
|
||||
output = llm.generate([prompt], sampling_params=sampling_params)
|
||||
|
||||
output_token_ids = output[0][0][0][num_prompt_tokens:]
|
||||
# [0] first (and only) request output
|
||||
# [0] token_ids (not text)
|
||||
# [0] first (and only) output completion
|
||||
|
||||
return output_token_ids
|
||||
|
||||
|
||||
class TestOneTokenBadWord:
|
||||
MODEL = "hmellor/tiny-random-LlamaForCausalLM"
|
||||
|
||||
PROMPT = "How old are "
|
||||
TARGET_TOKEN = "mn"
|
||||
|
||||
def setup_method(self, method):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL)
|
||||
|
||||
self.num_prompt_tokens = len(self._encode(self.PROMPT))
|
||||
self.target_token_id = self._encode(
|
||||
self.TARGET_TOKEN, add_special_tokens=False
|
||||
)[0]
|
||||
|
||||
def test_one_token_bad_word(self, vllm_runner):
|
||||
with vllm_runner(self.MODEL) as llm:
|
||||
output_token_ids = self._generate(llm)
|
||||
assert output_token_ids[0] == self.target_token_id
|
||||
|
||||
output_token_ids = self._generate(llm, bad_words=[self.TARGET_TOKEN])
|
||||
assert self.target_token_id not in output_token_ids
|
||||
|
||||
def _generate(self, llm: LLM, bad_words: list[str] | None = None) -> list[int]:
|
||||
return _generate(
|
||||
llm=llm,
|
||||
prompt=self.PROMPT,
|
||||
num_prompt_tokens=self.num_prompt_tokens,
|
||||
bad_words=bad_words,
|
||||
)
|
||||
|
||||
def _encode(self, prompt: str, add_special_tokens: bool = True) -> list[int]:
|
||||
return self.tokenizer(prompt, add_special_tokens=add_special_tokens).input_ids
|
||||
|
||||
|
||||
class TestTwoTokenBadWord:
|
||||
# Another model (with a different tokenizer behaviour)
|
||||
MODEL = "distilbert/distilgpt2"
|
||||
|
||||
PROMPT = "How old are you? I am 10"
|
||||
TARGET_TOKEN1 = "years"
|
||||
TARGET_TOKEN2 = "old"
|
||||
NEIGHBOUR_TOKEN2 = "older"
|
||||
|
||||
def setup_method(self, method):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.MODEL, add_prefix_space=True
|
||||
)
|
||||
|
||||
self.num_prompt_tokens = len(self._encode(self.PROMPT))
|
||||
self.target_token_id1 = self._encode(
|
||||
self.TARGET_TOKEN1, add_special_tokens=False
|
||||
)[0]
|
||||
self.target_token_id2 = self._encode(
|
||||
self.TARGET_TOKEN2, add_special_tokens=False
|
||||
)[0]
|
||||
self.neighbour_token_id2 = self._encode(
|
||||
self.NEIGHBOUR_TOKEN2, add_special_tokens=False
|
||||
)[0]
|
||||
|
||||
def test_two_token_bad_word(self, vllm_runner):
|
||||
with vllm_runner(self.MODEL, dtype="half") as llm:
|
||||
output_token_ids = self._generate(llm)
|
||||
assert output_token_ids[:2] == [
|
||||
self.target_token_id1,
|
||||
self.target_token_id2,
|
||||
]
|
||||
|
||||
output_token_ids = self._generate(llm, bad_words=[self.TARGET_TOKEN1])
|
||||
assert self.target_token_id1 not in output_token_ids
|
||||
|
||||
output_token_ids = self._generate(llm, bad_words=[self.TARGET_TOKEN2])
|
||||
assert output_token_ids[0] == self.target_token_id1
|
||||
assert self.target_token_id2 not in output_token_ids
|
||||
|
||||
output_token_ids = self._generate(
|
||||
llm, bad_words=[f"{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}"]
|
||||
)
|
||||
assert output_token_ids[0] == self.target_token_id1
|
||||
assert output_token_ids[:2] != [
|
||||
self.target_token_id1,
|
||||
self.target_token_id2,
|
||||
]
|
||||
assert not self._contains(
|
||||
output_token_ids, [self.target_token_id1, self.target_token_id2]
|
||||
)
|
||||
# Model dependent behaviour
|
||||
assert output_token_ids[:2] == [
|
||||
self.target_token_id1,
|
||||
self.neighbour_token_id2,
|
||||
]
|
||||
|
||||
output_token_ids = self._generate(
|
||||
llm,
|
||||
bad_words=[
|
||||
f"{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}",
|
||||
f"{self.TARGET_TOKEN1} {self.NEIGHBOUR_TOKEN2}",
|
||||
],
|
||||
)
|
||||
assert output_token_ids[0] == self.target_token_id1
|
||||
assert output_token_ids[:2] != [
|
||||
self.target_token_id1,
|
||||
self.target_token_id2,
|
||||
]
|
||||
assert not self._contains(
|
||||
output_token_ids, [self.target_token_id1, self.target_token_id2]
|
||||
)
|
||||
assert output_token_ids[:2] != [
|
||||
self.target_token_id1,
|
||||
self.neighbour_token_id2,
|
||||
]
|
||||
assert not self._contains(
|
||||
output_token_ids, [self.target_token_id1, self.neighbour_token_id2]
|
||||
)
|
||||
assert (self.target_token_id2 in output_token_ids) or (
|
||||
self.neighbour_token_id2 in output_token_ids
|
||||
)
|
||||
|
||||
def _generate(self, llm: LLM, bad_words: list[str] | None = None) -> list[int]:
|
||||
return _generate(
|
||||
llm=llm,
|
||||
prompt=self.PROMPT,
|
||||
num_prompt_tokens=self.num_prompt_tokens,
|
||||
bad_words=bad_words,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _contains(sequence: list[int], subsequence: list[int]) -> bool:
|
||||
searched = False
|
||||
|
||||
for start in range(len(sequence)):
|
||||
end = start + len(subsequence)
|
||||
current_subsequence = sequence[start:end]
|
||||
|
||||
if len(current_subsequence) < len(subsequence):
|
||||
continue
|
||||
|
||||
searched = True
|
||||
|
||||
assert len(current_subsequence) == len(subsequence)
|
||||
|
||||
if current_subsequence == subsequence:
|
||||
return True
|
||||
|
||||
assert searched, "All subsequences did not match in length..."
|
||||
|
||||
return False
|
||||
|
||||
def _encode(self, prompt: str, add_special_tokens: bool = True) -> list[int]:
|
||||
return self.tokenizer(prompt, add_special_tokens=add_special_tokens).input_ids
|
||||
@@ -1,50 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
MODELS = ["facebook/opt-125m"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_ranks(
|
||||
vllm_runner,
|
||||
model,
|
||||
dtype,
|
||||
example_prompts,
|
||||
):
|
||||
max_tokens = 5
|
||||
num_top_logprobs = 5
|
||||
num_prompt_logprobs = 5
|
||||
|
||||
vllm_model = vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs)
|
||||
|
||||
## Test greedy logprobs ranks
|
||||
vllm_sampling_params = SamplingParams(temperature=0.0,
|
||||
top_p=1.0,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=num_top_logprobs,
|
||||
prompt_logprobs=num_prompt_logprobs)
|
||||
vllm_results = vllm_model.generate_w_logprobs(example_prompts,
|
||||
vllm_sampling_params)
|
||||
for result in vllm_results:
|
||||
assert result[2] is not None
|
||||
assert len(result[2]) == len(result[0])
|
||||
# check whether all chosen tokens have ranks = 1
|
||||
for token, logprobs in zip(result[0], result[2]):
|
||||
assert token in logprobs
|
||||
assert logprobs[token].rank == 1
|
||||
|
||||
## Test non-greedy logprobs ranks
|
||||
sampling_params = SamplingParams(temperature=1.0,
|
||||
top_p=1.0,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=num_top_logprobs,
|
||||
prompt_logprobs=num_prompt_logprobs)
|
||||
res = vllm_model.generate_w_logprobs(example_prompts, sampling_params)
|
||||
for result in res:
|
||||
assert result[2] is not None
|
||||
assert len(result[2]) == len(result[0])
|
||||
# check whether all chosen tokens have ranks
|
||||
for token, logprobs in zip(result[0], result[2]):
|
||||
assert logprobs[token].rank >= 1
|
||||
@@ -1,385 +0,0 @@
|
||||
"""Tests for rejection sampling."""
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
|
||||
def mock_causal_accepted_tensor(
|
||||
k: int, last_accepted_indices: torch.Tensor) -> torch.Tensor:
|
||||
"""Generate an "accepted" tensor which should yield causally-accepted tokens
|
||||
up to last accepted indices.
|
||||
|
||||
Tokens after last_accepted_indices+1 may also be accepted, although they
|
||||
will not be causally accepted.
|
||||
"""
|
||||
batch_size = last_accepted_indices.shape[0]
|
||||
|
||||
accepted = (torch.arange(k).expand(batch_size, k) <=
|
||||
last_accepted_indices.unsqueeze(-1).broadcast_to(
|
||||
batch_size, k)).to(device="cuda")
|
||||
|
||||
# Sprinkle accepted values after the contiguous initial accepted values.
|
||||
# This replicates the behavior of rejection sampling, which may "accept"
|
||||
# a token that cannot be accepted because of causality.
|
||||
sprinkle_candidates = (
|
||||
torch.arange(k).expand(batch_size, k) >
|
||||
last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) + 1)
|
||||
sprinkle = torch.rand(batch_size, k, device="cuda") > 0.5
|
||||
accepted[sprinkle_candidates] = sprinkle[sprinkle_candidates]
|
||||
return accepted
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(10)))
|
||||
@pytest.mark.parametrize(
|
||||
"which_tokens_accepted",
|
||||
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
||||
device: str):
|
||||
"""Verify the output has correct format given predetermined accepted matrix.
|
||||
"""
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
batch_size = 10
|
||||
k = 5
|
||||
vocab_size = 3000
|
||||
|
||||
if which_tokens_accepted == "all_tokens_accepted":
|
||||
accepted = mock_causal_accepted_tensor(
|
||||
k, -1 + k * torch.ones((batch_size, ), dtype=torch.long))
|
||||
elif which_tokens_accepted == "no_tokens_accepted":
|
||||
accepted = mock_causal_accepted_tensor(
|
||||
k, -torch.ones((batch_size, ), dtype=torch.long))
|
||||
elif which_tokens_accepted == "some_tokens_accepted":
|
||||
last_accepted_indices = torch.randint(low=-1,
|
||||
high=k,
|
||||
size=(batch_size, ))
|
||||
accepted = mock_causal_accepted_tensor(k, last_accepted_indices)
|
||||
else:
|
||||
raise AssertionError()
|
||||
|
||||
recovered_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
|
||||
rejection_sampler = RejectionSampler()
|
||||
rejection_sampler.init_gpu_tensors(rank=0)
|
||||
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
|
||||
accepted,
|
||||
recovered_token_ids,
|
||||
draft_token_ids,
|
||||
bonus_token_ids,
|
||||
)
|
||||
|
||||
# Bonus tokens are currently disabled. Verify they're set to -1.
|
||||
# See https://github.com/vllm-project/vllm/issues/4212
|
||||
expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1
|
||||
|
||||
if which_tokens_accepted == "all_tokens_accepted":
|
||||
# Expect all tokens to be equal to draft tokens.
|
||||
assert torch.equal(output_token_ids[:, :-1], draft_token_ids)
|
||||
|
||||
# Expect all bonus tokens to be included.
|
||||
assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)
|
||||
elif which_tokens_accepted == "no_tokens_accepted":
|
||||
# Expect first token to be equal to recovered tokens.
|
||||
assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])
|
||||
|
||||
# Expect everything else to be -1.
|
||||
assert torch.equal(output_token_ids[:, 1:],
|
||||
torch.ones_like(output_token_ids[:, 1:]) * -1)
|
||||
elif which_tokens_accepted == "some_tokens_accepted":
|
||||
recovered_plus_bonus = torch.cat(
|
||||
(recovered_token_ids, expected_bonus_token_ids), dim=-1)
|
||||
# Assert first rejected token is a recovered token or bonus token.
|
||||
assert torch.equal(
|
||||
recovered_plus_bonus[torch.arange(0, batch_size),
|
||||
last_accepted_indices + 1],
|
||||
output_token_ids[torch.arange(0, batch_size),
|
||||
last_accepted_indices + 1])
|
||||
|
||||
# Assert every subsequent token is -1.
|
||||
subsequent_mask = torch.arange(0, k + 1).expand(
|
||||
batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1)
|
||||
assert torch.all(output_token_ids[subsequent_mask] == -1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k", list(range(1, 6)))
|
||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
device: str):
|
||||
torch.set_default_device(device)
|
||||
rejection_sampler = RejectionSampler()
|
||||
rejection_sampler.init_gpu_tensors(rank=0)
|
||||
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
|
||||
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
||||
draft_token_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
|
||||
@pytest.mark.parametrize("which_token_ids",
|
||||
["bonus_token_ids", "draft_token_ids"])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
||||
which_token_ids: str, device: str):
|
||||
k = 3
|
||||
batch_size = 5
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
|
||||
rejection_sampler = RejectionSampler(strict_mode=True)
|
||||
rejection_sampler.init_gpu_tensors(rank=0)
|
||||
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
|
||||
oob_token_ids = None
|
||||
if which_token_ids == "bonus_token_ids":
|
||||
oob_token_ids = bonus_token_ids
|
||||
elif which_token_ids == "draft_token_ids":
|
||||
oob_token_ids = draft_token_ids
|
||||
else:
|
||||
raise AssertionError()
|
||||
|
||||
if above_or_below_vocab_range == "above":
|
||||
rogue_token_id = vocab_size + 1
|
||||
elif above_or_below_vocab_range == "below":
|
||||
rogue_token_id = -1
|
||||
else:
|
||||
raise AssertionError()
|
||||
|
||||
oob_token_ids[0][0] = rogue_token_id
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
||||
draft_token_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
|
||||
@pytest.mark.parametrize("seed", list(range(5)))
|
||||
@torch.inference_mode()
|
||||
def test_rejection_sampling_approximates_target_distribution(
|
||||
seed: int, draft_and_target_probs_equal: bool):
|
||||
"""Verify rejection sampling approximates target distribution,
|
||||
despite sampling from a potentially distinct draft distribution.
|
||||
|
||||
This is done by first creating a random target probability
|
||||
distribution and a random draft probability distribution. We then
|
||||
sample token ids from the rejection sampler using these draft
|
||||
and target distributions. The samples are used to estimate
|
||||
the output probability distribution, which we expect to approximate
|
||||
the target distribution.
|
||||
|
||||
A basic distance metric is used to determine similarity between
|
||||
distributions.
|
||||
|
||||
We expect that as we increase the number of samples,
|
||||
the distance between the observed distribution and the target
|
||||
distribution decreases. To measure this, we compare the distance
|
||||
of the observed distribution against both the target distribution
|
||||
and a uniform random distribution. We expect the distance between
|
||||
the observed distribution and the target distribution to improve
|
||||
much more than the distance improvement between the observed
|
||||
distribution and the random distribution.
|
||||
|
||||
When draft_and_target_probs_equal=True, the draft and target
|
||||
probabilities are exactly equal. Rejection sampling should
|
||||
still work without any NaNs or exceptions.
|
||||
"""
|
||||
torch.set_default_device("cpu")
|
||||
set_random_seed(seed)
|
||||
|
||||
helper = _CorrectnessTestHelper(
|
||||
vocab_size=10,
|
||||
rejection_sampler=RejectionSampler(),
|
||||
)
|
||||
|
||||
draft_probs, target_probs, reference_probs = helper.generate_probs_for_test(
|
||||
draft_and_target_probs_equal)
|
||||
|
||||
sample_sizes = [10, 100, 1_000, 10_000, 100_000]
|
||||
distance_wrt_reference = []
|
||||
distance_wrt_target = []
|
||||
|
||||
for num_samples in sample_sizes:
|
||||
(reference_vs_rejsample_dist,
|
||||
target_vs_rejsample_dist) = helper.run_and_compare_distributions(
|
||||
draft_probs,
|
||||
target_probs,
|
||||
reference_probs,
|
||||
num_samples,
|
||||
)
|
||||
|
||||
distance_wrt_reference.append(reference_vs_rejsample_dist)
|
||||
distance_wrt_target.append(target_vs_rejsample_dist)
|
||||
|
||||
relative_change_in_distance_wrt_target = get_ratio_first_to_last(
|
||||
distance_wrt_target)
|
||||
relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
|
||||
distance_wrt_reference)
|
||||
|
||||
print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
|
||||
f"{reference_vs_rejsample_dist=:.05f}")
|
||||
print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} "
|
||||
f"{relative_change_in_distance_wrt_reference=:.02f}")
|
||||
|
||||
relative_change_in_distance_wrt_target = get_ratio_first_to_last(
|
||||
distance_wrt_target)
|
||||
relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
|
||||
distance_wrt_reference)
|
||||
|
||||
expected_improvement_multiplier = 20
|
||||
assert (relative_change_in_distance_wrt_target >
|
||||
relative_change_in_distance_wrt_reference *
|
||||
expected_improvement_multiplier)
|
||||
|
||||
|
||||
def get_ratio_first_to_last(elements: List[float]) -> float:
|
||||
return elements[0] / elements[-1]
|
||||
|
||||
|
||||
class _CorrectnessTestHelper:
|
||||
"""Class that packages together logic required for the unit-level
|
||||
rejection sampling correctness test.
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_size: int, rejection_sampler: RejectionSampler):
|
||||
self.rejection_sampler = rejection_sampler
|
||||
self.vocab_size = vocab_size
|
||||
self.vocab_range = (0, vocab_size)
|
||||
|
||||
self.rejection_sampler.init_gpu_tensors(rank=0)
|
||||
|
||||
# Keep test simple, use k=1
|
||||
self.k = 1
|
||||
|
||||
# Bonus tokens not used, but rejection sampler requires
|
||||
# correct shape.
|
||||
self.num_bonus_tokens = 1
|
||||
|
||||
def generate_probs_for_test(
|
||||
self, draft_and_target_probs_equal: bool
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
draft_probs, target_probs = [
|
||||
F.softmax(
|
||||
torch.rand(self.vocab_size, dtype=torch.float32),
|
||||
dim=-1,
|
||||
) for _ in range(2)
|
||||
]
|
||||
|
||||
num_reference_probs = 100
|
||||
reference_probs = F.softmax(
|
||||
torch.rand(num_reference_probs,
|
||||
self.vocab_size,
|
||||
dtype=torch.float32),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if draft_and_target_probs_equal:
|
||||
target_probs = draft_probs.clone()
|
||||
|
||||
return draft_probs, target_probs, reference_probs
|
||||
|
||||
def run_and_compare_distributions(self, draft_probs: torch.Tensor,
|
||||
target_probs: torch.Tensor,
|
||||
reference_probs: torch.Tensor,
|
||||
num_samples: int) -> Tuple[float, float]:
|
||||
# Sample using rejection sampling.
|
||||
rej_sample_probs = self._estimate_rejection_sampling_pdf(
|
||||
draft_probs, target_probs, num_samples)
|
||||
|
||||
# Average distance from reference probs.
|
||||
reference_vs_rejsample_dist = torch.dist(
|
||||
reference_probs,
|
||||
rej_sample_probs).item() / reference_probs.shape[0]
|
||||
target_vs_rejsample_dist = torch.dist(target_probs,
|
||||
rej_sample_probs).item()
|
||||
|
||||
return reference_vs_rejsample_dist, target_vs_rejsample_dist
|
||||
|
||||
def _estimate_rejection_sampling_pdf(
|
||||
self,
|
||||
draft_probs: torch.Tensor,
|
||||
target_probs: torch.Tensor,
|
||||
num_samples: int,
|
||||
) -> torch.Tensor:
|
||||
# Repeat draft probs num_samples times.
|
||||
draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat(
|
||||
num_samples, 1, 1)
|
||||
|
||||
# Repeat target probs num_samples * k times.
|
||||
# Rejection sampler requires bonus token probs, but they aren't used.
|
||||
target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat(
|
||||
num_samples, self.k, 1)
|
||||
|
||||
# Randomly sample draft token ids from draft probs.
|
||||
draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
|
||||
num_samples=1,
|
||||
replacement=True).reshape(
|
||||
num_samples, self.k)
|
||||
|
||||
# Bonus tokens not used but required.
|
||||
bonus_token_ids = torch.zeros((1, self.num_bonus_tokens),
|
||||
dtype=torch.int64,
|
||||
device="cuda").repeat(num_samples, 1)
|
||||
|
||||
# Get output tokens via rejection sampling.
|
||||
output_token_ids = self.rejection_sampler(target_probs.to("cuda"),
|
||||
bonus_token_ids.to("cuda"),
|
||||
draft_probs.to("cuda"),
|
||||
draft_token_ids.to("cuda"))
|
||||
|
||||
# Remove bonus tokens
|
||||
output_token_ids = output_token_ids[:, :-1].flatten()
|
||||
|
||||
# Estimate probability density function
|
||||
hist = torch.histogram(output_token_ids.to(dtype=torch.float,
|
||||
device="cpu"),
|
||||
bins=self.vocab_size,
|
||||
range=self.vocab_range,
|
||||
density=True)
|
||||
|
||||
return hist.hist
|
||||
@@ -1,661 +0,0 @@
|
||||
import itertools
|
||||
import random
|
||||
from typing import List, Optional, Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import GenerationConfig, GenerationMixin
|
||||
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import Counter
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
|
||||
|
||||
class MockLogitsSampler(Sampler):
|
||||
|
||||
def __init__(self, fake_logits: torch.Tensor):
|
||||
super().__init__()
|
||||
self.fake_logits = fake_logits
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
def _prepare_test(
|
||||
batch_size: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
|
||||
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
|
||||
fake_logits = torch.full((batch_size, VOCAB_SIZE),
|
||||
1e-2,
|
||||
dtype=input_tensor.dtype)
|
||||
sampler = MockLogitsSampler(fake_logits)
|
||||
model_runner = ModelRunner(model_config=None,
|
||||
parallel_config=None,
|
||||
scheduler_config=None,
|
||||
device_config=None,
|
||||
load_config=None,
|
||||
lora_config=None)
|
||||
return input_tensor, fake_logits, sampler, model_runner
|
||||
|
||||
|
||||
VOCAB_SIZE = 32000
|
||||
RANDOM_SEEDS = list(range(128))
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
|
||||
def _do_sample(
|
||||
batch_size: int,
|
||||
input_tensor: torch.Tensor,
|
||||
sampler: MockLogitsSampler,
|
||||
model_runner: ModelRunner,
|
||||
sampling_params: SamplingParams,
|
||||
device: str,
|
||||
):
|
||||
seq_group_metadata_list = []
|
||||
seq_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=sampling_params,
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_greedy(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
||||
batch_size)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
|
||||
sampling_params, device)
|
||||
expected = torch.argmax(fake_logits, dim=-1)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token == expected[i].item()
|
||||
|
||||
del model_runner
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_random(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
||||
batch_size)
|
||||
|
||||
for i in range(batch_size):
|
||||
fake_logits[i, i] = 1e2
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
)
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
|
||||
sampling_params, device)
|
||||
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token == i
|
||||
|
||||
del model_runner
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_random_seed(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
|
||||
|
||||
for i in range(batch_size):
|
||||
fake_logits[i, i] = 1e2
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
seed=random.randint(0, 10000),
|
||||
)
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
|
||||
sampling_params, device)
|
||||
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token == i
|
||||
|
||||
del model_runner
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
seed=random.randint(0, 10000),
|
||||
)
|
||||
first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
model_runner, sampling_params, device)
|
||||
|
||||
second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
model_runner, sampling_params, device)
|
||||
|
||||
assert first_sampler_output == second_sampler_output
|
||||
|
||||
del model_runner
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_beam(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
best_of=2,
|
||||
use_beam_search=True,
|
||||
)
|
||||
_do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params,
|
||||
device)
|
||||
# no assertion here as I am not sure how to determine whether
|
||||
# the outputs are expected - in other words, this just tests
|
||||
# whether there are no exceptions in the sampler
|
||||
# when handling an all-beam search case.
|
||||
del model_runner
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
seq_id_counter = Counter(start=random.randint(0, 100))
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
def create_sampling_params(min_tokens,
|
||||
eos_token_id=0,
|
||||
*,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
prompt_logprobs: Optional[int] = None):
|
||||
sampling_params = SamplingParams(
|
||||
min_tokens=min_tokens,
|
||||
max_tokens=9999, # keep higher than max of min_tokens
|
||||
stop_token_ids=stop_token_ids,
|
||||
# requesting prompt_logprobs changes the structure of `logits`
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
sampling_params.all_stop_token_ids.add(eos_token_id)
|
||||
return sampling_params
|
||||
|
||||
def create_sequence_data(num_input=3, num_generated=0):
|
||||
seq_data = SequenceData(
|
||||
random.choices(range(0, VOCAB_SIZE), k=num_input))
|
||||
if num_generated > 0:
|
||||
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
|
||||
k=num_generated)
|
||||
return seq_data
|
||||
|
||||
def generate_test_case():
|
||||
# generate multiple seq groups but limit total batch size
|
||||
batch_size = random.randint(1, 128)
|
||||
|
||||
expected_penalization = []
|
||||
sequence_metadata_list = []
|
||||
# 20% chance to generate seq group metadata list with all prompts
|
||||
is_prompt = random.random() < 0.2
|
||||
while batch_size > 0:
|
||||
num_seqs = 1 if is_prompt else random.randint(1, batch_size)
|
||||
|
||||
eos_token_id = random.randint(0, VOCAB_SIZE - 1)
|
||||
min_tokens = random.randint(0, 50)
|
||||
num_stop_tokens = random.randint(0, 8)
|
||||
if num_stop_tokens > 0:
|
||||
stop_token_ids = random.choices(range(0, VOCAB_SIZE - 1),
|
||||
k=num_stop_tokens)
|
||||
else:
|
||||
stop_token_ids = None
|
||||
|
||||
sampling_params = create_sampling_params(
|
||||
min_tokens=min_tokens,
|
||||
eos_token_id=eos_token_id,
|
||||
stop_token_ids=stop_token_ids)
|
||||
|
||||
seq_data = {}
|
||||
seq_group_penalization = []
|
||||
for _ in range(num_seqs):
|
||||
num_input = random.randint(1, 100)
|
||||
num_generated = 0 if is_prompt else random.randint(1, 100)
|
||||
seq_data[next(seq_id_counter)] = create_sequence_data(
|
||||
num_input=num_input, num_generated=num_generated)
|
||||
seq_group_penalization.append(num_generated < min_tokens)
|
||||
|
||||
expected_penalization.extend(seq_group_penalization)
|
||||
sequence_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{batch_size}",
|
||||
is_prompt=is_prompt,
|
||||
seq_data=seq_data,
|
||||
sampling_params=sampling_params,
|
||||
block_tables={},
|
||||
))
|
||||
batch_size -= num_seqs
|
||||
|
||||
return {
|
||||
"expected_penalization": expected_penalization,
|
||||
"seq_group_metadata_list": sequence_metadata_list,
|
||||
}
|
||||
|
||||
# define some explicit test cases for edge case behavior
|
||||
prompt_without_penalization = {
|
||||
"expected_penalization": [False],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(),
|
||||
},
|
||||
sampling_params=create_sampling_params(0),
|
||||
block_tables={},
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
prompt_with_penalization = {
|
||||
"expected_penalization": [True],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(),
|
||||
},
|
||||
sampling_params=create_sampling_params(1),
|
||||
block_tables={},
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
prompt_with_penalization_and_prompt_logprobs = {
|
||||
"expected_penalization": [False, False, True],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(num_input=3),
|
||||
},
|
||||
sampling_params=create_sampling_params(1, prompt_logprobs=3),
|
||||
block_tables={},
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
stop_penalizing_after_min_tokens = {
|
||||
"expected_penalization": [False],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=False,
|
||||
seq_data={
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=1),
|
||||
},
|
||||
sampling_params=create_sampling_params(1),
|
||||
block_tables={},
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
stop_token_ids = [42, 99, 42, 0] # intentional duplication
|
||||
prompt_combination = {
|
||||
"expected_penalization": [False, True, False],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_2",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(num_input=2),
|
||||
},
|
||||
sampling_params=create_sampling_params(1, prompt_logprobs=3),
|
||||
block_tables={},
|
||||
),
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_3",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(),
|
||||
},
|
||||
sampling_params=create_sampling_params(
|
||||
0, stop_token_ids=stop_token_ids),
|
||||
block_tables={},
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
stop_token_ids = [1, 999, 37, 37] # intentional duplication
|
||||
decode_combination = {
|
||||
"expected_penalization": [True, False, False, True, False],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=False,
|
||||
seq_data={
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=1),
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=100),
|
||||
},
|
||||
sampling_params=create_sampling_params(
|
||||
2, stop_token_ids=stop_token_ids),
|
||||
block_tables={},
|
||||
),
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_2",
|
||||
is_prompt=False,
|
||||
seq_data={
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=20),
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=1),
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=10),
|
||||
},
|
||||
sampling_params=create_sampling_params(
|
||||
10, prompt_logprobs=5, stop_token_ids=stop_token_ids),
|
||||
block_tables={},
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
if seed == 0:
|
||||
test_cases = [
|
||||
prompt_without_penalization,
|
||||
prompt_with_penalization,
|
||||
prompt_with_penalization_and_prompt_logprobs,
|
||||
stop_penalizing_after_min_tokens,
|
||||
prompt_combination,
|
||||
decode_combination,
|
||||
]
|
||||
else:
|
||||
test_cases = [generate_test_case()]
|
||||
|
||||
def run_test_case(*,
|
||||
expected_penalization=None,
|
||||
seq_group_metadata_list=None):
|
||||
assert expected_penalization, \
|
||||
"Invalid test case, need expected_penalization"
|
||||
assert seq_group_metadata_list, \
|
||||
"Invalid test case, need seq_group_metadata_list"
|
||||
|
||||
batch_size = 0
|
||||
seq_lens = []
|
||||
sampling_params_per_row = []
|
||||
for sgm in seq_group_metadata_list:
|
||||
sampling_params = sgm.sampling_params
|
||||
|
||||
num_rows = len(sgm.seq_data)
|
||||
if sgm.is_prompt:
|
||||
# a prompt seq_group has only one sequence
|
||||
seq_data = next(iter(sgm.seq_data.values()))
|
||||
prompt_len = seq_data.get_prompt_len()
|
||||
seq_lens.append(prompt_len)
|
||||
|
||||
if sgm.sampling_params.prompt_logprobs:
|
||||
# with prompt_logprobs each token in the prompt has a row in
|
||||
# logits
|
||||
num_rows = prompt_len
|
||||
|
||||
batch_size += num_rows
|
||||
sampling_params_per_row.extend(
|
||||
itertools.repeat(sampling_params, num_rows))
|
||||
|
||||
assert len(
|
||||
expected_penalization
|
||||
) == batch_size, \
|
||||
("Invalid test case, expected_penalization does not match computed"
|
||||
"batch size")
|
||||
|
||||
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens=seq_lens if seq_lens else None,
|
||||
query_lens=seq_lens if seq_lens else None,
|
||||
device=device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
# the logits tensor is modified in-place by the sampler
|
||||
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
||||
|
||||
for logits_idx, (should_penalize, sampling_params) in enumerate(
|
||||
zip(expected_penalization, sampling_params_per_row)):
|
||||
|
||||
tokens_to_check = sampling_params.all_stop_token_ids
|
||||
|
||||
if should_penalize:
|
||||
for token_id in tokens_to_check:
|
||||
assert fake_logits[logits_idx, token_id] == -float(
|
||||
'inf'
|
||||
), f"Expected token {token_id} for logits row {logits_idx}"
|
||||
" to be penalized"
|
||||
# no other tokens should be set to -inf
|
||||
assert torch.count_nonzero(
|
||||
fake_logits[logits_idx, :] == -float('inf')) == len(
|
||||
tokens_to_check
|
||||
), f"Expected only {len(tokens_to_check)} to be penalized"
|
||||
else:
|
||||
# no tokens should be set to -inf
|
||||
assert torch.count_nonzero(
|
||||
fake_logits[logits_idx, :] ==
|
||||
-float('inf')) == 0, "No tokens should have been penalized"
|
||||
|
||||
del model_runner
|
||||
|
||||
for test_case in test_cases:
|
||||
run_test_case(**test_case)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_mixed(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
||||
batch_size)
|
||||
|
||||
seq_group_metadata_list = []
|
||||
expected_tokens: List[Optional[List[int]]] = []
|
||||
seq_lens = []
|
||||
for i in range(batch_size):
|
||||
expected: Optional[List[int]] = None
|
||||
sampling_type = random.randint(0, 3)
|
||||
if sampling_type == 0:
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
expected = [torch.argmax(fake_logits[i], dim=-1).item()]
|
||||
elif sampling_type in (1, 2):
|
||||
n = random.randint(1, 10)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=random.random() + 0.1,
|
||||
top_p=min(random.random() + 0.1, 1),
|
||||
top_k=random.randint(0, 10) or -1,
|
||||
n=n,
|
||||
presence_penalty=random.randint(0, 1),
|
||||
)
|
||||
if sampling_type == 2:
|
||||
sampling_params.seed = random.randint(0, 10000)
|
||||
else:
|
||||
for idx in range(n):
|
||||
fake_logits[i, i + idx] = 1e2
|
||||
expected = list(range(i, i + n))
|
||||
else:
|
||||
sampling_params = SamplingParams(temperature=0,
|
||||
use_beam_search=True,
|
||||
best_of=2)
|
||||
expected_tokens.append(expected)
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=sampling_params,
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
def test_sampling(model_runner: ModelRunner):
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
sampler_output = sampler(logits=fake_logits,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
for i, (sequence_output, metadata) in enumerate(
|
||||
zip(sampler_output, seq_group_metadata_list)):
|
||||
if metadata.sampling_params.use_beam_search:
|
||||
continue
|
||||
|
||||
if (metadata.sampling_params.seed is not None
|
||||
and expected_tokens[i] is None):
|
||||
# Record seeded random result to compare with results of
|
||||
# second invocation
|
||||
expected_tokens[i] = [
|
||||
nth_output.output_token
|
||||
for nth_output in sequence_output.samples
|
||||
]
|
||||
continue
|
||||
|
||||
for n, nth_output in enumerate(sequence_output.samples):
|
||||
if (metadata.sampling_params.temperature == 0
|
||||
or metadata.sampling_params.seed is not None):
|
||||
# Ensure exact matches for greedy or random with seed
|
||||
assert nth_output.output_token == expected_tokens[i][n]
|
||||
else:
|
||||
# For non-seeded random check that one of the high-logit
|
||||
# tokens were chosen
|
||||
assert nth_output.output_token in expected_tokens[i]
|
||||
|
||||
# Test batch
|
||||
test_sampling(model_runner)
|
||||
|
||||
# Shuffle the batch and resample
|
||||
target_index = list(range(batch_size))
|
||||
for list_to_shuffle in (target_index, seq_group_metadata_list,
|
||||
expected_tokens, seq_lens):
|
||||
random.Random(seed).shuffle(list_to_shuffle)
|
||||
target_index = torch.tensor(target_index)
|
||||
input_tensor.data = input_tensor.index_select(0, target_index)
|
||||
fake_logits.data = fake_logits.index_select(0, target_index)
|
||||
|
||||
# This time, results of seeded random samples will be compared with
|
||||
# the corresponding sample in the pre-shuffled batch
|
||||
test_sampling(model_runner)
|
||||
|
||||
del model_runner
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
top_k = random.randint(100, 500)
|
||||
top_p = random.random() * 0.1
|
||||
vocab_size = 32000
|
||||
input_tensor = torch.rand((batch_size, 1024),
|
||||
device=device,
|
||||
dtype=torch.float16)
|
||||
fake_logits = torch.normal(0,
|
||||
5,
|
||||
size=(batch_size, vocab_size),
|
||||
device=input_tensor.device,
|
||||
dtype=input_tensor.dtype)
|
||||
sampler = MockLogitsSampler(fake_logits)
|
||||
model_runner = ModelRunner(model_config=None,
|
||||
parallel_config=None,
|
||||
scheduler_config=None,
|
||||
device_config=None,
|
||||
load_config=None,
|
||||
lora_config=None)
|
||||
|
||||
generation_model = GenerationMixin()
|
||||
generation_config = GenerationConfig(top_k=top_k,
|
||||
top_p=top_p,
|
||||
do_sample=True)
|
||||
warpers = generation_model._get_logits_warper(generation_config)
|
||||
assert len(warpers) == 2 # top_p and top_k
|
||||
|
||||
seq_group_metadata_list = []
|
||||
seq_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=SamplingParams(
|
||||
temperature=1,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
|
||||
sample_probs = None
|
||||
|
||||
def mock_sample(probs, *args, **kwargs):
|
||||
nonlocal sample_probs
|
||||
sample_probs = probs
|
||||
return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
|
||||
for prob in probs], None)
|
||||
|
||||
with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
|
||||
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
||||
hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
|
||||
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
|
||||
assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
|
||||
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
|
||||
|
||||
del model_runner
|
||||
@@ -1,82 +0,0 @@
|
||||
"""Verify that seeded random sampling is deterministic.
|
||||
|
||||
Run `pytest tests/samplers/test_seeded_generate.py`.
|
||||
"""
|
||||
import copy
|
||||
import random
|
||||
from itertools import combinations
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
MODEL = "facebook/opt-125m"
|
||||
RANDOM_SEEDS = list(range(5))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vllm_model(vllm_runner):
|
||||
vllm_model = vllm_runner(MODEL, dtype="half")
|
||||
yield vllm_model
|
||||
del vllm_model
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_random_sample_with_seed(
|
||||
vllm_model,
|
||||
example_prompts,
|
||||
seed: int,
|
||||
) -> None:
|
||||
set_random_seed(seed)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
# Parameters to ensure sufficient randomness
|
||||
temperature=2.0,
|
||||
top_p=min(random.random() + 0.3, 1),
|
||||
top_k=random.randint(5, 20),
|
||||
n=random.randint(1, 10),
|
||||
presence_penalty=random.randint(0, 1),
|
||||
max_tokens=8,
|
||||
ignore_eos=True,
|
||||
)
|
||||
|
||||
sampling_params_seed_1 = copy.deepcopy(sampling_params)
|
||||
sampling_params_seed_1.seed = 100
|
||||
sampling_params_seed_2 = copy.deepcopy(sampling_params)
|
||||
sampling_params_seed_2.seed = 200
|
||||
|
||||
llm = vllm_model.model
|
||||
|
||||
for prompt in example_prompts:
|
||||
for params in (
|
||||
sampling_params,
|
||||
sampling_params_seed_1,
|
||||
sampling_params_seed_2,
|
||||
sampling_params,
|
||||
sampling_params_seed_1,
|
||||
sampling_params_seed_2,
|
||||
):
|
||||
llm._add_request(
|
||||
prompt=prompt,
|
||||
prompt_token_ids=None,
|
||||
sampling_params=params,
|
||||
)
|
||||
|
||||
results = llm._run_engine(use_tqdm=False)
|
||||
all_outputs = [[out.token_ids for out in output.outputs]
|
||||
for output in results]
|
||||
|
||||
for i in range(0, len(example_prompts), 6):
|
||||
outputs = all_outputs[i:i + 6]
|
||||
|
||||
# verify all non-seeded requests differ
|
||||
for output_a, output_b in combinations(
|
||||
(outputs[0], outputs[1], outputs[2], outputs[3]),
|
||||
2,
|
||||
):
|
||||
assert output_a != output_b
|
||||
|
||||
# verify requests with the same seed match
|
||||
assert outputs[1] == outputs[4]
|
||||
assert outputs[2] == outputs[5]
|
||||
Reference in New Issue
Block a user