forked from EngineX-Cambricon/enginex-mlu370-vllm
add qwen3
This commit is contained in:
0
vllm-v0.6.2/tests/samplers/__init__.py
Normal file
0
vllm-v0.6.2/tests/samplers/__init__.py
Normal file
53
vllm-v0.6.2/tests/samplers/test_beam_search.py
Normal file
53
vllm-v0.6.2/tests/samplers/test_beam_search.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Compare the outputs of HF and vLLM when using beam search.
|
||||
|
||||
Run `pytest tests/samplers/test_beam_search.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
# FIXME(zhuohan): The test can not pass if we:
|
||||
# 1. Increase max_tokens to 256.
|
||||
# 2. Increase beam_width to 8.
|
||||
# 3. Use the model "huggyllama/llama-7b".
|
||||
MAX_TOKENS = [64]
|
||||
BEAM_WIDTHS = [4]
|
||||
MODELS = ["facebook/opt-125m"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
|
||||
@pytest.mark.parametrize("beam_width", BEAM_WIDTHS)
|
||||
def test_beam_search_single_input(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
beam_width: int,
|
||||
) -> None:
|
||||
example_prompts = example_prompts[:1]
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
|
||||
max_tokens)
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_beam_search(example_prompts,
|
||||
beam_width, max_tokens)
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_texts = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_texts = vllm_outputs[i]
|
||||
for i, (hf_text,
|
||||
vllm_text) in enumerate(zip(hf_output_texts,
|
||||
vllm_output_texts)):
|
||||
print(f">>>{i}-th hf output:")
|
||||
print(hf_text)
|
||||
print(f">>>{i}-th vllm output:")
|
||||
print(vllm_text)
|
||||
assert len(hf_output_ids) == len(vllm_output_ids)
|
||||
for j in range(len(hf_output_ids)):
|
||||
assert hf_output_ids[j] == vllm_output_ids[j], (
|
||||
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
|
||||
f"vLLM: {vllm_output_ids}")
|
||||
33
vllm-v0.6.2/tests/samplers/test_ignore_eos.py
Normal file
33
vllm-v0.6.2/tests/samplers/test_ignore_eos.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Make sure ignore_eos works.
|
||||
|
||||
Run `pytest tests/samplers/test_ignore_eos.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
# We also test with llama because it has generation_config to specify EOS
|
||||
# (past regression).
|
||||
MODELS = ["facebook/opt-125m", "meta-llama/Llama-2-7b-hf"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [512])
|
||||
def test_ignore_eos(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
ignore_eos=True)
|
||||
|
||||
for prompt in example_prompts:
|
||||
ignore_eos_output = vllm_model.model.generate(
|
||||
prompt, sampling_params=sampling_params)
|
||||
output_length = len(ignore_eos_output[0].outputs[0].token_ids)
|
||||
assert output_length == max_tokens
|
||||
59
vllm-v0.6.2/tests/samplers/test_logits_processor.py
Normal file
59
vllm-v0.6.2/tests/samplers/test_logits_processor.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
MODELS = ["facebook/opt-125m"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_logits_processor_force_generate(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
tokenizer = vllm_model.model.get_tokenizer()
|
||||
repeat_times = 2
|
||||
enforced_answers = " vLLM"
|
||||
vllm_token_ids = tokenizer.encode(enforced_answers,
|
||||
add_special_tokens=False)
|
||||
max_tokens = len(vllm_token_ids) * repeat_times
|
||||
|
||||
def pick_vllm(token_ids, logits):
|
||||
token_id = vllm_token_ids[len(token_ids) % len(vllm_token_ids)]
|
||||
logits[token_id] = torch.finfo(logits.dtype).max
|
||||
return logits
|
||||
|
||||
params_with_logprobs = SamplingParams(
|
||||
logits_processors=[pick_vllm],
|
||||
prompt_logprobs=3,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
# test logits_processors when prompt_logprobs is not None
|
||||
vllm_model.model._add_request(
|
||||
example_prompts[0],
|
||||
params=params_with_logprobs,
|
||||
)
|
||||
|
||||
# test prompt_logprobs is not None
|
||||
vllm_model.model._add_request(
|
||||
example_prompts[1],
|
||||
params=SamplingParams(
|
||||
prompt_logprobs=3,
|
||||
max_tokens=max_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
# test grouped requests
|
||||
vllm_model.model._add_request(
|
||||
example_prompts[2],
|
||||
params=SamplingParams(max_tokens=max_tokens),
|
||||
)
|
||||
|
||||
outputs = vllm_model.model._run_engine(use_tqdm=False)
|
||||
|
||||
assert outputs[0].outputs[0].text == enforced_answers * repeat_times
|
||||
182
vllm-v0.6.2/tests/samplers/test_logprobs.py
Normal file
182
vllm-v0.6.2/tests/samplers/test_logprobs.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
from ..conftest import VllmRunner
|
||||
|
||||
MODELS = ["facebook/opt-125m"]
|
||||
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
NOTES: chunked_prefill_token_size=1 contains some accuracy issue.
|
||||
So we skip this case in mlu ut.
|
||||
TODO(VLLM-662): fix accuracy error
|
||||
'''
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype",
|
||||
["float"]) # needed for comparing logprobs with HF
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [4, 16, -1])
|
||||
@pytest.mark.parametrize("num_top_logprobs", [0, 6]) # 32000 == vocab_size
|
||||
@pytest.mark.parametrize("detokenize", [True, False])
|
||||
def test_get_prompt_logprobs(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
model,
|
||||
dtype,
|
||||
chunked_prefill_token_size: int,
|
||||
num_top_logprobs: int,
|
||||
detokenize: bool,
|
||||
example_prompts,
|
||||
):
|
||||
max_num_seqs = 256
|
||||
enable_chunked_prefill = False
|
||||
max_num_batched_tokens = None
|
||||
if chunked_prefill_token_size != -1:
|
||||
enable_chunked_prefill = True
|
||||
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
max_tokens = 5
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_logprobs = hf_model.generate_greedy_logprobs(
|
||||
example_prompts,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_logprobs=num_top_logprobs,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_seqs=max_num_seqs,
|
||||
gpu_memory_utilization=0.6,
|
||||
) as vllm_model:
|
||||
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
logprobs=num_top_logprobs,
|
||||
prompt_logprobs=num_top_logprobs,
|
||||
temperature=0.0,
|
||||
detokenize=detokenize)
|
||||
vllm_results = vllm_model.model.generate(
|
||||
example_prompts, sampling_params=vllm_sampling_params)
|
||||
|
||||
# Test whether logprobs are included in the results.
|
||||
for result in vllm_results:
|
||||
assert result.prompt_logprobs is not None
|
||||
assert result.outputs[0].logprobs is not None
|
||||
assert len(result.outputs[0].logprobs) == max_tokens
|
||||
for logprobs in result.outputs[0].logprobs:
|
||||
# If the output token is not included in the top X
|
||||
# logprob, it can return 1 more data
|
||||
assert (len(logprobs) == num_top_logprobs
|
||||
or len(logprobs) == num_top_logprobs + 1)
|
||||
output_text = result.outputs[0].text
|
||||
output_string_from_most_likely_tokens_lst: List[str] = []
|
||||
for top_logprobs in result.outputs[0].logprobs:
|
||||
top_logprob = next(iter(top_logprobs.values()))
|
||||
output_string_from_most_likely_tokens_lst.append(
|
||||
top_logprob.decoded_token)
|
||||
|
||||
if detokenize:
|
||||
output_string_from_most_likely_tokens = "".join(
|
||||
output_string_from_most_likely_tokens_lst)
|
||||
assert output_text == output_string_from_most_likely_tokens, (
|
||||
"The output text from the top logprob for each token position "
|
||||
"should be the same as the output text in the result.")
|
||||
else:
|
||||
assert output_text == ''
|
||||
assert output_string_from_most_likely_tokens_lst == ([None] *
|
||||
max_tokens)
|
||||
|
||||
# The first prompt logprob is always None
|
||||
assert result.prompt_logprobs[0] is None
|
||||
for prompt_logprobs in result.prompt_logprobs[1:]:
|
||||
# If the prompt token is not included in the top X
|
||||
# logprob, it can return 1 more data
|
||||
assert (len(prompt_logprobs) == num_top_logprobs
|
||||
or len(prompt_logprobs) == num_top_logprobs + 1)
|
||||
|
||||
# Test whether prompt logprobs are consistent with HF
|
||||
for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs):
|
||||
# Check prompt logprobs
|
||||
# The first prompt logprob is always None, so we compare it from 1:.
|
||||
vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
|
||||
for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
|
||||
for token_id, logprob in vllm_prompt_logprob_dict.items():
|
||||
torch.testing.assert_close(logprob.logprob,
|
||||
hf_logprob[0][i][token_id].item(),
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
vllm_sample_logprobs = vllm_result.outputs[0].logprobs
|
||||
for i, top_logprobs in enumerate(vllm_sample_logprobs):
|
||||
for token_id, sample_logprob in top_logprobs.items():
|
||||
logprob = sample_logprob.logprob
|
||||
torch.testing.assert_close(logprob,
|
||||
hf_logprob[i][-1][token_id].item(),
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
if detokenize:
|
||||
assert isinstance(sample_logprob.decoded_token, str), (
|
||||
"The token should be decoded by the time it is returned"
|
||||
" to the user.")
|
||||
|
||||
# Test if prompt logprobs are correctly set.
|
||||
for vllm_result in vllm_results:
|
||||
token_ids = vllm_result.prompt_token_ids
|
||||
prompt_logprobs = vllm_result.prompt_logprobs
|
||||
|
||||
# The first token doesn't have logprob.
|
||||
assert prompt_logprobs[0] is None
|
||||
|
||||
for token_id, logprob_dict in zip(token_ids[1:], prompt_logprobs[1:]):
|
||||
assert token_id in logprob_dict
|
||||
|
||||
|
||||
def test_max_logprobs():
|
||||
runner = VllmRunner("facebook/opt-125m", max_logprobs=1)
|
||||
vllm_sampling_params = SamplingParams(logprobs=1)
|
||||
# should pass
|
||||
runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
|
||||
|
||||
bad_sampling_params = SamplingParams(logprobs=2)
|
||||
with pytest.raises(ValueError):
|
||||
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
|
||||
@pytest.mark.parametrize("detokenize", [True, False])
|
||||
def test_none_logprobs(vllm_runner, model, chunked_prefill_token_size: int,
|
||||
detokenize: bool, example_prompts):
|
||||
max_num_seqs = 256
|
||||
enable_chunked_prefill = False
|
||||
max_num_batched_tokens = None
|
||||
if chunked_prefill_token_size != -1:
|
||||
enable_chunked_prefill = True
|
||||
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
max_tokens = 5
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_seqs=max_num_seqs,
|
||||
gpu_memory_utilization=0.6,
|
||||
) as vllm_model:
|
||||
sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens,
|
||||
logprobs=None,
|
||||
temperature=0.0,
|
||||
detokenize=detokenize)
|
||||
results_logprobs_none = vllm_model.model.generate(
|
||||
example_prompts, sampling_params=sampling_params_logprobs_none)
|
||||
|
||||
for i in range(len(results_logprobs_none)):
|
||||
assert results_logprobs_none[i].outputs[0].logprobs is None
|
||||
assert results_logprobs_none[i].outputs[0].cumulative_logprob is None
|
||||
185
vllm-v0.6.2/tests/samplers/test_no_bad_words.py
Normal file
185
vllm-v0.6.2/tests/samplers/test_no_bad_words.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""Make sure bad_words works.
|
||||
|
||||
Run `pytest tests/samplers/test_no_bad_words.py`.
|
||||
|
||||
"""
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def _generate(
|
||||
model: LLM,
|
||||
prompt: str,
|
||||
num_prompt_tokens: int,
|
||||
temperature: float = 0,
|
||||
bad_words: Optional[List[str]] = None,
|
||||
) -> List[int]:
|
||||
sampling_params = SamplingParams(
|
||||
temperature=temperature,
|
||||
bad_words=bad_words,
|
||||
)
|
||||
|
||||
# [([output_token_ids, ], [output_text, ]), ]
|
||||
output = model.generate([prompt], sampling_params=sampling_params)
|
||||
|
||||
output_token_ids = output[0][0][0][num_prompt_tokens:]
|
||||
# [0] first (and only) request output
|
||||
# [0] token_ids (not text)
|
||||
# [0] first (and only) output completion
|
||||
|
||||
return output_token_ids
|
||||
|
||||
|
||||
class TestOneTokenBadWord:
|
||||
MODEL = "meta-llama/Llama-2-7b-hf"
|
||||
|
||||
PROMPT = "Hi! How are"
|
||||
TARGET_TOKEN = "you"
|
||||
|
||||
def setup_method(self, method):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL,
|
||||
add_prefix_space=True)
|
||||
|
||||
self.num_prompt_tokens = len(self._encode(self.PROMPT))
|
||||
self.target_token_id = self._encode(self.TARGET_TOKEN,
|
||||
add_special_tokens=False)[0]
|
||||
|
||||
def test_one_token_bad_word(self, vllm_runner):
|
||||
with vllm_runner(self.MODEL) as llm:
|
||||
output_token_ids = self._generate(llm)
|
||||
assert output_token_ids[0] == self.target_token_id
|
||||
|
||||
output_token_ids = self._generate(llm,
|
||||
bad_words=[self.TARGET_TOKEN])
|
||||
assert self.target_token_id not in output_token_ids
|
||||
|
||||
def _generate(self,
|
||||
model: LLM,
|
||||
bad_words: Optional[List[str]] = None) -> List[int]:
|
||||
return _generate(
|
||||
model=model,
|
||||
prompt=self.PROMPT,
|
||||
num_prompt_tokens=self.num_prompt_tokens,
|
||||
bad_words=bad_words,
|
||||
)
|
||||
|
||||
def _encode(self,
|
||||
prompt: str,
|
||||
add_special_tokens: bool = True) -> List[int]:
|
||||
return self.tokenizer(prompt,
|
||||
add_special_tokens=add_special_tokens).input_ids
|
||||
|
||||
|
||||
class TestTwoTokenBadWord:
|
||||
# Another model (with a different tokenizer behaviour)
|
||||
MODEL = "openai-community/gpt2"
|
||||
|
||||
PROMPT = "How old are you? I am 10"
|
||||
TARGET_TOKEN1 = "years"
|
||||
TARGET_TOKEN2 = "old"
|
||||
NEIGHBOUR_TOKEN2 = "older"
|
||||
|
||||
def setup_method(self, method):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL,
|
||||
add_prefix_space=True)
|
||||
|
||||
self.num_prompt_tokens = len(self._encode(self.PROMPT))
|
||||
self.target_token_id1 = self._encode(self.TARGET_TOKEN1,
|
||||
add_special_tokens=False)[0]
|
||||
self.target_token_id2 = self._encode(self.TARGET_TOKEN2,
|
||||
add_special_tokens=False)[0]
|
||||
self.neighbour_token_id2 = self._encode(self.NEIGHBOUR_TOKEN2,
|
||||
add_special_tokens=False)[0]
|
||||
|
||||
def test_two_token_bad_word(self, vllm_runner):
|
||||
with vllm_runner(self.MODEL) as llm:
|
||||
output_token_ids = self._generate(llm)
|
||||
assert output_token_ids[:2] == [
|
||||
self.target_token_id1, self.target_token_id2
|
||||
]
|
||||
|
||||
output_token_ids = self._generate(llm,
|
||||
bad_words=[self.TARGET_TOKEN1])
|
||||
assert self.target_token_id1 not in output_token_ids
|
||||
|
||||
output_token_ids = self._generate(llm,
|
||||
bad_words=[self.TARGET_TOKEN2])
|
||||
assert output_token_ids[0] == self.target_token_id1
|
||||
assert self.target_token_id2 not in output_token_ids
|
||||
|
||||
output_token_ids = self._generate(
|
||||
llm, bad_words=[f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}'])
|
||||
assert output_token_ids[0] == self.target_token_id1
|
||||
assert output_token_ids[:2] != [
|
||||
self.target_token_id1, self.target_token_id2
|
||||
]
|
||||
assert not self._contains(
|
||||
output_token_ids,
|
||||
[self.target_token_id1, self.target_token_id2])
|
||||
# Model dependent behaviour
|
||||
assert output_token_ids[:2] == [
|
||||
self.target_token_id1, self.neighbour_token_id2
|
||||
]
|
||||
|
||||
output_token_ids = self._generate(
|
||||
llm,
|
||||
bad_words=[
|
||||
f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}',
|
||||
f'{self.TARGET_TOKEN1} {self.NEIGHBOUR_TOKEN2}'
|
||||
])
|
||||
assert output_token_ids[0] == self.target_token_id1
|
||||
assert output_token_ids[:2] != [
|
||||
self.target_token_id1, self.target_token_id2
|
||||
]
|
||||
assert not self._contains(
|
||||
output_token_ids,
|
||||
[self.target_token_id1, self.target_token_id2])
|
||||
assert output_token_ids[:2] != [
|
||||
self.target_token_id1, self.neighbour_token_id2
|
||||
]
|
||||
assert not self._contains(
|
||||
output_token_ids,
|
||||
[self.target_token_id1, self.neighbour_token_id2])
|
||||
assert ((self.target_token_id2 in output_token_ids)
|
||||
or (self.neighbour_token_id2 in output_token_ids))
|
||||
|
||||
def _generate(self,
|
||||
model: LLM,
|
||||
bad_words: Optional[List[str]] = None) -> List[int]:
|
||||
return _generate(
|
||||
model=model,
|
||||
prompt=self.PROMPT,
|
||||
num_prompt_tokens=self.num_prompt_tokens,
|
||||
bad_words=bad_words,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _contains(sequence: List[int], subsequence: List[int]) -> bool:
|
||||
searched = False
|
||||
|
||||
for start in range(len(sequence)):
|
||||
end = start + len(subsequence)
|
||||
current_subsequence = sequence[start:end]
|
||||
|
||||
if len(current_subsequence) < len(subsequence):
|
||||
continue
|
||||
|
||||
searched = True
|
||||
|
||||
assert len(current_subsequence) == len(subsequence)
|
||||
|
||||
if current_subsequence == subsequence:
|
||||
return True
|
||||
|
||||
assert searched, "All subsequences did not match in length..."
|
||||
|
||||
return False
|
||||
|
||||
def _encode(self,
|
||||
prompt: str,
|
||||
add_special_tokens: bool = True) -> List[int]:
|
||||
return self.tokenizer(prompt,
|
||||
add_special_tokens=add_special_tokens).input_ids
|
||||
54
vllm-v0.6.2/tests/samplers/test_ranks.py
Normal file
54
vllm-v0.6.2/tests/samplers/test_ranks.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
MODELS = ["facebook/opt-125m"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_ranks(
|
||||
vllm_runner,
|
||||
model,
|
||||
dtype,
|
||||
example_prompts,
|
||||
):
|
||||
max_tokens = 5
|
||||
num_top_logprobs = 5
|
||||
num_prompt_logprobs = 5
|
||||
|
||||
with vllm_runner(model, dtype=dtype,
|
||||
max_logprobs=num_top_logprobs) as vllm_model:
|
||||
|
||||
## Test greedy logprobs ranks
|
||||
vllm_sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
top_p=1.0,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=num_top_logprobs,
|
||||
prompt_logprobs=num_prompt_logprobs)
|
||||
vllm_results = vllm_model.generate_w_logprobs(example_prompts,
|
||||
vllm_sampling_params)
|
||||
|
||||
## Test non-greedy logprobs ranks
|
||||
sampling_params = SamplingParams(temperature=1.0,
|
||||
top_p=1.0,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=num_top_logprobs,
|
||||
prompt_logprobs=num_prompt_logprobs)
|
||||
res = vllm_model.generate_w_logprobs(example_prompts, sampling_params)
|
||||
|
||||
for result in vllm_results:
|
||||
assert result[2] is not None
|
||||
assert len(result[2]) == len(result[0])
|
||||
# check whether all chosen tokens have ranks = 1
|
||||
for token, logprobs in zip(result[0], result[2]):
|
||||
assert token in logprobs
|
||||
assert logprobs[token].rank == 1
|
||||
|
||||
for result in res:
|
||||
assert result[2] is not None
|
||||
assert len(result[2]) == len(result[0])
|
||||
# check whether all chosen tokens have ranks
|
||||
for token, logprobs in zip(result[0], result[2]):
|
||||
assert logprobs[token].rank >= 1
|
||||
511
vllm-v0.6.2/tests/samplers/test_rejection_sampler.py
Normal file
511
vllm-v0.6.2/tests/samplers/test_rejection_sampler.py
Normal file
@@ -0,0 +1,511 @@
|
||||
"""Tests for rejection sampling."""
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1)
|
||||
]
|
||||
|
||||
|
||||
def mock_causal_accepted_tensor(
|
||||
k: int, last_accepted_indices: torch.Tensor) -> torch.Tensor:
|
||||
"""Generate an "accepted" tensor which should yield causally-accepted tokens
|
||||
up to last accepted indices.
|
||||
|
||||
Tokens after last_accepted_indices+1 may also be accepted, although they
|
||||
will not be causally accepted.
|
||||
"""
|
||||
batch_size = last_accepted_indices.shape[0]
|
||||
|
||||
accepted = (torch.arange(k).expand(batch_size, k) <=
|
||||
last_accepted_indices.unsqueeze(-1).broadcast_to(
|
||||
batch_size, k))
|
||||
|
||||
# Sprinkle accepted values after the contiguous initial accepted values.
|
||||
# This replicates the behavior of rejection sampling, which may "accept"
|
||||
# a token that cannot be accepted because of causality.
|
||||
sprinkle_candidates = (
|
||||
torch.arange(k).expand(batch_size, k) >
|
||||
last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) + 1)
|
||||
sprinkle = torch.rand(batch_size, k) > 0.5
|
||||
accepted[sprinkle_candidates] = sprinkle[sprinkle_candidates]
|
||||
return accepted
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(10)))
|
||||
@pytest.mark.parametrize(
|
||||
"which_tokens_accepted",
|
||||
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
||||
device: str, use_flashinfer: bool):
|
||||
"""Verify the output has correct format given predetermined accepted matrix.
|
||||
"""
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
batch_size = 10
|
||||
k = 5
|
||||
vocab_size = 3000
|
||||
|
||||
if which_tokens_accepted == "all_tokens_accepted":
|
||||
accepted = mock_causal_accepted_tensor(
|
||||
k, -1 + k * torch.ones((batch_size, ), dtype=torch.long))
|
||||
elif which_tokens_accepted == "no_tokens_accepted":
|
||||
accepted = mock_causal_accepted_tensor(
|
||||
k, -torch.ones((batch_size, ), dtype=torch.long))
|
||||
elif which_tokens_accepted == "some_tokens_accepted":
|
||||
last_accepted_indices = torch.randint(low=-1,
|
||||
high=k,
|
||||
size=(batch_size, ))
|
||||
accepted = mock_causal_accepted_tensor(k, last_accepted_indices)
|
||||
else:
|
||||
raise AssertionError()
|
||||
|
||||
recovered_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
|
||||
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
|
||||
accepted,
|
||||
recovered_token_ids,
|
||||
draft_token_ids,
|
||||
bonus_token_ids,
|
||||
)
|
||||
|
||||
expected_bonus_token_ids = bonus_token_ids.clone()
|
||||
|
||||
if which_tokens_accepted == "all_tokens_accepted":
|
||||
# Expect all tokens to be equal to draft tokens.
|
||||
assert torch.equal(output_token_ids[:, :-1], draft_token_ids)
|
||||
|
||||
# Expect all bonus tokens to be included.
|
||||
assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)
|
||||
elif which_tokens_accepted == "no_tokens_accepted":
|
||||
# Expect first token to be equal to recovered tokens.
|
||||
assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])
|
||||
|
||||
# Expect everything else to be -1.
|
||||
assert torch.equal(output_token_ids[:, 1:],
|
||||
torch.ones_like(output_token_ids[:, 1:]) * -1)
|
||||
elif which_tokens_accepted == "some_tokens_accepted":
|
||||
recovered_plus_bonus = torch.cat(
|
||||
(recovered_token_ids, expected_bonus_token_ids), dim=-1)
|
||||
# Assert first rejected token is a recovered token or bonus token.
|
||||
assert torch.equal(
|
||||
recovered_plus_bonus[torch.arange(0, batch_size),
|
||||
last_accepted_indices + 1],
|
||||
output_token_ids[torch.arange(0, batch_size),
|
||||
last_accepted_indices + 1])
|
||||
|
||||
# Assert every subsequent token is -1.
|
||||
subsequent_mask = torch.arange(0, k + 1).expand(
|
||||
batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1)
|
||||
assert torch.all(output_token_ids[subsequent_mask] == -1)
|
||||
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief(use_flashinfer): MLU device only support MLU_FLASH_ATTN backend
|
||||
'''
|
||||
@pytest.mark.parametrize("k", list(range(1, 6)))
|
||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_flashinfer", [False])
|
||||
@torch.inference_mode()
|
||||
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
device: str, use_flashinfer: bool):
|
||||
torch.set_default_device(device)
|
||||
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
|
||||
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
||||
draft_token_ids)
|
||||
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief(use_flashinfer): MLU device only support MLU_FLASH_ATTN backend
|
||||
'''
|
||||
@pytest.mark.parametrize("frac_seeded", [0.0, 0.25, 0.5, 1.0])
|
||||
@pytest.mark.parametrize("k", [1, 3, 6])
|
||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
|
||||
@pytest.mark.parametrize("n_rep", [100])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_flashinfer", [False])
|
||||
@torch.inference_mode()
|
||||
def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
|
||||
frac_seeded: float, n_rep: int, device: str,
|
||||
use_flashinfer: bool):
|
||||
torch.set_default_device(device)
|
||||
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
|
||||
seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded
|
||||
|
||||
results = []
|
||||
for _ in range(n_rep):
|
||||
seeded_seqs = {
|
||||
i: torch.Generator(device=device).manual_seed(i)
|
||||
for i in range(batch_size) if seeded_mask[i]
|
||||
}
|
||||
results.append(
|
||||
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
||||
draft_token_ids, seeded_seqs))
|
||||
|
||||
for i in range(batch_size):
|
||||
if seeded_mask[i]:
|
||||
for j in range(1, n_rep):
|
||||
assert torch.equal(results[j][i], results[0][i])
|
||||
|
||||
|
||||
@pytest.mark.skip("Skip flashinfer test case for MLU.")
|
||||
@pytest.mark.parametrize("k", [1, 3, 6])
|
||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
|
||||
batch_size: int, device: str):
|
||||
"""
|
||||
Test the flashinfer and nonflashinfer backend generate
|
||||
the same output metrics.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
torch.manual_seed(0)
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
|
||||
num_accepted_tokens = []
|
||||
num_emitted_tokens = []
|
||||
num_draft_tokens = []
|
||||
|
||||
def get_seeded_seqs():
|
||||
return {
|
||||
i: torch.Generator(device=device).manual_seed(i)
|
||||
for i in range(batch_size)
|
||||
}
|
||||
|
||||
for use_flashinfer in [True, False]:
|
||||
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
# We use seeded sequences to ensure the same tokens are accepted
|
||||
# for both flashinfer and nonflashinfer backends.
|
||||
seeded_seqs = get_seeded_seqs()
|
||||
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
||||
draft_token_ids, seeded_seqs)
|
||||
num_accepted_tokens.append(rejection_sampler.num_accepted_tokens)
|
||||
num_emitted_tokens.append(rejection_sampler.num_emitted_tokens)
|
||||
num_draft_tokens.append(rejection_sampler.num_draft_tokens)
|
||||
|
||||
assert num_accepted_tokens[0] == num_accepted_tokens[1]
|
||||
assert num_emitted_tokens[0] == num_emitted_tokens[1]
|
||||
assert num_draft_tokens[0] == num_draft_tokens[1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
|
||||
@pytest.mark.parametrize("which_token_ids",
|
||||
["bonus_token_ids", "draft_token_ids"])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
||||
which_token_ids: str, device: str,
|
||||
use_flashinfer: bool):
|
||||
k = 3
|
||||
batch_size = 5
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
|
||||
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer,
|
||||
strict_mode=True)
|
||||
rejection_sampler.init_gpu_tensors(device=device)
|
||||
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
|
||||
oob_token_ids = None
|
||||
if which_token_ids == "bonus_token_ids":
|
||||
oob_token_ids = bonus_token_ids
|
||||
elif which_token_ids == "draft_token_ids":
|
||||
oob_token_ids = draft_token_ids
|
||||
else:
|
||||
raise AssertionError()
|
||||
|
||||
if above_or_below_vocab_range == "above":
|
||||
rogue_token_id = vocab_size + 1
|
||||
elif above_or_below_vocab_range == "below":
|
||||
rogue_token_id = -1
|
||||
else:
|
||||
raise AssertionError()
|
||||
|
||||
oob_token_ids[0][0] = rogue_token_id
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
||||
draft_token_ids)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief(use_flashinfer): MLU device only support MLU_FLASH_ATTN backend
|
||||
'''
|
||||
@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
|
||||
@pytest.mark.parametrize("seed", list(range(5)))
|
||||
@pytest.mark.parametrize("use_flashinfer", [False])
|
||||
@torch.inference_mode()
|
||||
def test_rejection_sampling_approximates_target_distribution(
|
||||
seed: int, draft_and_target_probs_equal: bool, use_flashinfer: bool):
|
||||
"""Verify rejection sampling approximates target distribution,
|
||||
despite sampling from a potentially distinct draft distribution.
|
||||
|
||||
This is done by first creating a random target probability
|
||||
distribution and a random draft probability distribution. We then
|
||||
sample token ids from the rejection sampler using these draft
|
||||
and target distributions. The samples are used to estimate
|
||||
the output probability distribution, which we expect to approximate
|
||||
the target distribution.
|
||||
|
||||
A basic distance metric is used to determine similarity between
|
||||
distributions.
|
||||
|
||||
We expect that as we increase the number of samples,
|
||||
the distance between the observed distribution and the target
|
||||
distribution decreases. To measure this, we compare the distance
|
||||
of the observed distribution against both the target distribution
|
||||
and a uniform random distribution. We expect the distance between
|
||||
the observed distribution and the target distribution to improve
|
||||
much more than the distance improvement between the observed
|
||||
distribution and the random distribution.
|
||||
|
||||
When draft_and_target_probs_equal=True, the draft and target
|
||||
probabilities are exactly equal. Rejection sampling should
|
||||
still work without any NaNs or exceptions.
|
||||
"""
|
||||
torch.set_default_device("cpu")
|
||||
set_random_seed(seed)
|
||||
helper = _CorrectnessTestHelper(
|
||||
vocab_size=10,
|
||||
rejection_sampler=RejectionSampler(use_flashinfer=use_flashinfer),
|
||||
)
|
||||
|
||||
draft_probs, target_probs, reference_probs = helper.generate_probs_for_test(
|
||||
draft_and_target_probs_equal)
|
||||
|
||||
sample_sizes = [10, 100, 1_000, 10_000, 100_000]
|
||||
distance_wrt_reference: List[float] = []
|
||||
distance_wrt_target: List[float] = []
|
||||
|
||||
for num_samples in sample_sizes:
|
||||
(reference_vs_rejsample_dist,
|
||||
target_vs_rejsample_dist) = helper.run_and_compare_distributions(
|
||||
draft_probs,
|
||||
target_probs,
|
||||
reference_probs,
|
||||
num_samples,
|
||||
)
|
||||
|
||||
distance_wrt_reference.append(reference_vs_rejsample_dist)
|
||||
distance_wrt_target.append(target_vs_rejsample_dist)
|
||||
|
||||
relative_change_in_distance_wrt_target = get_ratio_first_to_last(
|
||||
distance_wrt_target)
|
||||
relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
|
||||
distance_wrt_reference)
|
||||
|
||||
print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
|
||||
f"{reference_vs_rejsample_dist=:.05f}")
|
||||
print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} "
|
||||
f"{relative_change_in_distance_wrt_reference=:.02f}")
|
||||
|
||||
relative_change_in_distance_wrt_target = get_ratio_first_to_last(
|
||||
distance_wrt_target)
|
||||
relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
|
||||
distance_wrt_reference)
|
||||
|
||||
expected_improvement_multiplier = 20
|
||||
assert (relative_change_in_distance_wrt_target >
|
||||
relative_change_in_distance_wrt_reference *
|
||||
expected_improvement_multiplier)
|
||||
|
||||
|
||||
def get_ratio_first_to_last(elements: List[float]) -> float:
|
||||
return elements[0] / elements[-1]
|
||||
|
||||
|
||||
class _CorrectnessTestHelper:
|
||||
"""Class that packages together logic required for the unit-level
|
||||
rejection sampling correctness test.
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_size: int, rejection_sampler: RejectionSampler):
|
||||
self.rejection_sampler = rejection_sampler
|
||||
self.vocab_size = vocab_size
|
||||
self.vocab_range = (0, vocab_size)
|
||||
|
||||
self.rejection_sampler.init_gpu_tensors(device=0)
|
||||
|
||||
# Keep test simple, use k=1
|
||||
self.k = 1
|
||||
|
||||
# Bonus tokens not used, but rejection sampler requires
|
||||
# correct shape.
|
||||
self.num_bonus_tokens = 1
|
||||
|
||||
def generate_probs_for_test(
|
||||
self, draft_and_target_probs_equal: bool
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
draft_probs, target_probs = (F.softmax(
|
||||
torch.rand(self.vocab_size, dtype=torch.float32),
|
||||
dim=-1,
|
||||
) for _ in range(2))
|
||||
|
||||
num_reference_probs = 100
|
||||
reference_probs = F.softmax(
|
||||
torch.rand(num_reference_probs,
|
||||
self.vocab_size,
|
||||
dtype=torch.float32),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if draft_and_target_probs_equal:
|
||||
target_probs = draft_probs.clone()
|
||||
|
||||
return draft_probs, target_probs, reference_probs
|
||||
|
||||
def run_and_compare_distributions(self, draft_probs: torch.Tensor,
|
||||
target_probs: torch.Tensor,
|
||||
reference_probs: torch.Tensor,
|
||||
num_samples: int) -> Tuple[float, float]:
|
||||
# Sample using rejection sampling.
|
||||
rej_sample_probs = self._estimate_rejection_sampling_pdf(
|
||||
draft_probs, target_probs, num_samples)
|
||||
|
||||
# Average distance from reference probs.
|
||||
reference_vs_rejsample_dist = torch.dist(
|
||||
reference_probs,
|
||||
rej_sample_probs).item() / reference_probs.shape[0]
|
||||
target_vs_rejsample_dist = torch.dist(target_probs,
|
||||
rej_sample_probs).item()
|
||||
|
||||
return reference_vs_rejsample_dist, target_vs_rejsample_dist
|
||||
|
||||
def _estimate_rejection_sampling_pdf(
|
||||
self,
|
||||
draft_probs: torch.Tensor,
|
||||
target_probs: torch.Tensor,
|
||||
num_samples: int,
|
||||
) -> torch.Tensor:
|
||||
# Repeat draft probs num_samples times.
|
||||
draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat(
|
||||
num_samples, 1, 1)
|
||||
|
||||
# Repeat target probs num_samples * (k + 1) times.
|
||||
# Rejection sampler requires bonus token probs, but they aren't used.
|
||||
target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat(
|
||||
num_samples, self.k + 1, 1)
|
||||
|
||||
# Randomly sample draft token ids from draft probs.
|
||||
draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
|
||||
num_samples=1,
|
||||
replacement=True).reshape(
|
||||
num_samples, self.k)
|
||||
|
||||
# Bonus tokens not used but required.
|
||||
bonus_token_ids = torch.zeros((1, self.num_bonus_tokens),
|
||||
dtype=torch.int64,
|
||||
device="cuda").repeat(num_samples, 1)
|
||||
|
||||
# Get output tokens via rejection sampling.
|
||||
output_token_ids = self.rejection_sampler(target_probs.to("cuda"),
|
||||
bonus_token_ids.to("cuda"),
|
||||
draft_probs.to("cuda"),
|
||||
draft_token_ids.to("cuda"))
|
||||
|
||||
# Remove bonus tokens
|
||||
output_token_ids = output_token_ids[:, :-1].flatten()
|
||||
|
||||
# Estimate probability density function
|
||||
hist = torch.histogram(output_token_ids.to(dtype=torch.float,
|
||||
device="cpu"),
|
||||
bins=self.vocab_size,
|
||||
range=self.vocab_range,
|
||||
density=True)
|
||||
|
||||
return hist.hist
|
||||
758
vllm-v0.6.2/tests/samplers/test_sampler.py
Normal file
758
vllm-v0.6.2/tests/samplers/test_sampler.py
Normal file
@@ -0,0 +1,758 @@
|
||||
import itertools
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import GenerationConfig, GenerationMixin
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import Counter, is_pin_memory_available
|
||||
|
||||
|
||||
class MockLogitsSampler(Sampler):
|
||||
|
||||
def __init__(self, fake_logits: torch.Tensor):
|
||||
super().__init__()
|
||||
self.fake_logits = fake_logits
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
def _prepare_test(
|
||||
batch_size: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]:
|
||||
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
|
||||
fake_logits = torch.full((batch_size, VOCAB_SIZE),
|
||||
1e-2,
|
||||
dtype=input_tensor.dtype)
|
||||
sampler = MockLogitsSampler(fake_logits)
|
||||
return input_tensor, fake_logits, sampler
|
||||
|
||||
|
||||
VOCAB_SIZE = 32000
|
||||
RANDOM_SEEDS = list(range(128))
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1)
|
||||
]
|
||||
|
||||
|
||||
def _do_sample(
|
||||
batch_size: int,
|
||||
input_tensor: torch.Tensor,
|
||||
sampler: MockLogitsSampler,
|
||||
sampling_params: SamplingParams,
|
||||
device: str,
|
||||
):
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
seq_lens: List[int] = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||
sampling_params=sampling_params,
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=is_pin_memory_available())
|
||||
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_greedy(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler = _prepare_test(batch_size)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
expected = torch.argmax(fake_logits, dim=-1)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token == expected[i].item()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_random(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
_, fake_logits, sampler = _prepare_test(batch_size)
|
||||
|
||||
for i in range(batch_size):
|
||||
fake_logits[i, i] = 1e2
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
)
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token == i
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_random_seed(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
_, fake_logits, sampler = _prepare_test(batch_size)
|
||||
|
||||
for i in range(batch_size):
|
||||
fake_logits[i, i] = 1e2
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
seed=random.randint(0, 10000),
|
||||
)
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token == i
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
_, fake_logits, sampler = _prepare_test(batch_size)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
seed=random.randint(0, 10000),
|
||||
)
|
||||
first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
|
||||
second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
|
||||
assert first_sampler_output == second_sampler_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
seq_id_counter = Counter(start=random.randint(0, 100))
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
def create_sampling_params(min_tokens,
|
||||
eos_token_id=0,
|
||||
*,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
prompt_logprobs: Optional[int] = None):
|
||||
sampling_params = SamplingParams(
|
||||
min_tokens=min_tokens,
|
||||
max_tokens=9999, # keep higher than max of min_tokens
|
||||
stop_token_ids=stop_token_ids,
|
||||
# requesting prompt_logprobs changes the structure of `logits`
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
sampling_params.all_stop_token_ids.add(eos_token_id)
|
||||
return sampling_params
|
||||
|
||||
def create_sequence_data(num_input=3, num_generated=0):
|
||||
seq_data = SequenceData.from_seqs(
|
||||
random.choices(range(0, VOCAB_SIZE), k=num_input))
|
||||
if num_generated > 0:
|
||||
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
|
||||
k=num_generated)
|
||||
return seq_data
|
||||
|
||||
def generate_test_case():
|
||||
# generate multiple seq groups but limit total batch size
|
||||
batch_size = random.randint(1, 128)
|
||||
|
||||
expected_penalization = []
|
||||
sequence_metadata_list: List[SequenceGroupMetadata] = []
|
||||
# 20% chance to generate seq group metadata list with all prompts
|
||||
is_prompt = random.random() < 0.2
|
||||
while batch_size > 0:
|
||||
num_seqs = 1 if is_prompt else random.randint(1, batch_size)
|
||||
|
||||
eos_token_id = random.randint(0, VOCAB_SIZE - 1)
|
||||
min_tokens = random.randint(0, 50)
|
||||
num_stop_tokens = random.randint(0, 8)
|
||||
if num_stop_tokens > 0:
|
||||
stop_token_ids = random.choices(range(0, VOCAB_SIZE - 1),
|
||||
k=num_stop_tokens)
|
||||
else:
|
||||
stop_token_ids = None
|
||||
|
||||
sampling_params = create_sampling_params(
|
||||
min_tokens=min_tokens,
|
||||
eos_token_id=eos_token_id,
|
||||
stop_token_ids=stop_token_ids)
|
||||
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
seq_group_penalization: List[bool] = []
|
||||
for _ in range(num_seqs):
|
||||
num_input = random.randint(1, 100)
|
||||
num_generated = 0 if is_prompt else random.randint(1, 100)
|
||||
seq_data[next(seq_id_counter)] = create_sequence_data(
|
||||
num_input=num_input, num_generated=num_generated)
|
||||
seq_group_penalization.append(num_generated < min_tokens)
|
||||
|
||||
expected_penalization.extend(seq_group_penalization)
|
||||
sequence_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{batch_size}",
|
||||
is_prompt=is_prompt,
|
||||
seq_data=seq_data,
|
||||
sampling_params=sampling_params,
|
||||
block_tables={},
|
||||
))
|
||||
batch_size -= num_seqs
|
||||
|
||||
return {
|
||||
"expected_penalization": expected_penalization,
|
||||
"seq_group_metadata_list": sequence_metadata_list,
|
||||
}
|
||||
|
||||
# define some explicit test cases for edge case behavior
|
||||
prompt_without_penalization = {
|
||||
"expected_penalization": [False],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(),
|
||||
},
|
||||
sampling_params=create_sampling_params(0),
|
||||
block_tables={},
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
prompt_with_penalization = {
|
||||
"expected_penalization": [True],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(),
|
||||
},
|
||||
sampling_params=create_sampling_params(1),
|
||||
block_tables={},
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
prompt_with_penalization_and_prompt_logprobs = {
|
||||
"expected_penalization": [False, False, True],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(num_input=3),
|
||||
},
|
||||
sampling_params=create_sampling_params(1, prompt_logprobs=3),
|
||||
block_tables={},
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
stop_penalizing_after_min_tokens = {
|
||||
"expected_penalization": [False],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=False,
|
||||
seq_data={
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=1),
|
||||
},
|
||||
sampling_params=create_sampling_params(1),
|
||||
block_tables={},
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
stop_token_ids = [42, 99, 42, 0] # intentional duplication
|
||||
prompt_combination = {
|
||||
"expected_penalization": [False, True, False],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_2",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(num_input=2),
|
||||
},
|
||||
sampling_params=create_sampling_params(1, prompt_logprobs=3),
|
||||
block_tables={},
|
||||
),
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_3",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(),
|
||||
},
|
||||
sampling_params=create_sampling_params(
|
||||
0, stop_token_ids=stop_token_ids),
|
||||
block_tables={},
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
stop_token_ids = [1, 999, 37, 37] # intentional duplication
|
||||
decode_combination = {
|
||||
"expected_penalization": [True, False, False, True, False],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=False,
|
||||
seq_data={
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=1),
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=100),
|
||||
},
|
||||
sampling_params=create_sampling_params(
|
||||
2, stop_token_ids=stop_token_ids),
|
||||
block_tables={},
|
||||
),
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_2",
|
||||
is_prompt=False,
|
||||
seq_data={
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=20),
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=1),
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=10),
|
||||
},
|
||||
sampling_params=create_sampling_params(
|
||||
10, prompt_logprobs=5, stop_token_ids=stop_token_ids),
|
||||
block_tables={},
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
if seed == 0:
|
||||
test_cases = [
|
||||
prompt_without_penalization,
|
||||
prompt_with_penalization,
|
||||
prompt_with_penalization_and_prompt_logprobs,
|
||||
stop_penalizing_after_min_tokens,
|
||||
prompt_combination,
|
||||
decode_combination,
|
||||
]
|
||||
else:
|
||||
test_cases = [generate_test_case()]
|
||||
|
||||
def run_test_case(*, expected_penalization: List[bool],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]):
|
||||
assert expected_penalization, \
|
||||
"Invalid test case, need expected_penalization"
|
||||
assert seq_group_metadata_list, \
|
||||
"Invalid test case, need seq_group_metadata_list"
|
||||
|
||||
batch_size = 0
|
||||
seq_lens: List[int] = []
|
||||
sampling_params_per_row: List[SamplingParams] = []
|
||||
for sgm in seq_group_metadata_list:
|
||||
sampling_params = sgm.sampling_params
|
||||
|
||||
num_rows = len(sgm.seq_data)
|
||||
if sgm.is_prompt:
|
||||
# a prompt seq_group has only one sequence
|
||||
seq_data = next(iter(sgm.seq_data.values()))
|
||||
prompt_len = seq_data.get_prompt_len()
|
||||
seq_lens.append(prompt_len)
|
||||
|
||||
assert sgm.sampling_params is not None
|
||||
if sgm.sampling_params.prompt_logprobs:
|
||||
# with prompt_logprobs each token in the prompt has a row in
|
||||
# logits
|
||||
num_rows = prompt_len
|
||||
|
||||
batch_size += num_rows
|
||||
sampling_params_per_row.extend(
|
||||
itertools.repeat(sampling_params, num_rows))
|
||||
|
||||
assert len(
|
||||
expected_penalization
|
||||
) == batch_size, \
|
||||
("Invalid test case, expected_penalization does not match computed"
|
||||
"batch size")
|
||||
|
||||
_, fake_logits, sampler = _prepare_test(batch_size)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens=seq_lens if seq_lens else None,
|
||||
query_lens=seq_lens if seq_lens else [1] * batch_size,
|
||||
device=device,
|
||||
pin_memory=is_pin_memory_available())
|
||||
# the logits tensor is modified in-place by the sampler
|
||||
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
||||
|
||||
for logits_idx, (should_penalize, sampling_params) in enumerate(
|
||||
zip(expected_penalization, sampling_params_per_row)):
|
||||
|
||||
tokens_to_check = sampling_params.all_stop_token_ids
|
||||
|
||||
if should_penalize:
|
||||
for token_id in tokens_to_check:
|
||||
assert fake_logits[logits_idx, token_id] == -float(
|
||||
'inf'
|
||||
), f"Expected token {token_id} for logits row {logits_idx}"
|
||||
" to be penalized"
|
||||
# no other tokens should be set to -inf
|
||||
assert torch.count_nonzero(
|
||||
fake_logits[logits_idx, :] == -float('inf')) == len(
|
||||
tokens_to_check
|
||||
), f"Expected only {len(tokens_to_check)} to be penalized"
|
||||
else:
|
||||
# no tokens should be set to -inf
|
||||
assert torch.count_nonzero(
|
||||
fake_logits[logits_idx, :] ==
|
||||
-float('inf')) == 0, "No tokens should have been penalized"
|
||||
|
||||
for test_case in test_cases:
|
||||
run_test_case(**test_case)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_mixed(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler = _prepare_test(batch_size)
|
||||
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
expected_tokens: List[Optional[List[int]]] = []
|
||||
seq_lens: List[int] = []
|
||||
for i in range(batch_size):
|
||||
expected: Optional[List[int]] = None
|
||||
sampling_type = random.randint(0, 2)
|
||||
if sampling_type == 0:
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
expected = [int(torch.argmax(fake_logits[i], dim=-1).item())]
|
||||
elif sampling_type in (1, 2):
|
||||
n = random.randint(1, 10)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=random.random() + 0.1,
|
||||
top_p=min(random.random() + 0.1, 1),
|
||||
top_k=random.randint(0, 10) or -1,
|
||||
n=n,
|
||||
presence_penalty=random.randint(0, 1),
|
||||
)
|
||||
if sampling_type == 2:
|
||||
sampling_params.seed = random.randint(0, 10000)
|
||||
else:
|
||||
for idx in range(n):
|
||||
fake_logits[i, i + idx] = 1e2
|
||||
expected = list(range(i, i + n))
|
||||
|
||||
expected_tokens.append(expected)
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||
sampling_params=sampling_params,
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
generators: Dict[str, torch.Generator] = {}
|
||||
|
||||
def test_sampling():
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=is_pin_memory_available(),
|
||||
generators=generators)
|
||||
sampler_output = sampler(logits=fake_logits,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
for i, (sequence_output, metadata) in enumerate(
|
||||
zip(sampler_output, seq_group_metadata_list)):
|
||||
assert metadata.sampling_params is not None
|
||||
|
||||
if (metadata.sampling_params.seed is not None
|
||||
and expected_tokens[i] is None):
|
||||
# Record seeded random result to compare with results of
|
||||
# second invocation
|
||||
expected_tokens[i] = [
|
||||
nth_output.output_token
|
||||
for nth_output in sequence_output.samples
|
||||
]
|
||||
continue
|
||||
|
||||
expected_tokens_item = expected_tokens[i]
|
||||
assert expected_tokens_item is not None
|
||||
|
||||
for n, nth_output in enumerate(sequence_output.samples):
|
||||
assert metadata.sampling_params is not None
|
||||
|
||||
if (metadata.sampling_params.temperature == 0
|
||||
or metadata.sampling_params.seed is not None):
|
||||
# Ensure exact matches for greedy or random with seed
|
||||
assert nth_output.output_token == expected_tokens_item[n]
|
||||
else:
|
||||
# For non-seeded random check that one of the high-logit
|
||||
# tokens were chosen
|
||||
assert nth_output.output_token in expected_tokens_item
|
||||
|
||||
# Test batch
|
||||
test_sampling()
|
||||
|
||||
# Shuffle the batch and resample
|
||||
target_index = list(range(batch_size))
|
||||
for list_to_shuffle in (target_index, seq_group_metadata_list,
|
||||
expected_tokens, seq_lens):
|
||||
random.Random(seed).shuffle(list_to_shuffle)
|
||||
target_index = torch.tensor(target_index)
|
||||
input_tensor.data = input_tensor.index_select(0, target_index)
|
||||
fake_logits.data = fake_logits.index_select(0, target_index)
|
||||
|
||||
# This time, results of seeded random samples will be compared with
|
||||
# the corresponding sample in the pre-shuffled batch
|
||||
test_sampling()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
if seed == 40:
|
||||
pytest.skip("skip cause diff accuracy between difference device.")
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
top_k = random.randint(100, 500)
|
||||
top_p = random.random() * 0.1
|
||||
vocab_size = 32000
|
||||
input_tensor = torch.rand((batch_size, 1024),
|
||||
device=device,
|
||||
dtype=torch.float16)
|
||||
fake_logits = torch.normal(0,
|
||||
5,
|
||||
size=(batch_size, vocab_size),
|
||||
device=input_tensor.device,
|
||||
dtype=input_tensor.dtype)
|
||||
sampler = MockLogitsSampler(fake_logits)
|
||||
|
||||
generation_model = GenerationMixin()
|
||||
generation_config = GenerationConfig(top_k=top_k,
|
||||
top_p=top_p,
|
||||
do_sample=True)
|
||||
|
||||
@dataclass
|
||||
class MockConfig:
|
||||
is_encoder_decoder: bool = False
|
||||
|
||||
generation_model.config = MockConfig() # needed by the following method
|
||||
generation_model._prepare_special_tokens(generation_config, device=device)
|
||||
processors = generation_model._get_logits_processor(generation_config,
|
||||
None,
|
||||
None,
|
||||
None, [],
|
||||
device=device)
|
||||
assert len(processors) == 2 # top_p and top_k
|
||||
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
seq_lens: List[int] = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||
sampling_params=SamplingParams(
|
||||
temperature=1,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=is_pin_memory_available())
|
||||
|
||||
sample_probs = None
|
||||
|
||||
def mock_sample(probs, *args, **kwargs):
|
||||
nonlocal sample_probs
|
||||
sample_probs = probs
|
||||
return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
|
||||
for prob in probs], None)
|
||||
|
||||
# top-k and top-p is only calculated when flashinfer kernel is not available
|
||||
with patch("vllm.model_executor.layers.sampler._sample", mock_sample), \
|
||||
patch("vllm.model_executor.layers.sampler."
|
||||
"flashinfer_top_k_top_p_sampling", None):
|
||||
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
||||
|
||||
assert sample_probs is not None
|
||||
|
||||
hf_probs = processors(torch.zeros_like(fake_logits), fake_logits.clone())
|
||||
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
|
||||
torch.testing.assert_close(hf_probs, sample_probs, rtol=0.0, atol=1e-5)
|
||||
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_flashinfer_fallback(seed: int, device: str):
|
||||
if not envs.VLLM_USE_FLASHINFER_SAMPLER:
|
||||
pytest.skip("Flashinfer sampler is disabled")
|
||||
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
_, fake_logits, sampler = _prepare_test(batch_size)
|
||||
|
||||
def failing_flashinfer_sampling(*_args, **_kwargs):
|
||||
return None, torch.zeros(batch_size, device=device, dtype=torch.int32)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
seed=random.randint(0, 10000),
|
||||
)
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
|
||||
with patch(
|
||||
"vllm.model_executor.layers.sampler."
|
||||
"flashinfer_top_k_top_p_sampling", failing_flashinfer_sampling):
|
||||
fallback_sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
|
||||
assert sampler_output == fallback_sampler_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_repetition_penalty_mixed(device: str):
|
||||
|
||||
vocab_size = 8
|
||||
|
||||
def test_sampling_params(sampling_params: List[SamplingParams]):
|
||||
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
seq_lens: List[int] = []
|
||||
for i in range(2):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||
sampling_params=sampling_params[i],
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=is_pin_memory_available())
|
||||
|
||||
fake_logits = torch.full((2, vocab_size),
|
||||
1e-2,
|
||||
device=device,
|
||||
dtype=torch.float16)
|
||||
|
||||
fake_logits[:, 5] = 1.1e-2
|
||||
fake_logits[:, 1] = 1.2e-2
|
||||
|
||||
sampler = MockLogitsSampler(fake_logits)
|
||||
|
||||
sampler_output = sampler(logits=fake_logits,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
generated_tokens = []
|
||||
for output in sampler_output:
|
||||
generated_tokens.append(output.samples[0].output_token)
|
||||
|
||||
return generated_tokens
|
||||
|
||||
# one configuration is greedy with repetition_penalty
|
||||
sampling_params_rep = SamplingParams(
|
||||
temperature=0.0,
|
||||
repetition_penalty=2.0,
|
||||
)
|
||||
|
||||
# other configuration is sampling w/o repetition_penalty
|
||||
sampling_params_sample = SamplingParams(
|
||||
temperature=1.0,
|
||||
top_k=1,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
tokens1 = test_sampling_params(
|
||||
[sampling_params_rep, sampling_params_sample])
|
||||
|
||||
tokens2 = test_sampling_params(
|
||||
[sampling_params_sample, sampling_params_rep])
|
||||
|
||||
assert tokens1[0] == tokens2[1]
|
||||
assert tokens1[1] == tokens2[0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_include_gpu_probs_tensor(device: str):
|
||||
set_random_seed(42)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
_, fake_logits, sampler = _prepare_test(batch_size)
|
||||
sampler.include_gpu_probs_tensor = True
|
||||
sampler.should_modify_greedy_probs_inplace = False
|
||||
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
|
||||
mock_inplace = Mock()
|
||||
with patch(
|
||||
"vllm.model_executor.layers.sampler._modify_greedy_probs_inplace",
|
||||
mock_inplace):
|
||||
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
mock_inplace.assert_not_called()
|
||||
|
||||
assert sampler_output.sampled_token_probs is not None
|
||||
assert sampler_output.logprobs is not None
|
||||
assert sampler_output.sampled_token_ids is not None
|
||||
77
vllm-v0.6.2/tests/samplers/test_seeded_generate.py
Normal file
77
vllm-v0.6.2/tests/samplers/test_seeded_generate.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Verify that seeded random sampling is deterministic.
|
||||
|
||||
Run `pytest tests/samplers/test_seeded_generate.py`.
|
||||
"""
|
||||
import copy
|
||||
import random
|
||||
from itertools import combinations
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
MODEL = "facebook/opt-125m"
|
||||
RANDOM_SEEDS = list(range(5))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vllm_model(vllm_runner):
|
||||
with vllm_runner(MODEL, dtype="half") as vllm_model:
|
||||
yield vllm_model
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_random_sample_with_seed(
|
||||
vllm_model,
|
||||
example_prompts,
|
||||
seed: int,
|
||||
) -> None:
|
||||
set_random_seed(seed)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
# Parameters to ensure sufficient randomness
|
||||
temperature=2.0,
|
||||
top_p=min(random.random() + 0.3, 1),
|
||||
top_k=random.randint(5, 20),
|
||||
n=random.randint(1, 10),
|
||||
presence_penalty=random.randint(0, 1),
|
||||
max_tokens=8,
|
||||
ignore_eos=True,
|
||||
)
|
||||
|
||||
sampling_params_seed_1 = copy.deepcopy(sampling_params)
|
||||
sampling_params_seed_1.seed = 100
|
||||
sampling_params_seed_2 = copy.deepcopy(sampling_params)
|
||||
sampling_params_seed_2.seed = 200
|
||||
|
||||
llm = vllm_model.model
|
||||
|
||||
for prompt in example_prompts:
|
||||
for params in (
|
||||
sampling_params,
|
||||
sampling_params_seed_1,
|
||||
sampling_params_seed_2,
|
||||
sampling_params,
|
||||
sampling_params_seed_1,
|
||||
sampling_params_seed_2,
|
||||
):
|
||||
llm._add_request(prompt, params=params)
|
||||
|
||||
results = llm._run_engine(use_tqdm=False)
|
||||
all_outputs = [[out.token_ids for out in output.outputs]
|
||||
for output in results]
|
||||
|
||||
for i in range(0, len(example_prompts), 6):
|
||||
outputs = all_outputs[i:i + 6]
|
||||
|
||||
# verify all non-seeded requests differ
|
||||
for output_a, output_b in combinations(
|
||||
(outputs[0], outputs[1], outputs[2], outputs[3]),
|
||||
2,
|
||||
):
|
||||
assert output_a != output_b
|
||||
|
||||
# verify requests with the same seed match
|
||||
assert outputs[1] == outputs[4]
|
||||
assert outputs[2] == outputs[5]
|
||||
470
vllm-v0.6.2/tests/samplers/test_typical_acceptance_sampler.py
Normal file
470
vllm-v0.6.2/tests/samplers/test_typical_acceptance_sampler.py
Normal file
@@ -0,0 +1,470 @@
|
||||
"""Tests for rejection sampling."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
||||
TypicalAcceptanceSampler)
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1)]
|
||||
|
||||
|
||||
def get_zero_temperature_prob_dist(batch_size, k, vocab_size):
|
||||
"""
|
||||
Generates a fake temperature zero probability distribution.
|
||||
Returns:
|
||||
1. A fake temperature zero probability distribution of shape
|
||||
[batch_size, k, vocab_size]
|
||||
2. Tensor of shape [batch_size, k] containing the token ids
|
||||
of the probability 1.0 tokens at each position.
|
||||
"""
|
||||
# Simulate temperature 0 probability distribution for target probabilities
|
||||
# and create target probabilities such that only 1 token id has
|
||||
# probability 1.0
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
probs = torch.rand(batch_size, k, vocab_size)
|
||||
_, zero_temperature_token_ids = torch.max(probs, dim=-1)
|
||||
# set the probability of the tokens with ids in zero_temperature_token_ids
|
||||
# to 1 and the rest to 0.
|
||||
target_probs = torch.zeros_like(probs).scatter_(
|
||||
-1, zero_temperature_token_ids.unsqueeze(-1), 1.0)
|
||||
return target_probs, zero_temperature_token_ids
|
||||
|
||||
|
||||
def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
|
||||
token_ids_to_exclude: torch.Tensor):
|
||||
"""
|
||||
Returns a tensor of shape [batch_size, k] of fake draft token ids
|
||||
drawn randomly from a vocab of size vocab_size. We however ensure
|
||||
that token_ids from token_ids_to_exclude are excluded at the
|
||||
corresponding positions.
|
||||
"""
|
||||
draft_token_ids = torch.empty(batch_size, k, dtype=torch.long)
|
||||
for i in range(batch_size):
|
||||
for j in range(k):
|
||||
# Generate a random token ID excluding token_ids_to_exclude[i, j]
|
||||
while True:
|
||||
token_id = torch.randint(0, vocab_size, (1, )).item()
|
||||
if token_id != token_ids_to_exclude[i, j]:
|
||||
draft_token_ids[i, j] = token_id
|
||||
break
|
||||
return draft_token_ids
|
||||
|
||||
|
||||
def get_acceptance_sampler(
|
||||
posterior_threshold: float = 0.03,
|
||||
posterior_alpha: float = 0.9,
|
||||
strict_mode: bool = False,
|
||||
) -> TypicalAcceptanceSampler:
|
||||
"""
|
||||
Initializes and returns a TypicalAcceptanceSampler.
|
||||
"""
|
||||
return TypicalAcceptanceSampler(posterior_threshold, posterior_alpha,
|
||||
strict_mode)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k", list(range(1, 6)))
|
||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
device: str):
|
||||
"""
|
||||
Tests that the TypicalAcceptancSampler forward succeeds for
|
||||
different combinations of k, vocab_size, batch_size and num devices.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = get_acceptance_sampler()
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
target_with_bonus_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
# Verify that sampling succeeds for all cases.
|
||||
typical_acceptance_sampler(target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
|
||||
@pytest.mark.parametrize("which_token_ids",
|
||||
["bonus_token_ids", "draft_token_ids"])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
||||
which_token_ids: str, device: str):
|
||||
"""
|
||||
Tests that we throw an exception of the token ids fall outside
|
||||
the bound of the provided vocabulary.
|
||||
"""
|
||||
k = 3
|
||||
batch_size = 5
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
target_with_bonus_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
# Verify that appropriate exceptions are thrown for out
|
||||
# of bound vocabs.
|
||||
oob_token_ids = None
|
||||
if which_token_ids == "bonus_token_ids":
|
||||
oob_token_ids = bonus_token_ids
|
||||
elif which_token_ids == "draft_token_ids":
|
||||
oob_token_ids = draft_token_ids
|
||||
else:
|
||||
raise AssertionError()
|
||||
|
||||
if above_or_below_vocab_range == "above":
|
||||
rogue_token_id = vocab_size + 1
|
||||
elif above_or_below_vocab_range == "below":
|
||||
rogue_token_id = -1
|
||||
else:
|
||||
raise AssertionError()
|
||||
|
||||
oob_token_ids[0][0] = rogue_token_id
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
typical_acceptance_sampler(target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(10)))
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_uniform_target_distribution_accepts_all_tokens(
|
||||
seed: int, device: str):
|
||||
"""
|
||||
Test the TypicalAcceptanceSampler with a uniform target probability
|
||||
distribution.
|
||||
|
||||
This test verifies that when provided with a uniform target probability
|
||||
distribution, the TypicalAcceptanceSampler accepts all draft tokens. The
|
||||
entropy of the uniform target distribution being high should lead to all
|
||||
draft tokens being accepted.
|
||||
"""
|
||||
set_random_seed(seed)
|
||||
k = 3
|
||||
batch_size = 5
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
target_with_bonus_probs = torch.rand(batch_size,
|
||||
k + 1,
|
||||
vocab_size,
|
||||
dtype=torch.float32)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
# We are using a uniform target probability distribution.
|
||||
# For a uniform distribution the entropy is very high and it
|
||||
# should lead to all draft tokens being accepted. Verify that.
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, -1] == bonus_token_ids.squeeze())
|
||||
|
||||
assert torch.all(output_token_ids[:, :k] == draft_token_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(10)))
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_temperature_zero_target_distribution(seed: int, device: str):
|
||||
"""
|
||||
Test the TypicalAcceptanceSampler with a zero-temperature target
|
||||
probability distribution.
|
||||
|
||||
This test verifies that when using a zero-temperature target probability
|
||||
distribution, where only one token has a probability of 1.0, the
|
||||
TypicalAcceptanceSampler correctly rejects all draft tokens that do not
|
||||
match this probability. Additionally, it ensures that when all draft
|
||||
tokens are rejected, the sampler falls back to greedy sampling to select a
|
||||
single token from the target distribution.
|
||||
"""
|
||||
set_random_seed(seed)
|
||||
k = 3
|
||||
batch_size = 5
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
|
||||
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
# Simulate temperature 0 probability distribution for target probabilities
|
||||
# and create target probabilities such that only 1 token id has
|
||||
# probability 1.0
|
||||
target_with_bonus_probs, zero_temperature_token_ids = \
|
||||
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
|
||||
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
|
||||
# Populate draft_token_ids such that they exclude the token_ids
|
||||
# with probability = 1.0
|
||||
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
|
||||
zero_temperature_token_ids)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
# The target probaility distribution is a temperature zero distribution
|
||||
# with zero entroy. Since our draft token ids don't match the probability
|
||||
# 1.0 tokens in the target distribution we will reject all of them and
|
||||
# fallback to the greedy sampling for selecting 1 token for each sequence.
|
||||
# Verify the same.
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, -1] == -1)
|
||||
assert torch.all(output_token_ids[:, 0] == zero_temperature_token_ids[:,
|
||||
0])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(10)))
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_mixed_target_distribution(seed: int, device: str):
|
||||
"""
|
||||
Test the TypicalAcceptanceSampler with a mixed target probability
|
||||
distribution.
|
||||
|
||||
This test ensures that the TypicalAcceptanceSampler handles a mixed
|
||||
target probability distribution correctly. Specifically, it uses a
|
||||
zero-temperature distribution for some sequences and a uniform
|
||||
distribution for others. The test verifies that:
|
||||
|
||||
- For sequences with a zero-temperature distribution, only the token
|
||||
with a probability of 1.0 is accepted, and all other tokens are rejected.
|
||||
- For sequences with a uniform distribution, all draft tokens are
|
||||
accepted.
|
||||
"""
|
||||
set_random_seed(seed)
|
||||
k = 3
|
||||
batch_size = 4
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
# For sequences 0 and 2 set the distribution to a temperature
|
||||
# zero distribution. For sequences 1 and 3 set it to a uniform
|
||||
# distribution.
|
||||
target_with_bonus_probs, zero_temperature_token_ids = \
|
||||
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
|
||||
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
|
||||
target_probs = target_with_bonus_probs[:, :-1]
|
||||
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
|
||||
zero_temperature_token_ids)
|
||||
uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32)
|
||||
target_probs[[1, 3]] = uniform_probs
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
# verify the shape of output_token_ids
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
# For sequences 0 and 2 verify that only 1 token is accepted
|
||||
# which is the token with probability 1.0 in the target distribution
|
||||
# at position 0.
|
||||
assert torch.all(output_token_ids[[0, 2], 1:] == -1)
|
||||
assert (torch.all(output_token_ids[[0, 2],
|
||||
0] == zero_temperature_token_ids[[0, 2],
|
||||
0]))
|
||||
# For sequences 1 and 3 verify that all tokens are accepted since the
|
||||
# target probability distribution is uniform. In addition verify that
|
||||
# we also accept the bonus tokens.
|
||||
assert torch.all(
|
||||
output_token_ids[[1, 3], :-1] == draft_token_ids[[1, 3], :])
|
||||
assert torch.all(output_token_ids[[1, 3], -1] != -1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(10)))
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_accept_tokens_partially(seed: int, device: str):
|
||||
"""
|
||||
Test the TypicalAcceptanceSampler's behavior when only a subset of draft
|
||||
tokens should be accepted.
|
||||
|
||||
This test verifies that the TypicalAcceptanceSampler correctly accepts or
|
||||
rejects draft tokens based on a zero-temperature target probability
|
||||
distribution. Specifically, it ensures that:
|
||||
|
||||
- When all draft tokens match tokens with a probability of 1.0 in the
|
||||
target distribution, all draft tokens are accepted.
|
||||
- When only some draft tokens match tokens with a probability of 1.0 in
|
||||
the target distribution, only those matching tokens are accepted, and the
|
||||
rest are rejected.
|
||||
"""
|
||||
set_random_seed(seed)
|
||||
k = 5
|
||||
batch_size = 1
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
# Create a temperature zero target probability distribution and ensure
|
||||
# all draft token ids correspond to the tokens with 1.0 probability.
|
||||
# Verify that all of them are accepted.
|
||||
target_with_bonus_probs, zero_temperature_token_ids = \
|
||||
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
|
||||
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
|
||||
draft_token_ids = zero_temperature_token_ids
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
|
||||
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
|
||||
# Next only keep the first 2 draft tokens same as the zero temperature
|
||||
# tokens. For the remaining 3 choose some other tokens. In the
|
||||
# response we will expect the first 2 tokens to be the same as the
|
||||
# draft tokens and the recovered token and rest as -1
|
||||
draft_token_ids_to_replace = get_draft_token_ids(
|
||||
batch_size, k, vocab_size, zero_temperature_token_ids)
|
||||
draft_token_ids = torch.cat(
|
||||
(draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_with_bonus_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2])
|
||||
assert torch.all(
|
||||
output_token_ids[:, 2] == target_with_bonus_probs.argmax(-1)[:, 2])
|
||||
assert torch.all(output_token_ids[:, -3:] == -1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(1)))
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_accept_tokens_set_non_default_posteriors(seed: int, device: str):
|
||||
"""
|
||||
Test the TypicalAcceptanceSampler with custom posterior thresholds and
|
||||
alpha values. This test verifies that by modifying the posterior
|
||||
thresholds and alpha values we can change the acceptance behavior of the
|
||||
sampler.
|
||||
"""
|
||||
set_random_seed(seed)
|
||||
k = 5
|
||||
batch_size = 1
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
# Simulate temperature 0 probability distribution for target
|
||||
# probabilities and create target probabilities such that only 1 token
|
||||
# id has probability 1.0 and others have a very low probability of
|
||||
# 0.00001. Populate draft_token_ids such that they exclude the token_ids
|
||||
# with probability = 1.0. Without any changes to the posterior thresholds
|
||||
# none of the draft tokens are accepted.
|
||||
target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist(
|
||||
batch_size, k + 1, vocab_size)
|
||||
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
|
||||
target_probs[target_probs == 0] = 0.00001
|
||||
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
|
||||
zero_temperature_token_ids)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, 1:-1] == -1)
|
||||
|
||||
# Change the posterior threshold values to 0.0 so that we will
|
||||
# now accept even draft tokens with very low probability in the
|
||||
# target distribution. Simulate and verify the same.
|
||||
typical_acceptance_sampler = TypicalAcceptanceSampler(
|
||||
strict_mode=True, posterior_threshold=0.0, posterior_alpha=0.0)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
output_token_ids = typical_acceptance_sampler(
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
draft_probs=None,
|
||||
draft_token_ids=draft_token_ids)
|
||||
assert output_token_ids.shape[0] == batch_size
|
||||
assert output_token_ids.shape[1] == (k + 1)
|
||||
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
|
||||
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(10)))
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_get_recovered_token_ids(seed: int, device: str):
|
||||
"""
|
||||
Test the TypicalAcceptanceSampler's method for generating
|
||||
replacement token IDs.
|
||||
|
||||
This test verifies that the `_get_recovered_token_ids` method of the
|
||||
TypicalAcceptanceSampler correctly identifies the token IDs to be used
|
||||
as recovered token IDs based on the target probability distribution.
|
||||
Specifically, it ensures that the method correctly identifies the
|
||||
tokens with the highest probability for each sequence in the batch.
|
||||
"""
|
||||
set_random_seed(seed)
|
||||
k = 10
|
||||
batch_size = 5
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
expected_replacement_tokens = torch.argmax(target_probs, dim=-1)
|
||||
actual_replacement_tokens = (
|
||||
typical_acceptance_sampler._get_recovered_token_ids(target_probs))
|
||||
assert torch.all(expected_replacement_tokens == actual_replacement_tokens)
|
||||
Reference in New Issue
Block a user