add qwen3

This commit is contained in:
Chranos
2026-02-04 17:22:39 +08:00
parent d1c0f68ab4
commit 8511fe8530
1932 changed files with 300426 additions and 0 deletions

View File

View File

@@ -0,0 +1,53 @@
"""Compare the outputs of HF and vLLM when using beam search.
Run `pytest tests/samplers/test_beam_search.py`.
"""
import pytest
# FIXME(zhuohan): The test can not pass if we:
# 1. Increase max_tokens to 256.
# 2. Increase beam_width to 8.
# 3. Use the model "huggyllama/llama-7b".
MAX_TOKENS = [64]
BEAM_WIDTHS = [4]
MODELS = ["facebook/opt-125m"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
@pytest.mark.parametrize("beam_width", BEAM_WIDTHS)
def test_beam_search_single_input(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
beam_width: int,
) -> None:
example_prompts = example_prompts[:1]
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
max_tokens)
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search(example_prompts,
beam_width, max_tokens)
for i in range(len(example_prompts)):
hf_output_ids, hf_output_texts = hf_outputs[i]
vllm_output_ids, vllm_output_texts = vllm_outputs[i]
for i, (hf_text,
vllm_text) in enumerate(zip(hf_output_texts,
vllm_output_texts)):
print(f">>>{i}-th hf output:")
print(hf_text)
print(f">>>{i}-th vllm output:")
print(vllm_text)
assert len(hf_output_ids) == len(vllm_output_ids)
for j in range(len(hf_output_ids)):
assert hf_output_ids[j] == vllm_output_ids[j], (
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
f"vLLM: {vllm_output_ids}")

View File

@@ -0,0 +1,33 @@
"""Make sure ignore_eos works.
Run `pytest tests/samplers/test_ignore_eos.py`.
"""
import pytest
from vllm import SamplingParams
# We also test with llama because it has generation_config to specify EOS
# (past regression).
MODELS = ["facebook/opt-125m", "meta-llama/Llama-2-7b-hf"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [512])
def test_ignore_eos(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
sampling_params = SamplingParams(max_tokens=max_tokens,
ignore_eos=True)
for prompt in example_prompts:
ignore_eos_output = vllm_model.model.generate(
prompt, sampling_params=sampling_params)
output_length = len(ignore_eos_output[0].outputs[0].token_ids)
assert output_length == max_tokens

View File

@@ -0,0 +1,59 @@
import pytest
import torch
from vllm import SamplingParams
MODELS = ["facebook/opt-125m"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_logits_processor_force_generate(
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
tokenizer = vllm_model.model.get_tokenizer()
repeat_times = 2
enforced_answers = " vLLM"
vllm_token_ids = tokenizer.encode(enforced_answers,
add_special_tokens=False)
max_tokens = len(vllm_token_ids) * repeat_times
def pick_vllm(token_ids, logits):
token_id = vllm_token_ids[len(token_ids) % len(vllm_token_ids)]
logits[token_id] = torch.finfo(logits.dtype).max
return logits
params_with_logprobs = SamplingParams(
logits_processors=[pick_vllm],
prompt_logprobs=3,
max_tokens=max_tokens,
)
# test logits_processors when prompt_logprobs is not None
vllm_model.model._add_request(
example_prompts[0],
params=params_with_logprobs,
)
# test prompt_logprobs is not None
vllm_model.model._add_request(
example_prompts[1],
params=SamplingParams(
prompt_logprobs=3,
max_tokens=max_tokens,
),
)
# test grouped requests
vllm_model.model._add_request(
example_prompts[2],
params=SamplingParams(max_tokens=max_tokens),
)
outputs = vllm_model.model._run_engine(use_tqdm=False)
assert outputs[0].outputs[0].text == enforced_answers * repeat_times

View File

@@ -0,0 +1,182 @@
from typing import List
import pytest
import torch
from vllm import SamplingParams
from ..conftest import VllmRunner
MODELS = ["facebook/opt-125m"]
'''
=============================
Modify by vllm_mlu
=============================
NOTES: chunked_prefill_token_size=1 contains some accuracy issue.
So we skip this case in mlu ut.
TODO(VLLM-662): fix accuracy error
'''
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype",
["float"]) # needed for comparing logprobs with HF
@pytest.mark.parametrize("chunked_prefill_token_size", [4, 16, -1])
@pytest.mark.parametrize("num_top_logprobs", [0, 6]) # 32000 == vocab_size
@pytest.mark.parametrize("detokenize", [True, False])
def test_get_prompt_logprobs(
hf_runner,
vllm_runner,
model,
dtype,
chunked_prefill_token_size: int,
num_top_logprobs: int,
detokenize: bool,
example_prompts,
):
max_num_seqs = 256
enable_chunked_prefill = False
max_num_batched_tokens = None
if chunked_prefill_token_size != -1:
enable_chunked_prefill = True
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
max_num_batched_tokens = chunked_prefill_token_size
max_tokens = 5
with hf_runner(model, dtype=dtype) as hf_model:
hf_logprobs = hf_model.generate_greedy_logprobs(
example_prompts,
max_tokens=max_tokens,
)
with vllm_runner(
model,
dtype=dtype,
max_logprobs=num_top_logprobs,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs,
gpu_memory_utilization=0.6,
) as vllm_model:
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
logprobs=num_top_logprobs,
prompt_logprobs=num_top_logprobs,
temperature=0.0,
detokenize=detokenize)
vllm_results = vllm_model.model.generate(
example_prompts, sampling_params=vllm_sampling_params)
# Test whether logprobs are included in the results.
for result in vllm_results:
assert result.prompt_logprobs is not None
assert result.outputs[0].logprobs is not None
assert len(result.outputs[0].logprobs) == max_tokens
for logprobs in result.outputs[0].logprobs:
# If the output token is not included in the top X
# logprob, it can return 1 more data
assert (len(logprobs) == num_top_logprobs
or len(logprobs) == num_top_logprobs + 1)
output_text = result.outputs[0].text
output_string_from_most_likely_tokens_lst: List[str] = []
for top_logprobs in result.outputs[0].logprobs:
top_logprob = next(iter(top_logprobs.values()))
output_string_from_most_likely_tokens_lst.append(
top_logprob.decoded_token)
if detokenize:
output_string_from_most_likely_tokens = "".join(
output_string_from_most_likely_tokens_lst)
assert output_text == output_string_from_most_likely_tokens, (
"The output text from the top logprob for each token position "
"should be the same as the output text in the result.")
else:
assert output_text == ''
assert output_string_from_most_likely_tokens_lst == ([None] *
max_tokens)
# The first prompt logprob is always None
assert result.prompt_logprobs[0] is None
for prompt_logprobs in result.prompt_logprobs[1:]:
# If the prompt token is not included in the top X
# logprob, it can return 1 more data
assert (len(prompt_logprobs) == num_top_logprobs
or len(prompt_logprobs) == num_top_logprobs + 1)
# Test whether prompt logprobs are consistent with HF
for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs):
# Check prompt logprobs
# The first prompt logprob is always None, so we compare it from 1:.
vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
for token_id, logprob in vllm_prompt_logprob_dict.items():
torch.testing.assert_close(logprob.logprob,
hf_logprob[0][i][token_id].item(),
atol=1e-2,
rtol=1e-2)
vllm_sample_logprobs = vllm_result.outputs[0].logprobs
for i, top_logprobs in enumerate(vllm_sample_logprobs):
for token_id, sample_logprob in top_logprobs.items():
logprob = sample_logprob.logprob
torch.testing.assert_close(logprob,
hf_logprob[i][-1][token_id].item(),
atol=1e-2,
rtol=1e-2)
if detokenize:
assert isinstance(sample_logprob.decoded_token, str), (
"The token should be decoded by the time it is returned"
" to the user.")
# Test if prompt logprobs are correctly set.
for vllm_result in vllm_results:
token_ids = vllm_result.prompt_token_ids
prompt_logprobs = vllm_result.prompt_logprobs
# The first token doesn't have logprob.
assert prompt_logprobs[0] is None
for token_id, logprob_dict in zip(token_ids[1:], prompt_logprobs[1:]):
assert token_id in logprob_dict
def test_max_logprobs():
runner = VllmRunner("facebook/opt-125m", max_logprobs=1)
vllm_sampling_params = SamplingParams(logprobs=1)
# should pass
runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
bad_sampling_params = SamplingParams(logprobs=2)
with pytest.raises(ValueError):
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
@pytest.mark.parametrize("detokenize", [True, False])
def test_none_logprobs(vllm_runner, model, chunked_prefill_token_size: int,
detokenize: bool, example_prompts):
max_num_seqs = 256
enable_chunked_prefill = False
max_num_batched_tokens = None
if chunked_prefill_token_size != -1:
enable_chunked_prefill = True
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
max_num_batched_tokens = chunked_prefill_token_size
max_tokens = 5
with vllm_runner(
model,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs,
gpu_memory_utilization=0.6,
) as vllm_model:
sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens,
logprobs=None,
temperature=0.0,
detokenize=detokenize)
results_logprobs_none = vllm_model.model.generate(
example_prompts, sampling_params=sampling_params_logprobs_none)
for i in range(len(results_logprobs_none)):
assert results_logprobs_none[i].outputs[0].logprobs is None
assert results_logprobs_none[i].outputs[0].cumulative_logprob is None

View File

@@ -0,0 +1,185 @@
"""Make sure bad_words works.
Run `pytest tests/samplers/test_no_bad_words.py`.
"""
from typing import List, Optional
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
def _generate(
model: LLM,
prompt: str,
num_prompt_tokens: int,
temperature: float = 0,
bad_words: Optional[List[str]] = None,
) -> List[int]:
sampling_params = SamplingParams(
temperature=temperature,
bad_words=bad_words,
)
# [([output_token_ids, ], [output_text, ]), ]
output = model.generate([prompt], sampling_params=sampling_params)
output_token_ids = output[0][0][0][num_prompt_tokens:]
# [0] first (and only) request output
# [0] token_ids (not text)
# [0] first (and only) output completion
return output_token_ids
class TestOneTokenBadWord:
MODEL = "meta-llama/Llama-2-7b-hf"
PROMPT = "Hi! How are"
TARGET_TOKEN = "you"
def setup_method(self, method):
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL,
add_prefix_space=True)
self.num_prompt_tokens = len(self._encode(self.PROMPT))
self.target_token_id = self._encode(self.TARGET_TOKEN,
add_special_tokens=False)[0]
def test_one_token_bad_word(self, vllm_runner):
with vllm_runner(self.MODEL) as llm:
output_token_ids = self._generate(llm)
assert output_token_ids[0] == self.target_token_id
output_token_ids = self._generate(llm,
bad_words=[self.TARGET_TOKEN])
assert self.target_token_id not in output_token_ids
def _generate(self,
model: LLM,
bad_words: Optional[List[str]] = None) -> List[int]:
return _generate(
model=model,
prompt=self.PROMPT,
num_prompt_tokens=self.num_prompt_tokens,
bad_words=bad_words,
)
def _encode(self,
prompt: str,
add_special_tokens: bool = True) -> List[int]:
return self.tokenizer(prompt,
add_special_tokens=add_special_tokens).input_ids
class TestTwoTokenBadWord:
# Another model (with a different tokenizer behaviour)
MODEL = "openai-community/gpt2"
PROMPT = "How old are you? I am 10"
TARGET_TOKEN1 = "years"
TARGET_TOKEN2 = "old"
NEIGHBOUR_TOKEN2 = "older"
def setup_method(self, method):
self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL,
add_prefix_space=True)
self.num_prompt_tokens = len(self._encode(self.PROMPT))
self.target_token_id1 = self._encode(self.TARGET_TOKEN1,
add_special_tokens=False)[0]
self.target_token_id2 = self._encode(self.TARGET_TOKEN2,
add_special_tokens=False)[0]
self.neighbour_token_id2 = self._encode(self.NEIGHBOUR_TOKEN2,
add_special_tokens=False)[0]
def test_two_token_bad_word(self, vllm_runner):
with vllm_runner(self.MODEL) as llm:
output_token_ids = self._generate(llm)
assert output_token_ids[:2] == [
self.target_token_id1, self.target_token_id2
]
output_token_ids = self._generate(llm,
bad_words=[self.TARGET_TOKEN1])
assert self.target_token_id1 not in output_token_ids
output_token_ids = self._generate(llm,
bad_words=[self.TARGET_TOKEN2])
assert output_token_ids[0] == self.target_token_id1
assert self.target_token_id2 not in output_token_ids
output_token_ids = self._generate(
llm, bad_words=[f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}'])
assert output_token_ids[0] == self.target_token_id1
assert output_token_ids[:2] != [
self.target_token_id1, self.target_token_id2
]
assert not self._contains(
output_token_ids,
[self.target_token_id1, self.target_token_id2])
# Model dependent behaviour
assert output_token_ids[:2] == [
self.target_token_id1, self.neighbour_token_id2
]
output_token_ids = self._generate(
llm,
bad_words=[
f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}',
f'{self.TARGET_TOKEN1} {self.NEIGHBOUR_TOKEN2}'
])
assert output_token_ids[0] == self.target_token_id1
assert output_token_ids[:2] != [
self.target_token_id1, self.target_token_id2
]
assert not self._contains(
output_token_ids,
[self.target_token_id1, self.target_token_id2])
assert output_token_ids[:2] != [
self.target_token_id1, self.neighbour_token_id2
]
assert not self._contains(
output_token_ids,
[self.target_token_id1, self.neighbour_token_id2])
assert ((self.target_token_id2 in output_token_ids)
or (self.neighbour_token_id2 in output_token_ids))
def _generate(self,
model: LLM,
bad_words: Optional[List[str]] = None) -> List[int]:
return _generate(
model=model,
prompt=self.PROMPT,
num_prompt_tokens=self.num_prompt_tokens,
bad_words=bad_words,
)
@staticmethod
def _contains(sequence: List[int], subsequence: List[int]) -> bool:
searched = False
for start in range(len(sequence)):
end = start + len(subsequence)
current_subsequence = sequence[start:end]
if len(current_subsequence) < len(subsequence):
continue
searched = True
assert len(current_subsequence) == len(subsequence)
if current_subsequence == subsequence:
return True
assert searched, "All subsequences did not match in length..."
return False
def _encode(self,
prompt: str,
add_special_tokens: bool = True) -> List[int]:
return self.tokenizer(prompt,
add_special_tokens=add_special_tokens).input_ids

View File

@@ -0,0 +1,54 @@
import pytest
from vllm import SamplingParams
MODELS = ["facebook/opt-125m"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_ranks(
vllm_runner,
model,
dtype,
example_prompts,
):
max_tokens = 5
num_top_logprobs = 5
num_prompt_logprobs = 5
with vllm_runner(model, dtype=dtype,
max_logprobs=num_top_logprobs) as vllm_model:
## Test greedy logprobs ranks
vllm_sampling_params = SamplingParams(
temperature=0.0,
top_p=1.0,
max_tokens=max_tokens,
logprobs=num_top_logprobs,
prompt_logprobs=num_prompt_logprobs)
vllm_results = vllm_model.generate_w_logprobs(example_prompts,
vllm_sampling_params)
## Test non-greedy logprobs ranks
sampling_params = SamplingParams(temperature=1.0,
top_p=1.0,
max_tokens=max_tokens,
logprobs=num_top_logprobs,
prompt_logprobs=num_prompt_logprobs)
res = vllm_model.generate_w_logprobs(example_prompts, sampling_params)
for result in vllm_results:
assert result[2] is not None
assert len(result[2]) == len(result[0])
# check whether all chosen tokens have ranks = 1
for token, logprobs in zip(result[0], result[2]):
assert token in logprobs
assert logprobs[token].rank == 1
for result in res:
assert result[2] is not None
assert len(result[2]) == len(result[0])
# check whether all chosen tokens have ranks
for token, logprobs in zip(result[0], result[2]):
assert logprobs[token].rank >= 1

View File

@@ -0,0 +1,511 @@
"""Tests for rejection sampling."""
from typing import List, Tuple
import pytest
import torch
import torch.nn.functional as F
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.utils import set_random_seed
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1)
]
def mock_causal_accepted_tensor(
k: int, last_accepted_indices: torch.Tensor) -> torch.Tensor:
"""Generate an "accepted" tensor which should yield causally-accepted tokens
up to last accepted indices.
Tokens after last_accepted_indices+1 may also be accepted, although they
will not be causally accepted.
"""
batch_size = last_accepted_indices.shape[0]
accepted = (torch.arange(k).expand(batch_size, k) <=
last_accepted_indices.unsqueeze(-1).broadcast_to(
batch_size, k))
# Sprinkle accepted values after the contiguous initial accepted values.
# This replicates the behavior of rejection sampling, which may "accept"
# a token that cannot be accepted because of causality.
sprinkle_candidates = (
torch.arange(k).expand(batch_size, k) >
last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) + 1)
sprinkle = torch.rand(batch_size, k) > 0.5
accepted[sprinkle_candidates] = sprinkle[sprinkle_candidates]
return accepted
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize(
"which_tokens_accepted",
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False])
@torch.inference_mode()
def test_correct_output_format(which_tokens_accepted: str, seed: int,
device: str, use_flashinfer: bool):
"""Verify the output has correct format given predetermined accepted matrix.
"""
set_random_seed(seed)
torch.set_default_device(device)
batch_size = 10
k = 5
vocab_size = 3000
if which_tokens_accepted == "all_tokens_accepted":
accepted = mock_causal_accepted_tensor(
k, -1 + k * torch.ones((batch_size, ), dtype=torch.long))
elif which_tokens_accepted == "no_tokens_accepted":
accepted = mock_causal_accepted_tensor(
k, -torch.ones((batch_size, ), dtype=torch.long))
elif which_tokens_accepted == "some_tokens_accepted":
last_accepted_indices = torch.randint(low=-1,
high=k,
size=(batch_size, ))
accepted = mock_causal_accepted_tensor(k, last_accepted_indices)
else:
raise AssertionError()
recovered_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
accepted,
recovered_token_ids,
draft_token_ids,
bonus_token_ids,
)
expected_bonus_token_ids = bonus_token_ids.clone()
if which_tokens_accepted == "all_tokens_accepted":
# Expect all tokens to be equal to draft tokens.
assert torch.equal(output_token_ids[:, :-1], draft_token_ids)
# Expect all bonus tokens to be included.
assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)
elif which_tokens_accepted == "no_tokens_accepted":
# Expect first token to be equal to recovered tokens.
assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])
# Expect everything else to be -1.
assert torch.equal(output_token_ids[:, 1:],
torch.ones_like(output_token_ids[:, 1:]) * -1)
elif which_tokens_accepted == "some_tokens_accepted":
recovered_plus_bonus = torch.cat(
(recovered_token_ids, expected_bonus_token_ids), dim=-1)
# Assert first rejected token is a recovered token or bonus token.
assert torch.equal(
recovered_plus_bonus[torch.arange(0, batch_size),
last_accepted_indices + 1],
output_token_ids[torch.arange(0, batch_size),
last_accepted_indices + 1])
# Assert every subsequent token is -1.
subsequent_mask = torch.arange(0, k + 1).expand(
batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1)
assert torch.all(output_token_ids[subsequent_mask] == -1)
'''
=============================
Modify by vllm_mlu
=============================
@brief(use_flashinfer): MLU device only support MLU_FLASH_ATTN backend
'''
@pytest.mark.parametrize("k", list(range(1, 6)))
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [False])
@torch.inference_mode()
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
device: str, use_flashinfer: bool):
torch.set_default_device(device)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids)
'''
=============================
Modify by vllm_mlu
=============================
@brief(use_flashinfer): MLU device only support MLU_FLASH_ATTN backend
'''
@pytest.mark.parametrize("frac_seeded", [0.0, 0.25, 0.5, 1.0])
@pytest.mark.parametrize("k", [1, 3, 6])
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
@pytest.mark.parametrize("n_rep", [100])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [False])
@torch.inference_mode()
def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
frac_seeded: float, n_rep: int, device: str,
use_flashinfer: bool):
torch.set_default_device(device)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded
results = []
for _ in range(n_rep):
seeded_seqs = {
i: torch.Generator(device=device).manual_seed(i)
for i in range(batch_size) if seeded_mask[i]
}
results.append(
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids, seeded_seqs))
for i in range(batch_size):
if seeded_mask[i]:
for j in range(1, n_rep):
assert torch.equal(results[j][i], results[0][i])
@pytest.mark.skip("Skip flashinfer test case for MLU.")
@pytest.mark.parametrize("k", [1, 3, 6])
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
batch_size: int, device: str):
"""
Test the flashinfer and nonflashinfer backend generate
the same output metrics.
"""
torch.set_default_device(device)
torch.manual_seed(0)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
num_accepted_tokens = []
num_emitted_tokens = []
num_draft_tokens = []
def get_seeded_seqs():
return {
i: torch.Generator(device=device).manual_seed(i)
for i in range(batch_size)
}
for use_flashinfer in [True, False]:
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
# We use seeded sequences to ensure the same tokens are accepted
# for both flashinfer and nonflashinfer backends.
seeded_seqs = get_seeded_seqs()
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids, seeded_seqs)
num_accepted_tokens.append(rejection_sampler.num_accepted_tokens)
num_emitted_tokens.append(rejection_sampler.num_emitted_tokens)
num_draft_tokens.append(rejection_sampler.num_draft_tokens)
assert num_accepted_tokens[0] == num_accepted_tokens[1]
assert num_emitted_tokens[0] == num_emitted_tokens[1]
assert num_draft_tokens[0] == num_draft_tokens[1]
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
@pytest.mark.parametrize("which_token_ids",
["bonus_token_ids", "draft_token_ids"])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_flashinfer", [True, False])
@torch.inference_mode()
def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
which_token_ids: str, device: str,
use_flashinfer: bool):
k = 3
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer,
strict_mode=True)
rejection_sampler.init_gpu_tensors(device=device)
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
oob_token_ids = None
if which_token_ids == "bonus_token_ids":
oob_token_ids = bonus_token_ids
elif which_token_ids == "draft_token_ids":
oob_token_ids = draft_token_ids
else:
raise AssertionError()
if above_or_below_vocab_range == "above":
rogue_token_id = vocab_size + 1
elif above_or_below_vocab_range == "below":
rogue_token_id = -1
else:
raise AssertionError()
oob_token_ids[0][0] = rogue_token_id
with pytest.raises(AssertionError):
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids)
'''
=============================
Modify by vllm_mlu
=============================
@brief(use_flashinfer): MLU device only support MLU_FLASH_ATTN backend
'''
@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
@pytest.mark.parametrize("seed", list(range(5)))
@pytest.mark.parametrize("use_flashinfer", [False])
@torch.inference_mode()
def test_rejection_sampling_approximates_target_distribution(
seed: int, draft_and_target_probs_equal: bool, use_flashinfer: bool):
"""Verify rejection sampling approximates target distribution,
despite sampling from a potentially distinct draft distribution.
This is done by first creating a random target probability
distribution and a random draft probability distribution. We then
sample token ids from the rejection sampler using these draft
and target distributions. The samples are used to estimate
the output probability distribution, which we expect to approximate
the target distribution.
A basic distance metric is used to determine similarity between
distributions.
We expect that as we increase the number of samples,
the distance between the observed distribution and the target
distribution decreases. To measure this, we compare the distance
of the observed distribution against both the target distribution
and a uniform random distribution. We expect the distance between
the observed distribution and the target distribution to improve
much more than the distance improvement between the observed
distribution and the random distribution.
When draft_and_target_probs_equal=True, the draft and target
probabilities are exactly equal. Rejection sampling should
still work without any NaNs or exceptions.
"""
torch.set_default_device("cpu")
set_random_seed(seed)
helper = _CorrectnessTestHelper(
vocab_size=10,
rejection_sampler=RejectionSampler(use_flashinfer=use_flashinfer),
)
draft_probs, target_probs, reference_probs = helper.generate_probs_for_test(
draft_and_target_probs_equal)
sample_sizes = [10, 100, 1_000, 10_000, 100_000]
distance_wrt_reference: List[float] = []
distance_wrt_target: List[float] = []
for num_samples in sample_sizes:
(reference_vs_rejsample_dist,
target_vs_rejsample_dist) = helper.run_and_compare_distributions(
draft_probs,
target_probs,
reference_probs,
num_samples,
)
distance_wrt_reference.append(reference_vs_rejsample_dist)
distance_wrt_target.append(target_vs_rejsample_dist)
relative_change_in_distance_wrt_target = get_ratio_first_to_last(
distance_wrt_target)
relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
distance_wrt_reference)
print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
f"{reference_vs_rejsample_dist=:.05f}")
print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} "
f"{relative_change_in_distance_wrt_reference=:.02f}")
relative_change_in_distance_wrt_target = get_ratio_first_to_last(
distance_wrt_target)
relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
distance_wrt_reference)
expected_improvement_multiplier = 20
assert (relative_change_in_distance_wrt_target >
relative_change_in_distance_wrt_reference *
expected_improvement_multiplier)
def get_ratio_first_to_last(elements: List[float]) -> float:
return elements[0] / elements[-1]
class _CorrectnessTestHelper:
"""Class that packages together logic required for the unit-level
rejection sampling correctness test.
"""
def __init__(self, vocab_size: int, rejection_sampler: RejectionSampler):
self.rejection_sampler = rejection_sampler
self.vocab_size = vocab_size
self.vocab_range = (0, vocab_size)
self.rejection_sampler.init_gpu_tensors(device=0)
# Keep test simple, use k=1
self.k = 1
# Bonus tokens not used, but rejection sampler requires
# correct shape.
self.num_bonus_tokens = 1
def generate_probs_for_test(
self, draft_and_target_probs_equal: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
draft_probs, target_probs = (F.softmax(
torch.rand(self.vocab_size, dtype=torch.float32),
dim=-1,
) for _ in range(2))
num_reference_probs = 100
reference_probs = F.softmax(
torch.rand(num_reference_probs,
self.vocab_size,
dtype=torch.float32),
dim=-1,
)
if draft_and_target_probs_equal:
target_probs = draft_probs.clone()
return draft_probs, target_probs, reference_probs
def run_and_compare_distributions(self, draft_probs: torch.Tensor,
target_probs: torch.Tensor,
reference_probs: torch.Tensor,
num_samples: int) -> Tuple[float, float]:
# Sample using rejection sampling.
rej_sample_probs = self._estimate_rejection_sampling_pdf(
draft_probs, target_probs, num_samples)
# Average distance from reference probs.
reference_vs_rejsample_dist = torch.dist(
reference_probs,
rej_sample_probs).item() / reference_probs.shape[0]
target_vs_rejsample_dist = torch.dist(target_probs,
rej_sample_probs).item()
return reference_vs_rejsample_dist, target_vs_rejsample_dist
def _estimate_rejection_sampling_pdf(
self,
draft_probs: torch.Tensor,
target_probs: torch.Tensor,
num_samples: int,
) -> torch.Tensor:
# Repeat draft probs num_samples times.
draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat(
num_samples, 1, 1)
# Repeat target probs num_samples * (k + 1) times.
# Rejection sampler requires bonus token probs, but they aren't used.
target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat(
num_samples, self.k + 1, 1)
# Randomly sample draft token ids from draft probs.
draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
num_samples=1,
replacement=True).reshape(
num_samples, self.k)
# Bonus tokens not used but required.
bonus_token_ids = torch.zeros((1, self.num_bonus_tokens),
dtype=torch.int64,
device="cuda").repeat(num_samples, 1)
# Get output tokens via rejection sampling.
output_token_ids = self.rejection_sampler(target_probs.to("cuda"),
bonus_token_ids.to("cuda"),
draft_probs.to("cuda"),
draft_token_ids.to("cuda"))
# Remove bonus tokens
output_token_ids = output_token_ids[:, :-1].flatten()
# Estimate probability density function
hist = torch.histogram(output_token_ids.to(dtype=torch.float,
device="cpu"),
bins=self.vocab_size,
range=self.vocab_range,
density=True)
return hist.hist

View File

@@ -0,0 +1,758 @@
import itertools
import random
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from unittest.mock import Mock, patch
import pytest
import torch
from transformers import GenerationConfig, GenerationMixin
import vllm.envs as envs
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import Counter, is_pin_memory_available
class MockLogitsSampler(Sampler):
def __init__(self, fake_logits: torch.Tensor):
super().__init__()
self.fake_logits = fake_logits
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
def _prepare_test(
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]:
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
fake_logits = torch.full((batch_size, VOCAB_SIZE),
1e-2,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(fake_logits)
return input_tensor, fake_logits, sampler
VOCAB_SIZE = 32000
RANDOM_SEEDS = list(range(128))
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1)
]
def _do_sample(
batch_size: int,
input_tensor: torch.Tensor,
sampler: MockLogitsSampler,
sampling_params: SamplingParams,
device: str,
):
seq_group_metadata_list: List[SequenceGroupMetadata] = []
seq_lens: List[int] = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=sampling_params,
block_tables={0: [1]},
))
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=is_pin_memory_available())
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_greedy(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler = _prepare_test(batch_size)
sampling_params = SamplingParams(temperature=0)
sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)
expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
assert nth_output.output_token == expected[i].item()
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler = _prepare_test(batch_size)
for i in range(batch_size):
fake_logits[i, i] = 1e2
sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
)
sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
assert nth_output.output_token == i
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler = _prepare_test(batch_size)
for i in range(batch_size):
fake_logits[i, i] = 1e2
sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
assert nth_output.output_token == i
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler = _prepare_test(batch_size)
sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)
second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)
assert first_sampler_output == second_sampler_output
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_min_tokens_penalty(seed: int, device: str):
seq_id_counter = Counter(start=random.randint(0, 100))
set_random_seed(seed)
torch.set_default_device(device)
def create_sampling_params(min_tokens,
eos_token_id=0,
*,
stop_token_ids: Optional[List[int]] = None,
prompt_logprobs: Optional[int] = None):
sampling_params = SamplingParams(
min_tokens=min_tokens,
max_tokens=9999, # keep higher than max of min_tokens
stop_token_ids=stop_token_ids,
# requesting prompt_logprobs changes the structure of `logits`
prompt_logprobs=prompt_logprobs,
)
sampling_params.all_stop_token_ids.add(eos_token_id)
return sampling_params
def create_sequence_data(num_input=3, num_generated=0):
seq_data = SequenceData.from_seqs(
random.choices(range(0, VOCAB_SIZE), k=num_input))
if num_generated > 0:
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
k=num_generated)
return seq_data
def generate_test_case():
# generate multiple seq groups but limit total batch size
batch_size = random.randint(1, 128)
expected_penalization = []
sequence_metadata_list: List[SequenceGroupMetadata] = []
# 20% chance to generate seq group metadata list with all prompts
is_prompt = random.random() < 0.2
while batch_size > 0:
num_seqs = 1 if is_prompt else random.randint(1, batch_size)
eos_token_id = random.randint(0, VOCAB_SIZE - 1)
min_tokens = random.randint(0, 50)
num_stop_tokens = random.randint(0, 8)
if num_stop_tokens > 0:
stop_token_ids = random.choices(range(0, VOCAB_SIZE - 1),
k=num_stop_tokens)
else:
stop_token_ids = None
sampling_params = create_sampling_params(
min_tokens=min_tokens,
eos_token_id=eos_token_id,
stop_token_ids=stop_token_ids)
seq_data: Dict[int, SequenceData] = {}
seq_group_penalization: List[bool] = []
for _ in range(num_seqs):
num_input = random.randint(1, 100)
num_generated = 0 if is_prompt else random.randint(1, 100)
seq_data[next(seq_id_counter)] = create_sequence_data(
num_input=num_input, num_generated=num_generated)
seq_group_penalization.append(num_generated < min_tokens)
expected_penalization.extend(seq_group_penalization)
sequence_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{batch_size}",
is_prompt=is_prompt,
seq_data=seq_data,
sampling_params=sampling_params,
block_tables={},
))
batch_size -= num_seqs
return {
"expected_penalization": expected_penalization,
"seq_group_metadata_list": sequence_metadata_list,
}
# define some explicit test cases for edge case behavior
prompt_without_penalization = {
"expected_penalization": [False],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_1",
is_prompt=True,
seq_data={
next(seq_id_counter): create_sequence_data(),
},
sampling_params=create_sampling_params(0),
block_tables={},
),
]
}
prompt_with_penalization = {
"expected_penalization": [True],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_1",
is_prompt=True,
seq_data={
next(seq_id_counter): create_sequence_data(),
},
sampling_params=create_sampling_params(1),
block_tables={},
),
]
}
prompt_with_penalization_and_prompt_logprobs = {
"expected_penalization": [False, False, True],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_1",
is_prompt=True,
seq_data={
next(seq_id_counter): create_sequence_data(num_input=3),
},
sampling_params=create_sampling_params(1, prompt_logprobs=3),
block_tables={},
),
]
}
stop_penalizing_after_min_tokens = {
"expected_penalization": [False],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_1",
is_prompt=False,
seq_data={
next(seq_id_counter):
create_sequence_data(num_generated=1),
},
sampling_params=create_sampling_params(1),
block_tables={},
)
]
}
stop_token_ids = [42, 99, 42, 0] # intentional duplication
prompt_combination = {
"expected_penalization": [False, True, False],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_2",
is_prompt=True,
seq_data={
next(seq_id_counter): create_sequence_data(num_input=2),
},
sampling_params=create_sampling_params(1, prompt_logprobs=3),
block_tables={},
),
SequenceGroupMetadata(
request_id="test_3",
is_prompt=True,
seq_data={
next(seq_id_counter): create_sequence_data(),
},
sampling_params=create_sampling_params(
0, stop_token_ids=stop_token_ids),
block_tables={},
)
]
}
stop_token_ids = [1, 999, 37, 37] # intentional duplication
decode_combination = {
"expected_penalization": [True, False, False, True, False],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_1",
is_prompt=False,
seq_data={
next(seq_id_counter):
create_sequence_data(num_generated=1),
next(seq_id_counter):
create_sequence_data(num_generated=100),
},
sampling_params=create_sampling_params(
2, stop_token_ids=stop_token_ids),
block_tables={},
),
SequenceGroupMetadata(
request_id="test_2",
is_prompt=False,
seq_data={
next(seq_id_counter):
create_sequence_data(num_generated=20),
next(seq_id_counter):
create_sequence_data(num_generated=1),
next(seq_id_counter):
create_sequence_data(num_generated=10),
},
sampling_params=create_sampling_params(
10, prompt_logprobs=5, stop_token_ids=stop_token_ids),
block_tables={},
),
]
}
if seed == 0:
test_cases = [
prompt_without_penalization,
prompt_with_penalization,
prompt_with_penalization_and_prompt_logprobs,
stop_penalizing_after_min_tokens,
prompt_combination,
decode_combination,
]
else:
test_cases = [generate_test_case()]
def run_test_case(*, expected_penalization: List[bool],
seq_group_metadata_list: List[SequenceGroupMetadata]):
assert expected_penalization, \
"Invalid test case, need expected_penalization"
assert seq_group_metadata_list, \
"Invalid test case, need seq_group_metadata_list"
batch_size = 0
seq_lens: List[int] = []
sampling_params_per_row: List[SamplingParams] = []
for sgm in seq_group_metadata_list:
sampling_params = sgm.sampling_params
num_rows = len(sgm.seq_data)
if sgm.is_prompt:
# a prompt seq_group has only one sequence
seq_data = next(iter(sgm.seq_data.values()))
prompt_len = seq_data.get_prompt_len()
seq_lens.append(prompt_len)
assert sgm.sampling_params is not None
if sgm.sampling_params.prompt_logprobs:
# with prompt_logprobs each token in the prompt has a row in
# logits
num_rows = prompt_len
batch_size += num_rows
sampling_params_per_row.extend(
itertools.repeat(sampling_params, num_rows))
assert len(
expected_penalization
) == batch_size, \
("Invalid test case, expected_penalization does not match computed"
"batch size")
_, fake_logits, sampler = _prepare_test(batch_size)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens=seq_lens if seq_lens else None,
query_lens=seq_lens if seq_lens else [1] * batch_size,
device=device,
pin_memory=is_pin_memory_available())
# the logits tensor is modified in-place by the sampler
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
for logits_idx, (should_penalize, sampling_params) in enumerate(
zip(expected_penalization, sampling_params_per_row)):
tokens_to_check = sampling_params.all_stop_token_ids
if should_penalize:
for token_id in tokens_to_check:
assert fake_logits[logits_idx, token_id] == -float(
'inf'
), f"Expected token {token_id} for logits row {logits_idx}"
" to be penalized"
# no other tokens should be set to -inf
assert torch.count_nonzero(
fake_logits[logits_idx, :] == -float('inf')) == len(
tokens_to_check
), f"Expected only {len(tokens_to_check)} to be penalized"
else:
# no tokens should be set to -inf
assert torch.count_nonzero(
fake_logits[logits_idx, :] ==
-float('inf')) == 0, "No tokens should have been penalized"
for test_case in test_cases:
run_test_case(**test_case)
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_mixed(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler = _prepare_test(batch_size)
seq_group_metadata_list: List[SequenceGroupMetadata] = []
expected_tokens: List[Optional[List[int]]] = []
seq_lens: List[int] = []
for i in range(batch_size):
expected: Optional[List[int]] = None
sampling_type = random.randint(0, 2)
if sampling_type == 0:
sampling_params = SamplingParams(temperature=0)
expected = [int(torch.argmax(fake_logits[i], dim=-1).item())]
elif sampling_type in (1, 2):
n = random.randint(1, 10)
sampling_params = SamplingParams(
temperature=random.random() + 0.1,
top_p=min(random.random() + 0.1, 1),
top_k=random.randint(0, 10) or -1,
n=n,
presence_penalty=random.randint(0, 1),
)
if sampling_type == 2:
sampling_params.seed = random.randint(0, 10000)
else:
for idx in range(n):
fake_logits[i, i + idx] = 1e2
expected = list(range(i, i + n))
expected_tokens.append(expected)
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=sampling_params,
block_tables={0: [1]},
))
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
generators: Dict[str, torch.Generator] = {}
def test_sampling():
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=is_pin_memory_available(),
generators=generators)
sampler_output = sampler(logits=fake_logits,
sampling_metadata=sampling_metadata)
for i, (sequence_output, metadata) in enumerate(
zip(sampler_output, seq_group_metadata_list)):
assert metadata.sampling_params is not None
if (metadata.sampling_params.seed is not None
and expected_tokens[i] is None):
# Record seeded random result to compare with results of
# second invocation
expected_tokens[i] = [
nth_output.output_token
for nth_output in sequence_output.samples
]
continue
expected_tokens_item = expected_tokens[i]
assert expected_tokens_item is not None
for n, nth_output in enumerate(sequence_output.samples):
assert metadata.sampling_params is not None
if (metadata.sampling_params.temperature == 0
or metadata.sampling_params.seed is not None):
# Ensure exact matches for greedy or random with seed
assert nth_output.output_token == expected_tokens_item[n]
else:
# For non-seeded random check that one of the high-logit
# tokens were chosen
assert nth_output.output_token in expected_tokens_item
# Test batch
test_sampling()
# Shuffle the batch and resample
target_index = list(range(batch_size))
for list_to_shuffle in (target_index, seq_group_metadata_list,
expected_tokens, seq_lens):
random.Random(seed).shuffle(list_to_shuffle)
target_index = torch.tensor(target_index)
input_tensor.data = input_tensor.index_select(0, target_index)
fake_logits.data = fake_logits.index_select(0, target_index)
# This time, results of seeded random samples will be compared with
# the corresponding sample in the pre-shuffled batch
test_sampling()
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_top_k_top_p(seed: int, device: str):
if seed == 40:
pytest.skip("skip cause diff accuracy between difference device.")
set_random_seed(seed)
batch_size = random.randint(1, 256)
top_k = random.randint(100, 500)
top_p = random.random() * 0.1
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024),
device=device,
dtype=torch.float16)
fake_logits = torch.normal(0,
5,
size=(batch_size, vocab_size),
device=input_tensor.device,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(fake_logits)
generation_model = GenerationMixin()
generation_config = GenerationConfig(top_k=top_k,
top_p=top_p,
do_sample=True)
@dataclass
class MockConfig:
is_encoder_decoder: bool = False
generation_model.config = MockConfig() # needed by the following method
generation_model._prepare_special_tokens(generation_config, device=device)
processors = generation_model._get_logits_processor(generation_config,
None,
None,
None, [],
device=device)
assert len(processors) == 2 # top_p and top_k
seq_group_metadata_list: List[SequenceGroupMetadata] = []
seq_lens: List[int] = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=SamplingParams(
temperature=1,
top_k=top_k,
top_p=top_p,
),
block_tables={0: [1]},
))
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=is_pin_memory_available())
sample_probs = None
def mock_sample(probs, *args, **kwargs):
nonlocal sample_probs
sample_probs = probs
return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
for prob in probs], None)
# top-k and top-p is only calculated when flashinfer kernel is not available
with patch("vllm.model_executor.layers.sampler._sample", mock_sample), \
patch("vllm.model_executor.layers.sampler."
"flashinfer_top_k_top_p_sampling", None):
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
assert sample_probs is not None
hf_probs = processors(torch.zeros_like(fake_logits), fake_logits.clone())
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
torch.testing.assert_close(hf_probs, sample_probs, rtol=0.0, atol=1e-5)
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_flashinfer_fallback(seed: int, device: str):
if not envs.VLLM_USE_FLASHINFER_SAMPLER:
pytest.skip("Flashinfer sampler is disabled")
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler = _prepare_test(batch_size)
def failing_flashinfer_sampling(*_args, **_kwargs):
return None, torch.zeros(batch_size, device=device, dtype=torch.int32)
sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)
with patch(
"vllm.model_executor.layers.sampler."
"flashinfer_top_k_top_p_sampling", failing_flashinfer_sampling):
fallback_sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)
assert sampler_output == fallback_sampler_output
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_repetition_penalty_mixed(device: str):
vocab_size = 8
def test_sampling_params(sampling_params: List[SamplingParams]):
seq_group_metadata_list: List[SequenceGroupMetadata] = []
seq_lens: List[int] = []
for i in range(2):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=sampling_params[i],
block_tables={0: [1]},
))
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=is_pin_memory_available())
fake_logits = torch.full((2, vocab_size),
1e-2,
device=device,
dtype=torch.float16)
fake_logits[:, 5] = 1.1e-2
fake_logits[:, 1] = 1.2e-2
sampler = MockLogitsSampler(fake_logits)
sampler_output = sampler(logits=fake_logits,
sampling_metadata=sampling_metadata)
generated_tokens = []
for output in sampler_output:
generated_tokens.append(output.samples[0].output_token)
return generated_tokens
# one configuration is greedy with repetition_penalty
sampling_params_rep = SamplingParams(
temperature=0.0,
repetition_penalty=2.0,
)
# other configuration is sampling w/o repetition_penalty
sampling_params_sample = SamplingParams(
temperature=1.0,
top_k=1,
seed=42,
)
tokens1 = test_sampling_params(
[sampling_params_rep, sampling_params_sample])
tokens2 = test_sampling_params(
[sampling_params_sample, sampling_params_rep])
assert tokens1[0] == tokens2[1]
assert tokens1[1] == tokens2[0]
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_include_gpu_probs_tensor(device: str):
set_random_seed(42)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler = _prepare_test(batch_size)
sampler.include_gpu_probs_tensor = True
sampler.should_modify_greedy_probs_inplace = False
sampling_params = SamplingParams(temperature=0)
mock_inplace = Mock()
with patch(
"vllm.model_executor.layers.sampler._modify_greedy_probs_inplace",
mock_inplace):
sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)
mock_inplace.assert_not_called()
assert sampler_output.sampled_token_probs is not None
assert sampler_output.logprobs is not None
assert sampler_output.sampled_token_ids is not None

View File

@@ -0,0 +1,77 @@
"""Verify that seeded random sampling is deterministic.
Run `pytest tests/samplers/test_seeded_generate.py`.
"""
import copy
import random
from itertools import combinations
import pytest
from vllm import SamplingParams
from vllm.model_executor.utils import set_random_seed
MODEL = "facebook/opt-125m"
RANDOM_SEEDS = list(range(5))
@pytest.fixture
def vllm_model(vllm_runner):
with vllm_runner(MODEL, dtype="half") as vllm_model:
yield vllm_model
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_random_sample_with_seed(
vllm_model,
example_prompts,
seed: int,
) -> None:
set_random_seed(seed)
sampling_params = SamplingParams(
# Parameters to ensure sufficient randomness
temperature=2.0,
top_p=min(random.random() + 0.3, 1),
top_k=random.randint(5, 20),
n=random.randint(1, 10),
presence_penalty=random.randint(0, 1),
max_tokens=8,
ignore_eos=True,
)
sampling_params_seed_1 = copy.deepcopy(sampling_params)
sampling_params_seed_1.seed = 100
sampling_params_seed_2 = copy.deepcopy(sampling_params)
sampling_params_seed_2.seed = 200
llm = vllm_model.model
for prompt in example_prompts:
for params in (
sampling_params,
sampling_params_seed_1,
sampling_params_seed_2,
sampling_params,
sampling_params_seed_1,
sampling_params_seed_2,
):
llm._add_request(prompt, params=params)
results = llm._run_engine(use_tqdm=False)
all_outputs = [[out.token_ids for out in output.outputs]
for output in results]
for i in range(0, len(example_prompts), 6):
outputs = all_outputs[i:i + 6]
# verify all non-seeded requests differ
for output_a, output_b in combinations(
(outputs[0], outputs[1], outputs[2], outputs[3]),
2,
):
assert output_a != output_b
# verify requests with the same seed match
assert outputs[1] == outputs[4]
assert outputs[2] == outputs[5]

View File

@@ -0,0 +1,470 @@
"""Tests for rejection sampling."""
import pytest
import torch
from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.model_executor.utils import set_random_seed
CUDA_DEVICES = [f"cuda:{i}" for i in range(1)]
def get_zero_temperature_prob_dist(batch_size, k, vocab_size):
"""
Generates a fake temperature zero probability distribution.
Returns:
1. A fake temperature zero probability distribution of shape
[batch_size, k, vocab_size]
2. Tensor of shape [batch_size, k] containing the token ids
of the probability 1.0 tokens at each position.
"""
# Simulate temperature 0 probability distribution for target probabilities
# and create target probabilities such that only 1 token id has
# probability 1.0
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
probs = torch.rand(batch_size, k, vocab_size)
_, zero_temperature_token_ids = torch.max(probs, dim=-1)
# set the probability of the tokens with ids in zero_temperature_token_ids
# to 1 and the rest to 0.
target_probs = torch.zeros_like(probs).scatter_(
-1, zero_temperature_token_ids.unsqueeze(-1), 1.0)
return target_probs, zero_temperature_token_ids
def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
token_ids_to_exclude: torch.Tensor):
"""
Returns a tensor of shape [batch_size, k] of fake draft token ids
drawn randomly from a vocab of size vocab_size. We however ensure
that token_ids from token_ids_to_exclude are excluded at the
corresponding positions.
"""
draft_token_ids = torch.empty(batch_size, k, dtype=torch.long)
for i in range(batch_size):
for j in range(k):
# Generate a random token ID excluding token_ids_to_exclude[i, j]
while True:
token_id = torch.randint(0, vocab_size, (1, )).item()
if token_id != token_ids_to_exclude[i, j]:
draft_token_ids[i, j] = token_id
break
return draft_token_ids
def get_acceptance_sampler(
posterior_threshold: float = 0.03,
posterior_alpha: float = 0.9,
strict_mode: bool = False,
) -> TypicalAcceptanceSampler:
"""
Initializes and returns a TypicalAcceptanceSampler.
"""
return TypicalAcceptanceSampler(posterior_threshold, posterior_alpha,
strict_mode)
@pytest.mark.parametrize("k", list(range(1, 6)))
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
device: str):
"""
Tests that the TypicalAcceptancSampler forward succeeds for
different combinations of k, vocab_size, batch_size and num devices.
"""
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler()
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_with_bonus_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
# Verify that sampling succeeds for all cases.
typical_acceptance_sampler(target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
@pytest.mark.parametrize("which_token_ids",
["bonus_token_ids", "draft_token_ids"])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
which_token_ids: str, device: str):
"""
Tests that we throw an exception of the token ids fall outside
the bound of the provided vocabulary.
"""
k = 3
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_with_bonus_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
# Verify that appropriate exceptions are thrown for out
# of bound vocabs.
oob_token_ids = None
if which_token_ids == "bonus_token_ids":
oob_token_ids = bonus_token_ids
elif which_token_ids == "draft_token_ids":
oob_token_ids = draft_token_ids
else:
raise AssertionError()
if above_or_below_vocab_range == "above":
rogue_token_id = vocab_size + 1
elif above_or_below_vocab_range == "below":
rogue_token_id = -1
else:
raise AssertionError()
oob_token_ids[0][0] = rogue_token_id
with pytest.raises(AssertionError):
typical_acceptance_sampler(target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_uniform_target_distribution_accepts_all_tokens(
seed: int, device: str):
"""
Test the TypicalAcceptanceSampler with a uniform target probability
distribution.
This test verifies that when provided with a uniform target probability
distribution, the TypicalAcceptanceSampler accepts all draft tokens. The
entropy of the uniform target distribution being high should lead to all
draft tokens being accepted.
"""
set_random_seed(seed)
k = 3
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_with_bonus_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(
target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
# We are using a uniform target probability distribution.
# For a uniform distribution the entropy is very high and it
# should lead to all draft tokens being accepted. Verify that.
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, -1] == bonus_token_ids.squeeze())
assert torch.all(output_token_ids[:, :k] == draft_token_ids)
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_temperature_zero_target_distribution(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler with a zero-temperature target
probability distribution.
This test verifies that when using a zero-temperature target probability
distribution, where only one token has a probability of 1.0, the
TypicalAcceptanceSampler correctly rejects all draft tokens that do not
match this probability. Additionally, it ensures that when all draft
tokens are rejected, the sampler falls back to greedy sampling to select a
single token from the target distribution.
"""
set_random_seed(seed)
k = 3
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# Simulate temperature 0 probability distribution for target probabilities
# and create target probabilities such that only 1 token id has
# probability 1.0
target_with_bonus_probs, zero_temperature_token_ids = \
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
# Populate draft_token_ids such that they exclude the token_ids
# with probability = 1.0
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
zero_temperature_token_ids)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
# The target probaility distribution is a temperature zero distribution
# with zero entroy. Since our draft token ids don't match the probability
# 1.0 tokens in the target distribution we will reject all of them and
# fallback to the greedy sampling for selecting 1 token for each sequence.
# Verify the same.
output_token_ids = typical_acceptance_sampler(
target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, -1] == -1)
assert torch.all(output_token_ids[:, 0] == zero_temperature_token_ids[:,
0])
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_mixed_target_distribution(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler with a mixed target probability
distribution.
This test ensures that the TypicalAcceptanceSampler handles a mixed
target probability distribution correctly. Specifically, it uses a
zero-temperature distribution for some sequences and a uniform
distribution for others. The test verifies that:
- For sequences with a zero-temperature distribution, only the token
with a probability of 1.0 is accepted, and all other tokens are rejected.
- For sequences with a uniform distribution, all draft tokens are
accepted.
"""
set_random_seed(seed)
k = 3
batch_size = 4
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# For sequences 0 and 2 set the distribution to a temperature
# zero distribution. For sequences 1 and 3 set it to a uniform
# distribution.
target_with_bonus_probs, zero_temperature_token_ids = \
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
target_probs = target_with_bonus_probs[:, :-1]
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
zero_temperature_token_ids)
uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32)
target_probs[[1, 3]] = uniform_probs
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(
target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
# verify the shape of output_token_ids
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
# For sequences 0 and 2 verify that only 1 token is accepted
# which is the token with probability 1.0 in the target distribution
# at position 0.
assert torch.all(output_token_ids[[0, 2], 1:] == -1)
assert (torch.all(output_token_ids[[0, 2],
0] == zero_temperature_token_ids[[0, 2],
0]))
# For sequences 1 and 3 verify that all tokens are accepted since the
# target probability distribution is uniform. In addition verify that
# we also accept the bonus tokens.
assert torch.all(
output_token_ids[[1, 3], :-1] == draft_token_ids[[1, 3], :])
assert torch.all(output_token_ids[[1, 3], -1] != -1)
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_accept_tokens_partially(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler's behavior when only a subset of draft
tokens should be accepted.
This test verifies that the TypicalAcceptanceSampler correctly accepts or
rejects draft tokens based on a zero-temperature target probability
distribution. Specifically, it ensures that:
- When all draft tokens match tokens with a probability of 1.0 in the
target distribution, all draft tokens are accepted.
- When only some draft tokens match tokens with a probability of 1.0 in
the target distribution, only those matching tokens are accepted, and the
rest are rejected.
"""
set_random_seed(seed)
k = 5
batch_size = 1
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# Create a temperature zero target probability distribution and ensure
# all draft token ids correspond to the tokens with 1.0 probability.
# Verify that all of them are accepted.
target_with_bonus_probs, zero_temperature_token_ids = \
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
draft_token_ids = zero_temperature_token_ids
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(
target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
# Next only keep the first 2 draft tokens same as the zero temperature
# tokens. For the remaining 3 choose some other tokens. In the
# response we will expect the first 2 tokens to be the same as the
# draft tokens and the recovered token and rest as -1
draft_token_ids_to_replace = get_draft_token_ids(
batch_size, k, vocab_size, zero_temperature_token_ids)
draft_token_ids = torch.cat(
(draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1)
output_token_ids = typical_acceptance_sampler(
target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2])
assert torch.all(
output_token_ids[:, 2] == target_with_bonus_probs.argmax(-1)[:, 2])
assert torch.all(output_token_ids[:, -3:] == -1)
@pytest.mark.parametrize("seed", list(range(1)))
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_accept_tokens_set_non_default_posteriors(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler with custom posterior thresholds and
alpha values. This test verifies that by modifying the posterior
thresholds and alpha values we can change the acceptance behavior of the
sampler.
"""
set_random_seed(seed)
k = 5
batch_size = 1
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
# Simulate temperature 0 probability distribution for target
# probabilities and create target probabilities such that only 1 token
# id has probability 1.0 and others have a very low probability of
# 0.00001. Populate draft_token_ids such that they exclude the token_ids
# with probability = 1.0. Without any changes to the posterior thresholds
# none of the draft tokens are accepted.
target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist(
batch_size, k + 1, vocab_size)
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
target_probs[target_probs == 0] = 0.00001
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
zero_temperature_token_ids)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 1:-1] == -1)
# Change the posterior threshold values to 0.0 so that we will
# now accept even draft tokens with very low probability in the
# target distribution. Simulate and verify the same.
typical_acceptance_sampler = TypicalAcceptanceSampler(
strict_mode=True, posterior_threshold=0.0, posterior_alpha=0.0)
typical_acceptance_sampler.init_gpu_tensors(device=device)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
assert output_token_ids.shape[0] == batch_size
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_get_recovered_token_ids(seed: int, device: str):
"""
Test the TypicalAcceptanceSampler's method for generating
replacement token IDs.
This test verifies that the `_get_recovered_token_ids` method of the
TypicalAcceptanceSampler correctly identifies the token IDs to be used
as recovered token IDs based on the target probability distribution.
Specifically, it ensures that the method correctly identifies the
tokens with the highest probability for each sequence in the batch.
"""
set_random_seed(seed)
k = 10
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
expected_replacement_tokens = torch.argmax(target_probs, dim=-1)
actual_replacement_tokens = (
typical_acceptance_sampler._get_recovered_token_ids(target_probs))
assert torch.all(expected_replacement_tokens == actual_replacement_tokens)