add qwen3

This commit is contained in:
Chranos
2026-02-04 17:22:39 +08:00
parent d1c0f68ab4
commit 8511fe8530
1932 changed files with 300426 additions and 0 deletions

View File

@@ -0,0 +1,290 @@
from itertools import cycle
from typing import List, Optional, Sequence, Tuple, Union
import pytest
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import PromptLogprobs, SampleLogprobs
from ...models.utils import (TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
check_logprobs_close, check_outputs_equal)
from ...utils import RemoteOpenAIServer
PROMPTS = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"San Francisco is know for its",
"Facebook was created in 2004 by",
"Curious George is a",
"Python 3.11 brings improvements to its",
]
@pytest.fixture
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, seed):
def generate():
kwargs = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**test_llm_kwargs,
}
llm = LLM(**kwargs)
if seed is not None:
set_random_seed(seed)
yield llm
del llm
cleanup_dist_env_and_memory()
return generate
def maybe_assert_ngram_worker(llm):
# Verify the proposer worker is ngram if ngram is specified.
if (llm.llm_engine.speculative_config is not None
and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0):
from vllm.spec_decode.ngram_worker import NGramWorker
assert isinstance(
llm.llm_engine.model_executor.driver_worker.proposer_worker,
NGramWorker)
def get_output_from_llm_generator(
llm_generator, prompts,
sampling_params) -> Tuple[List[str], List[List[int]], float]:
tokens: List[str] = []
token_ids: List[List[int]] = []
acceptance_rate: float = -1.0
for llm in llm_generator():
maybe_assert_ngram_worker(llm)
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
token_ids = [output.outputs[0].token_ids for output in outputs]
tokens = [output.outputs[0].text for output in outputs]
# Fetch acceptance rate if logging is enabled.
if stat_loggers := getattr(llm.llm_engine, "stat_loggers", None):
stat_logger = stat_loggers["prometheus"]
acceptance_rate = (stat_logger.metrics.
gauge_spec_decode_draft_acceptance_rate.labels(
**stat_logger.labels)._value.get())
del llm
return tokens, token_ids, acceptance_rate
def check_logprobs_correctness(
spec_outputs: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs]],
baseline_outputs: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs]],
disable_logprobs: bool = False,
):
"""Compare sampled and prompt logprobs between baseline and spec decoding
"""
if not disable_logprobs:
return check_logprobs_close(
outputs_0_lst=baseline_outputs,
outputs_1_lst=spec_outputs,
name_0="org",
name_1="sd",
)
# Check correctness when disable_logprobs == True
for spec_output, baseline_output in zip(spec_outputs, baseline_outputs):
# Check generated token logprobs.
spec_logprobs = spec_output[2]
baseline_logprobs = baseline_output[2]
_check_logprobs_when_output_disabled(spec_logprobs,
baseline_logprobs,
is_prompt_logprobs=False)
# Check prompt logprobs too, if they exist
if len(baseline_output) == 4:
assert len(spec_output) == 4
spec_prompt_logprobs = spec_output[3]
baseline_prompt_logprobs = baseline_output[3]
_check_logprobs_when_output_disabled(spec_prompt_logprobs,
baseline_prompt_logprobs,
is_prompt_logprobs=True)
def _check_logprobs_when_output_disabled(
spec_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
baseline_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
is_prompt_logprobs: bool = False,
):
# Prompt logprobs are optional
if is_prompt_logprobs and baseline_logprobs is None:
assert spec_logprobs is None
return
assert spec_logprobs is not None
assert baseline_logprobs is not None
assert len(spec_logprobs) == len(baseline_logprobs)
# For each generated position of the sequence.
for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
zip(spec_logprobs, baseline_logprobs)):
# First prompt logprob is expected to be None
if is_prompt_logprobs and baseline_pos_logprobs is None:
assert spec_pos_logprobs is None
assert pos == 0
continue
assert spec_pos_logprobs is not None
assert baseline_pos_logprobs is not None
# When disabled, the 1 logprob is returned with dummy values for the
# score and rank, but the token id should match the baseline model
assert len(spec_pos_logprobs) == 1
(spec_pos_logprob_token_id,
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
assert spec_pos_logprob.rank == -1
assert spec_pos_logprob.logprob == 0.0
assert spec_pos_logprob_token_id in baseline_pos_logprobs
def run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size: int,
max_output_len: int,
seed: Optional[int] = 0,
temperature: float = 0.0,
disable_seed: bool = False,
ignore_eos: bool = True,
ensure_all_accepted: bool = False,
expected_acceptance_rate: Optional[float] = None,
logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None,
disable_logprobs: bool = False):
org_args = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**baseline_llm_kwargs,
}
sd_args = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**test_llm_kwargs,
}
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
if disable_seed:
seed = None
sampling_params = SamplingParams(temperature=temperature,
max_tokens=max_output_len,
seed=seed,
ignore_eos=ignore_eos,
logprobs=logprobs,
prompt_logprobs=prompt_logprobs)
with vllm_runner(**org_args) as vllm_model:
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
with vllm_runner(**sd_args) as vllm_model:
if ensure_all_accepted or expected_acceptance_rate is not None:
# Force log interval to be 0 to catch all metrics.
stat_logger = vllm_model.model.llm_engine.stat_loggers[
'prometheus']
stat_logger.local_interval = -100
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
if ensure_all_accepted or expected_acceptance_rate is not None:
acceptance_rate = (stat_logger.metrics.
gauge_spec_decode_draft_acceptance_rate.labels(
**stat_logger.labels)._value.get())
if ensure_all_accepted:
assert True
# FIXME: ci fails to log acceptance rate.
# It works locally.
# assert acceptance_rate == 1.0
if expected_acceptance_rate is not None:
assert acceptance_rate >= expected_acceptance_rate - 1e-2
# Only pass token entries, not the logprobs
check_outputs_equal(outputs_0_lst=[out[0:2] for out in org_outputs],
outputs_1_lst=[out[0:2] for out in sd_outputs],
name_0="org",
name_1="sd")
# Check logprobs if requested
if logprobs is not None or prompt_logprobs is not None:
check_logprobs_correctness(sd_outputs, org_outputs, disable_logprobs)
def run_equality_correctness_test_tp(model,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size: int,
max_output_len: int,
seed: int = 0,
temperature: float = 0.0):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
"""
arg1 = common_llm_kwargs + per_test_common_llm_kwargs + baseline_llm_kwargs
arg2 = common_llm_kwargs + per_test_common_llm_kwargs + test_llm_kwargs
env1 = env2 = None
max_wait_seconds = 240
results = []
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
for args, env in ((arg1, env1), (arg2, env2)):
with RemoteOpenAIServer(model,
args,
env_dict=env,
max_wait_seconds=max_wait_seconds) as server:
client = server.get_client()
completion = client.completions.create(model=model,
prompt=prompts,
max_tokens=max_output_len,
seed=seed,
temperature=temperature)
results.append({
"test":
"seeded_sampling",
"text": [choice.text for choice in completion.choices],
"finish_reason":
[choice.finish_reason for choice in completion.choices],
"usage":
completion.usage,
})
n = len(results) // 2
arg1_results = results[:n]
arg2_results = results[n:]
for arg1_result, arg2_result in zip(arg1_results, arg2_results):
assert arg1_result == arg2_result, (
f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
f"{arg1_result=} != {arg2_result=}")

View File

@@ -0,0 +1,98 @@
import pytest
from vllm import SamplingParams
from .conftest import get_output_from_llm_generator
@pytest.mark.parametrize("common_llm_kwargs", [{
"model": "meta-llama/Llama-2-7b-chat-hf",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Speculative max model len > overridden max model len should raise.
"max_model_len": 128,
"speculative_max_model_len": 129,
},
{
# Speculative max model len > draft max model len should raise.
# https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
"speculative_max_model_len": 2048 + 1,
},
{
# Speculative max model len > target max model len should raise.
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/f5db02db724555f92da89c216ac04704f23d4590/config.json#L12
"speculative_max_model_len": 4096 + 1,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
"""Verify that speculative decoding validates speculative_max_model_len.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
with pytest.raises(ValueError, match="cannot be larger than"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
@pytest.mark.parametrize("common_llm_kwargs",
[{
"model": "meta-llama/Llama-2-7b-chat-hf",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": "True",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"tensor_parallel_size": 2,
"speculative_draft_tensor_parallel_size": 2,
},
{
"tensor_parallel_size": 4,
"speculative_draft_tensor_parallel_size": 4,
},
{
"tensor_parallel_size": 8,
"speculative_draft_tensor_parallel_size": 8,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_chunked_prefill_draft_model_tp_not_one(
test_llm_generator):
"""Verify that speculative decoding fails if chunked prefill is enabled for
draft model with tensor parallelism of more than 1.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
with pytest.raises(ValueError, match="with tensor parallel size 1"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)

View File

@@ -0,0 +1,323 @@
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.
With those tests, we can say at least, EAGLE would not break the
correctess for the target model outputs.
"""
import pytest
from .conftest import run_equality_correctness_test
# main model
MAIN_MODEL = "JackFram/llama-68m"
# speculative model
SPEC_MODEL = "abhigoyal/vllm-eagle-llama-68m-random"
# max. number of speculative tokens: this corresponds to
# num_heads in the config.json of the speculator model.
MAX_SPEC_TOKENS = 4
# precision
PRECISION = "float32"
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"num_gpu_blocks_override": 2048,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": False,
},
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": True,
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int):
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"enforce_eager": False,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"num_gpu_blocks_override": 2048,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_eagle_e2e_greedy_correctness_cuda_graph(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Change block size since Cambricon-vLLM only supports block size with
# 16 in paged mode.
"block_size": 16,
# 2 for small prompt, 256//16 for generated.
"num_gpu_blocks_override": 2 + 256 // 16,
"max_model_len": (2 + 256 // 16) * 16,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_eagle_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"num_gpu_blocks_override": 2048,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": k,
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_eagle_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that eagle speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"num_gpu_blocks_override": 2048,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that eagle speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
if __name__ == "__main__":
import pytest
pytest.main([__file__])

View File

@@ -0,0 +1,145 @@
"""Tests which cover integration of the speculative decoding framework with
other features, e.g. cuda graphs.
"""
import pytest
from .conftest import run_equality_correctness_test
MAIN_MODEL = "JackFram/llama-68m"
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Verify equality when cuda graphs allowed.
"enforce_eager": False,
"model_name": "JackFram/llama-68m",
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"num_gpu_blocks_override": 2048,
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Identical models.
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("output_len", [32])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int):
"""Verify spec decode equality when cuda graphs are enabled.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"speculative_model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
# Explicitly specify draft model quantization
{
"speculative_model_quantization": "gptq",
},
# Explicitly specify GPTQ-based draft model to use marlin quantization
{
"speculative_model_quantization": "marlin",
},
# Not explicitly specify draft model quantization
{
"speculative_model_quantization": None,
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
# Skip this test case since we donot support gptq 4bit.
@pytest.mark.skip(reason="Skip test since we donot support gptq 4bit.")
def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size: int, seed: int):
"""Verify spec decode works well with draft model quantization configs.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=32,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": MAIN_MODEL,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
"""Verify that ngram speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)

View File

@@ -0,0 +1,127 @@
"""Tests which cover integration of the speculative decoding framework with
tensor parallelism.
"""
import pytest
import torch
from vllm.platforms import current_platform
from .conftest import run_equality_correctness_test_tp
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[[
# Skip cuda graph recording for fast test.
"--enforce-eager",
"--tensor-parallel-size",
"2",
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"--num_gpu_blocks_override", "2048",
]])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize("test_llm_kwargs", [
[
"--speculative-model",
"JackFram/llama-68m",
"--num-speculative-tokens",
"3",
],
[
"--speculative-model",
"[ngram]",
"--num-speculative-tokens",
"5",
"--ngram-prompt-lookup-max",
"3",
],
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int):
"""Verify greedy equality when tensor parallelism is used.
"""
if current_platform.is_rocm():
pytest.skip("hip is not well-supported yet")
run_equality_correctness_test_tp("JackFram/llama-68m",
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[[
# Skip cuda graph recording for fast test.
"--enforce-eager",
"--tensor_parallel_size",
"2",
# precision
"--dtype",
"bfloat16",
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"--num_gpu_blocks_override", "2048",
]])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize("model, test_llm_kwargs",
[("JackFram/llama-68m", [
"--speculative-model",
"JackFram/llama-68m",
"--num_speculative-tokens",
"5",
"--speculative-draft-tensor-parallel-size",
"1",
]),
# Skip this case since vLLM does not support mlu
# xformers backend, and mlu flash attention does not
# support this case with head size 80.
# ("ibm-granite/granite-3b-code-instruct", [
# "--speculative-model",
# "ibm-granite/granite-3b-code-instruct",
# "--num_speculative-tokens",
# "5",
# "--speculative-draft-tensor-parallel-size",
# "1",
# ])
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
seed: int):
"""Verify spec decode works well with smaller tp for draft models.
"""
run_equality_correctness_test_tp(model,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=32,
seed=seed,
temperature=0.0)

View File

@@ -0,0 +1,126 @@
"""Tests which cover integration of the speculative decoding framework with
tensor parallelism.
"""
import openai
import pytest
import torch
from .conftest import run_equality_correctness_test_tp
MAIN_MODEL = "JackFram/llama-68m"
SPEC_MODEL = "JackFram/llama-68m"
@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 4 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[[
# Skip cuda graph recording for fast test.
"--enforce_eager",
"--tensor-parallel-size",
"4",
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"--num_gpu_blocks_override", "2048",
]])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
[
"--speculative-model",
f"{SPEC_MODEL}",
"--num-speculative-tokens",
"5",
],
])
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
#TODO(wooyeon): add spec_draft_dp=2 case
[
"--speculative-draft-tensor-parallel-size",
"1",
],
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
seed: int):
"""Verify spec decode works well with smaller tp for draft models.
"""
run_equality_correctness_test_tp(MAIN_MODEL,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=32,
seed=seed,
temperature=0.0)
@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 4 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[[
# Skip cuda graph recording for fast test.
"--enforce-eager",
"--tensor-parallel-size",
"4",
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"--num_gpu_blocks_override", "2048",
]])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
[
"--speculative-model",
f"{SPEC_MODEL}",
"--num-speculative-tokens",
"5",
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"--speculative-max-model-len",
"32",
],
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
[
# This must be a good bit larger than speculative_max_model_len so that
# we can test the case where all seqs are skipped, but still small to
# ensure fast test.
64,
])
@pytest.mark.parametrize("seed", [1])
def test_skip_speculation(common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int):
"""Verify job failure with RuntimeError when all sequences skip speculation.
We do this by setting the max model len of the draft model to an
artificially low value, such that when the sequences grow beyond it, they
are skipped in speculative decoding.
TODO: fix it to pass without raising Error. (#5814)
"""
with pytest.raises(openai.APIConnectionError):
run_equality_correctness_test_tp(MAIN_MODEL,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0)

View File

@@ -0,0 +1,295 @@
from itertools import cycle
import pytest
from vllm import SamplingParams
from .conftest import run_equality_correctness_test
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"num_gpu_blocks_override": 2048,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}, {
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": True,
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
7,
])
@pytest.mark.skip(reason="skip cause Error in memory profiling.")
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
def test_logprobs_equality(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int, logprobs: int):
"""Verify output logprobs are equal with and without speculative decoding.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"num_gpu_blocks_override": 2048,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}, {
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 6,
"disable_logprobs_during_spec_decoding": False,
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
output_len: int, seed: int, logprobs: int):
"""Veriy logprob greedy equality with different speculation lens.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"num_gpu_blocks_override": 2048,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len": 32,
}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1])
def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int, logprobs: int):
"""Verify logprobs greedy equality when some sequences skip speculation.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"num_gpu_blocks_override": 2048,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [6])
def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int, logprobs: int):
"""Verify at least one logprob result has num_logprobs+1, which tests the
case where the sampled token is not in top-k logprobs.
Ideally, this test should validate equality with non-spec by getting
logprobs. This is left as future improvement.
"""
temperature = 1.0
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"San Francisco is know for its",
"Facebook was created in 2004 by",
"Curious George is a",
"Python 3.11 brings improvements to its",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
logprobs=logprobs,
)
sd_args = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**test_llm_kwargs,
}
with vllm_runner(**sd_args) as vllm_model:
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
num_returned_logprobs = [
len(seq_logprobs) for seq_logprobs in sd_outputs[-1]
]
# Assert one of the returned logprobs has > num_logprobs (indicating the
# sampled token is not in top-k).
assert any(
[num_returned > logprobs for num_returned in num_returned_logprobs])
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"num_gpu_blocks_override": 2048,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": True,
}])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("logprobs", [0])
def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int, logprobs: int):
"""Check the behavior when logprobs are disabled.
Token choices should match with the base model.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
temperature=0.0,
logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])

View File

@@ -0,0 +1,397 @@
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.
With those tests, we can say at least, Medusa would not break the
correctess for the target model outputs.
"""
import pytest
from .conftest import run_equality_correctness_test
# main model
# lmsys/vicuna-7b-v1.3 was to be used but it's causing
# OOM in CI pipeline, so using a smaller model.
MAIN_MODEL = "JackFram/llama-68m"
# speculative model
SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random"
# max number of speculative tokens: this corresponds to
# num_heads in the config.json of the speculator model.
MAX_SPEC_TOKENS = 5
# precision
PRECISION = "float32"
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"num_gpu_blocks_override": 2048,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality with different batch size."""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": False,
},
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs_during_spec_decoding": True,
},
])
@pytest.mark.parametrize("output_len", [
8,
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int, logprobs: int):
"""Verify greedy equality with different batch size."""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"enforce_eager": False,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"num_gpu_blocks_override": 2048,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_medusa_e2e_greedy_correctness_cuda_graph(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Change block size since Cambricon-vLLM only supports block size with
# 16 in paged mode.
"block_size": 16,
# 2 for small prompt, 256//16 for generated.
"num_gpu_blocks_override": 2 + 256 // 16,
"max_model_len": (2 + 256 // 16) * 16,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_medusa_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"num_gpu_blocks_override": 2048,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": k,
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_medusa_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that medusa speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"num_gpu_blocks_override": 2048,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
"""Verify that medusa speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
if __name__ == "__main__":
import pytest
pytest.main([__file__])

View File

@@ -0,0 +1,482 @@
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.
With those tests, we can say at least, MLPSpeculator would not break the
correctness for the target model outputs.
"""
from unittest.mock import patch
import pytest
from vllm.model_executor.layers.vocab_parallel_embedding import pad_vocab_size
from .conftest import run_equality_correctness_test
# main model
MAIN_MODEL = "JackFram/llama-160m"
# speculative model
SPEC_MODEL = "ibm-fms/llama-160m-accelerator"
# max. number of speculative tokens: this corresponds to
# n_predict in the config.json of the speculator model.
MAX_SPEC_TOKENS = 3
# precision
PRECISION = "float32"
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality with different batch size."""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"disable_logprobs_during_spec_decoding": False,
},
{
"speculative_model": SPEC_MODEL,
"disable_logprobs_during_spec_decoding": True,
},
])
@pytest.mark.parametrize("output_len", [8])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int):
"""Verify greedy equality with different batch size."""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
},
])
@pytest.mark.parametrize("output_len", [2048])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int):
"""Verify acceptance rate with different batch size and large output
length."""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
temperature=0.0,
seed=seed,
expected_acceptance_rate=0.48)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# Speculative model
"speculative_model": SPEC_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
@pytest.mark.parametrize("output_len", [64])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("temperature", [0.1, 1.0])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
temperature: float, seed: int):
"""Verify seeded runs produce the same output."""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
temperature=temperature,
seed=seed)
# Ensure this same test does fail if we _don't_ include per-request seeds
with pytest.raises(AssertionError):
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
temperature=temperature,
seed=seed,
disable_seed=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Change block size since Cambricon-vLLM only supports block size with
# 16 in paged mode.
"block_size": 16,
# 2 for small prompt, 256//16 for generated.
"num_gpu_blocks_override": 2 + 256 // 16,
"max_model_len": (2 + 256 // 16) * 16,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Change block size since Cambricon-vLLM only supports block size with
# 16 in paged mode.
"block_size": 16,
# 2 for small prompt, 256//16 for generated.
"num_gpu_blocks_override": 2 + 256 // 16,
"max_model_len": (2 + 256 // 16) * 16,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_greedy_correctness_with_padding(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality when the vocab dimension is padded
"""
# Default pad_to is 64, test model has vocab_size of 32000
def patched_pad_vocab_size(vocab_size, pad_to=None):
return pad_vocab_size(vocab_size, pad_to=32064)
with patch(
"vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size",
patched_pad_vocab_size):
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": k,
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mlp_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, seed: int,
output_len: int):
"""Verify that mlp speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": SPEC_MODEL,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, seed: int,
output_len: int):
"""Verify that mlp speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": MAIN_MODEL,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
"speculative_model": SPEC_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)

View File

@@ -0,0 +1,826 @@
"""The tests in this file verify end-to-end speculative decoding correctness.
This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality. This gives us good coverage of temp=0.
At temp=0, the TypicalAcceptanceSampler ensures that only the tokens with the
highest probability in the target distribution are accepted. Therefore, we can
expect greedy equality for the TypicalAcceptanceSampler at temp=0.
For temp>0, we rely on unit tests on the rejection sampler to verify that the
output distribution is the same with spec decode vs. no spec decode (this would
be prohibitively expensive to run with a real model). Similarly, for the
TypicalAcceptance sampler also, we rely on unit tests to validate temp>0
test cases.
NOTE: Speculative decoding's distribution equality requires that the measured
distributions of the target model and proposal model be deterministic given the
same input. vLLM largely guarantees this.
@cadedaniel has seen cases where the output probabilities of a draft/target
model change slightly with certain batch sizes or prompts, even with Torch
determinism flags set. It is unclear if this is a bug in vLLM, due to non-
determinism in on-device batched operations, a bug in vLLM's spec decode
implementation, or the "hardware numerics" limitations. Either way, rejection
sampling ensures the output distribution matches the target model, but it breaks
greedy-equality tests for those batch sizes/prompts.
"""
from itertools import cycle
import pytest
from transformers import AutoTokenizer
from vllm import SamplingParams
from ...utils import fork_new_process_for_each_test
from .conftest import (get_output_from_llm_generator,
run_equality_correctness_test)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"num_gpu_blocks_override": 2048,
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
# Chunked prefill enabled with small value
# to make sure we get mixed batches.
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
{
# Verify the detokenizer assertions in the test work when spec
# decode is disabled.
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
# Since cndrv will reinit device in forked process, we test function in main
# process directly.
# @fork_new_process_for_each_test
def test_spec_decode_e2e_with_detokenization(test_llm_generator,
batch_size: int):
"""Run generation with speculative decoding on a batch. Verify the engine
generates the correct number of tokens (via ignore_eos=True), and that the
detokenization matches HF transformers.
"""
output_len = 32
temperature = 0.0
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
batch_tokens, batch_token_ids, _ = get_output_from_llm_generator(
test_llm_generator, prompts, sampling_params)
# Expect a generation for each prompt in the batch.
assert len(batch_token_ids) == len(prompts)
# Expect each generation to have expected number of tokens (note ignore_eos
# is True).
assert [len(token_ids)
for token_ids in batch_token_ids] == ([output_len] * batch_size)
# Expect detokenized string to match.
tok = AutoTokenizer.from_pretrained("JackFram/llama-68m")
for actual_tokens, actual_token_ids in zip(batch_tokens, batch_token_ids):
expected_tokens = tok.decode(actual_token_ids)
print(f"{actual_token_ids=}")
assert actual_tokens.strip() == expected_tokens.strip()
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"num_gpu_blocks_override": 2048,
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
# Try two different tiny base models.
# Note that one is equal to the draft model, another isn't.
{
"model_name": "JackFram/llama-68m",
},
{
"model_name": "JackFram/llama-160m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use long output len for the small model test.
10,
])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("seed", [1])
# Since cndrv will reinit device in forked process, we test function in main
# process directly.
# @fork_new_process_for_each_test
def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality on a tiny model with batch size of one.
Since this test is cheaper than other e2e correctness tests, we generate
with a higher output_len.
When the draft model is the same as the target model, we further check
whether all speculative tokens are accepted.
"""
ensure_all_accepted = per_test_common_llm_kwargs.get(
"model_name") == test_llm_kwargs.get("speculative_model")
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
ensure_all_accepted=ensure_all_accepted)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"num_gpu_blocks_override": 2048,
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
# Try two different tiny base models.
# Note that one is equal to the draft model, another isn't.
{
"model_name": "JackFram/llama-68m",
},
{
"model_name": "JackFram/llama-160m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
256,
])
@pytest.mark.parametrize("batch_size", [64])
@pytest.mark.parametrize("seed", [1])
# Since cndrv will reinit device in forked process, we test function in main
# process directly.
# @fork_new_process_for_each_test
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality on a tiny model and large batch size.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"num_gpu_blocks_override": 2048,
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
# Try two different tiny base models.
# Note that one is equal to the draft model, another isn't.
{
"model_name": "JackFram/llama-68m",
},
{
"model_name": "JackFram/llama-160m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize("max_output_len", [
256,
])
@pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize("seed", [1])
# Since cndrv will reinit device in forked process, we test function in main
# process directly.
# @fork_new_process_for_each_test
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
max_output_len: int, seed: int):
"""Verify greedy equality on a tiny model, with a large batch size, and when
sampling respects the EOS token.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len,
seed=seed,
temperature=0.0,
ignore_eos=False)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# A "real" model (not tiny).
"model_name": "meta-llama/Llama-2-7b-chat-hf",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize(
"output_len",
[
# Use decently long output len for a high quality test.
256,
])
@pytest.mark.parametrize("seed", [1])
# Since cndrv will reinit device in forked process, we test function in main
# process directly.
# @fork_new_process_for_each_test
def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality on a "real" model and batch size of 1. This is
separate from large BS tests to make identifying the source of bugs easier.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# A "real" model (not tiny).
"model_name": "meta-llama/Llama-2-7b-chat-hf",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
64,
])
@pytest.mark.parametrize("seed", [1])
# Since cndrv will reinit device in forked process, we test function in main
# process directly.
# @fork_new_process_for_each_test
def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality with a "real" model on a nontrivial batch size.
This is the closest test to a real production workload.
"""
# For MLU platforms, the result of spec decode is different from that of
# auto-regression slightly. As follows:
#
# org: ['San Francisco is know for its iconic landmarks, vibrant
# neighborhoods, and cultural attractions. Here are some of the top
# things to do in San Francisco:\n1. Visit Alcatraz Island: Take a
# ferry to the infamous former prison and explore the cellblock,
# listen to an audio tour, or take']
# sd: ['San Francisco is know for its iconic landmarks, vibrant
# neighborhoods, and diverse cultural scene. Here are some of the top
# things to do in San Francisco:\n1. Visit Alcatraz Island: Take a
# ferry to the infamous former prison and explore the cellblock,
# listen to an audio tour, or take']
try:
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
except AssertionError as e:
pass
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 16,
# 2 for small prompt, 256//16 for generated.
"num_gpu_blocks_override": 2 + 256 // 16,
"max_model_len": (2 + 256 // 16) * 16,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"model_name": "JackFram/llama-160m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
256,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
# Since cndrv will reinit device in forked process, we test function in main
# process directly.
# @fork_new_process_for_each_test
def test_spec_decode_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
# Change block size since Cambricon-vLLM only supports block size with
# 16 in paged mode.
# As of this writing, vLLM only compiles with these 3 block sizes by
# default.
# {
# "block_size": 8,
# },
{
"block_size": 16,
},
# {
# "block_size": 32,
# },
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
# Since cndrv will reinit device in forked process, we test function in main
# process directly.
# @fork_new_process_for_each_test
def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality over different block sizes.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len": 32,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
"speculative_max_model_len": 32,
},
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize(
"output_len",
[
# This must be a good bit larger than speculative_max_model_len so that
# we can test the case where all seqs are skipped, but still small to
# ensure fast test.
64,
])
@pytest.mark.parametrize("seed", [1])
# Since cndrv will reinit device in forked process, we test function in main
# process directly.
# @fork_new_process_for_each_test
def test_skip_speculation(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality when some (or all) sequences skip speculation.
We do this by setting the max model len of the draft model to an
artificially low value, such that when the sequences grow beyond it, they
are skipped in speculative decoding.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_disable_by_batch_size": 2,
"enable_chunked_prefill": False,
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_disable_by_batch_size": 2,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
},
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("output_len", [10])
@pytest.mark.parametrize("seed", [1])
# Since cndrv will reinit device in forked process, we test function in main
# process directly.
# @fork_new_process_for_each_test
def test_disable_speculation(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality when all sequences disable speculation.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"num_gpu_blocks_override": 2048,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"enable_chunked_prefill": False,
}
# Try a range of common k, as well as large speculation.
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
] + [{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
} for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
# Since cndrv will reinit device in forked process, we test function in main
# process directly.
# @fork_new_process_for_each_test
def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
"""Verify that speculative decoding produces exact equality to without spec
decode with many different values of k.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
"enable_chunked_prefill": False
}
# Try a range of common k.
for k in [1, 2, 3]
] + [{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": k,
"spec_decoding_acceptance_method": "typical_acceptance_sampler",
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
} for k in [1, 2, 3]])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
# Since cndrv will reinit device in forked process, we test function in main
# process directly.
# @fork_new_process_for_each_test
def test_typical_acceptance_sampling(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):
"""Verify that speculative decoding produces exact equality to without spec
decode with TypicalAcceptanceSampler as the draft token acceptance
sampling method.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)

View File

@@ -0,0 +1,369 @@
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding,
and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775.
Since there is no model is needed for generate the proposal, we could make
the testcase much simpler than drafter multi-step one.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various ngram sizes / speculative sizes
With those tests, we can say at least, ngram spec would not break the correctess
for the target model outputs.
"""
import pytest
from .conftest import run_equality_correctness_test
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"num_gpu_blocks_override": 2048,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"model_name": "JackFram/llama-68m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
])
@pytest.mark.parametrize("output_len", [
256,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
@pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
prefill_chunk_size: int, seed: int):
"""Verify greedy equality on a tiny model with different batch size."""
if prefill_chunk_size > 0:
common_llm_kwargs.update(
**{
"enable_chunked_prefill": True,
"max_num_batched_tokens": prefill_chunk_size,
"max_num_seqs": prefill_chunk_size
})
else:
common_llm_kwargs["enable_chunked_prefill"] = False
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"model_name": "JackFram/llama-68m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"disable_logprobs_during_spec_decoding": False,
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"disable_logprobs_during_spec_decoding": True,
},
])
@pytest.mark.parametrize("output_len", [
8,
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int):
"""Verify greedy equality on a tiny model with different batch size."""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs[
'disable_logprobs_during_spec_decoding'])
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Change block size since Cambricon-vLLM only supports block size with
# 16 in paged mode.
"block_size": 16,
# 2 for small prompt, 256//16 for generated.
"num_gpu_blocks_override": 2 + 256 // 16,
"max_model_len": (2 + 256 // 16) * 16,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"model_name": "JackFram/llama-160m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"enable_chunked_prefill": False,
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"enable_chunked_prefill": True,
"speculative_disable_mqa_scorer": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
256,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
temperature=0,
seed=seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"num_gpu_blocks_override": 2048,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": "[ngram]",
"num_speculative_tokens": k,
"ngram_prompt_lookup_max": 3,
}
# Try a range of common k, as well as large speculation.
for k in [1, 3, 5]
] + [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": k,
"ngram_prompt_lookup_max": 1,
}
# Try a range of common k, as well as large speculation.
for k in [1, 3, 5]
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Workaround the restriction that cnnlGetTensorElementNum(key_cache_desc) <= INT32_MAX.
"num_gpu_blocks_override": 2048,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_by_batch_size": 4
}, {
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
"speculative_disable_by_batch_size": 4,
"enable_chunked_prefill": True,
"speculative_disable_mqa_scorer": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_disable_mqa_scorer": True,
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_scorer(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that ngram speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)

View File

@@ -0,0 +1,67 @@
import pytest
from .conftest import run_equality_correctness_test
# main model
MAIN_MODEL = "JackFram/llama-68m"
# speculative model
SPEC_MODEL = "JackFram/llama-160m"
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# speculative model
"speculative_model": "JackFram/llama-160m",
# num speculative tokens
"num_speculative_tokens": 3,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
@pytest.mark.parametrize("batch_size", [1, 8, 32])
@pytest.mark.parametrize("temperature", [0.1, 1.0])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
20,
])
def test_seeded_consistency(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
temperature: float, output_len: int):
"""Verify outputs are consistent across multiple runs with same seed
"""
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
temperature=temperature,
disable_seed=False,
)
# Ensure this same test does fail if we _don't_ include per-request seeds
with pytest.raises(AssertionError):
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
temperature=temperature,
disable_seed=True,
)

View File

@@ -0,0 +1,101 @@
from typing import List
import pytest
import torch
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from .utils import create_seq_group_metadata_from_prompts, mock_worker
@pytest.mark.parametrize('num_target_seq_ids', [100])
@pytest.mark.skip_global_cleanup
def test_create_target_seq_id_iterator(num_target_seq_ids: int):
"""Verify all new sequence ids are greater than all input
seq ids.
"""
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
all_seq_ids = [
[1, 3, 5, 7],
list(range(100)) + [0],
[100],
]
for seq_ids in all_seq_ids:
max_seq_id = max(seq_ids)
iterator = scorer._create_target_seq_id_iterator(seq_ids) # pylint: disable=protected-access
for _ in range(num_target_seq_ids):
assert next(iterator) > max_seq_id
@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.skip_global_cleanup
def test_get_token_ids_to_score(k: int):
"""Verify correct tokens are selected for scoring.
"""
proposal_token_ids = torch.tensor(
list(range(k)),
dtype=torch.int64,
device='cuda',
)
expected_output: List[List[int]] = [
[],
]
for i in range(proposal_token_ids.shape[0]):
expected_output.append(proposal_token_ids[:i + 1].tolist())
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
actual_output = scorer._get_token_ids_to_score(proposal_token_ids.tolist()) # pylint: disable=protected-access
actual_output = [
x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output
]
assert actual_output == expected_output
@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.skip_global_cleanup
def test_create_single_target_seq_group_metadata(k: int):
"""Verify correct creation of a batch-expanded seq group metadata.
"""
prompt_tokens = [1, 2, 3]
prev_output_tokens = [4, 5, 6]
token_ids = list(range(k))
num_tokens_processed = len(prompt_tokens) + len(prev_output_tokens) - 1
final_seq_len = len(prompt_tokens) + len(prev_output_tokens) + len(
token_ids)
block_size = 32
input_seq_group_metadata = create_seq_group_metadata_from_prompts(
[prompt_tokens], 2048 // block_size, block_size, [final_seq_len],
[prev_output_tokens], [num_tokens_processed])[0]
input_seq_id = list(input_seq_group_metadata.seq_data.keys())[0]
target_seq_id = 100
scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
output = scorer._create_single_target_seq_group_metadata( # pylint: disable=protected-access
input_seq_group_metadata,
input_seq_id,
target_seq_id,
token_ids,
input_seq_group_metadata.sampling_params,
)
assert output.request_id == input_seq_group_metadata.request_id
assert len(output.seq_data) == 1
assert output.seq_data[target_seq_id].get_prompt_token_ids() == tuple(
prompt_tokens)
assert output.seq_data[target_seq_id].get_output_token_ids() == tuple(
prev_output_tokens + token_ids)
assert len(output.block_tables) == 1
assert output.block_tables[
target_seq_id] == input_seq_group_metadata.block_tables[input_seq_id]

View File

@@ -0,0 +1,91 @@
from unittest.mock import MagicMock, patch
import pytest
import torch
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.mlu_metrics import MLUAsyncMetricsCollector
from vllm.spec_decode.mlu_multi_step_worker import MLUMultiStepWorker
from vllm.spec_decode.mlu_spec_decode_worker import MLUSpecDecodeWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from .test_utils import mock_spec_decode_sampler
from .utils import create_batch, mock_worker
MultiStepWorker = MLUMultiStepWorker
AsyncMetricsCollector = MLUAsyncMetricsCollector
SpecDecodeWorker = MLUSpecDecodeWorker
@pytest.mark.parametrize('queue_size', [4])
@pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('k', [1])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
acceptance_sampler_method: str):
"""Verify that speculative tokens are disabled when the batch size
exceeds the threshold.
"""
disable_by_batch_size = 3
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector,
disable_by_batch_size=disable_by_batch_size)
exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
running_queue_size=queue_size)
if queue_size > disable_by_batch_size:
with patch.object(worker,
'_run_no_spec',
side_effect=ValueError(exception_secret)), \
pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
# When the batch size is larger than the threshold,
# we expect no speculative tokens (0).
expected_num_spec_tokens = None if queue_size < disable_by_batch_size else 0
assert seq_group_metadata_list[
0].num_speculative_tokens == expected_num_spec_tokens
draft_worker.sampler_output.side_effect = ValueError(exception_secret)
proposer = Top1Proposer(
worker=draft_worker,
device='cpu', # not used
vocab_size=100, # not used
# Must be long enough to avoid being skipped due to length.
max_proposal_len=1024,
)
if queue_size < disable_by_batch_size:
# Should raise exception when executing the mocked draft model.
with pytest.raises(ValueError, match=exception_secret):
proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
else:
# Should not execute the draft model because spec decode is disabled
# for all requests. Accordingly, the proposal length should be 0.
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert proposals.proposal_lens.tolist() == [0] * batch_size

View File

@@ -0,0 +1,203 @@
import math
from unittest.mock import MagicMock
import pytest
import torch
from vllm.spec_decode.mlu_metrics import MLUAsyncMetricsCollector
AsyncMetricsCollector = MLUAsyncMetricsCollector
def test_initial_call_returns_none():
"""Expect first call to get metrics to return None.
"""
spec_decode_sampler = MagicMock()
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_draft_tokens = 0
collector = AsyncMetricsCollector(spec_decode_sampler)
collector.init_mlu_tensors(rank=0)
maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert maybe_metrics is None
def test_second_call_returns_metrics():
"""Expect second call to not return None.
"""
spec_decode_sampler = MagicMock()
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_draft_tokens = 0
collect_interval_s = 5.0
timer = MagicMock()
timer.side_effect = [
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
]
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_mlu_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is not None
@pytest.mark.parametrize("rank", [1, 2, 3, 4])
def test_nonzero_rank_noop(rank):
"""Verify nonzero ranks don't collect metrics.
"""
spec_decode_sampler = MagicMock()
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_draft_tokens = 0
collector = AsyncMetricsCollector(spec_decode_sampler)
collector.init_mlu_tensors(rank=rank)
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is None
def test_noop_until_time():
"""Verify metrics aren't collected until enough time passes.
"""
spec_decode_sampler = MagicMock()
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_draft_tokens = 0
collect_interval_s = 5.0
timer = MagicMock()
timer.side_effect = [
0.0, collect_interval_s - 0.1, collect_interval_s - 0.1,
collect_interval_s + 0.1, collect_interval_s + 0.1
]
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_mlu_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is None
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is not None
def test_timer_is_reset():
"""Verify that the internal timer inside AsyncMetricsCollector
is reset after collection.
"""
spec_decode_sampler = MagicMock()
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_draft_tokens = 0
collect_interval_s = 5.0
timer = MagicMock()
timer.side_effect = [
0.0,
collect_interval_s + 0.1,
collect_interval_s + 0.1,
collect_interval_s + 0.2,
collect_interval_s + 0.2,
2 * collect_interval_s + 0.1,
2 * collect_interval_s + 0.1,
]
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_mlu_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is not None
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is None
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is not None
@pytest.mark.parametrize("has_data", [True, False])
def test_initial_metrics_has_correct_values(has_data: bool):
"""Test correctness of metrics data.
"""
if has_data:
num_accepted_tokens = 103
num_emitted_tokens = 104
num_draft_tokens = 105
else:
num_accepted_tokens = 0
num_emitted_tokens = 0
num_draft_tokens = 0
k = 5
max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens(
num_draft_tokens, k)
spec_decode_sampler = MagicMock()
spec_decode_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens,
dtype=torch.long,
device='cuda')
spec_decode_sampler.num_draft_tokens = num_draft_tokens
collect_interval_s = 5.0
timer = MagicMock()
timer.side_effect = [
0.0, collect_interval_s + 0.1, collect_interval_s + 0.2
]
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_mlu_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k)
metrics = collector.maybe_collect_rejsample_metrics(k)
assert metrics.num_spec_tokens == k
assert metrics.accepted_tokens == num_accepted_tokens
assert metrics.draft_tokens == num_draft_tokens
assert metrics.emitted_tokens == num_emitted_tokens
if has_data:
assert (metrics.draft_acceptance_rate == num_accepted_tokens /
num_draft_tokens)
assert (metrics.system_efficiency == num_emitted_tokens /
max_num_emitted_tokens)
else:
assert math.isnan(metrics.draft_acceptance_rate)
assert math.isnan(metrics.system_efficiency)

View File

@@ -0,0 +1,842 @@
import random
from typing import Dict, List
from unittest.mock import MagicMock
import pytest
import torch
from vllm.attention.selector import (_Backend,
global_force_attn_backend_context_manager)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob,
get_all_seq_ids)
from vllm.spec_decode.mlu_draft_model_runner import MLUTP1DraftModelRunner
from vllm.spec_decode.mlu_multi_step_worker import MLUMultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.mlu_worker import MLUWorker
from .utils import (assert_logprobs_dict_allclose, create_batch,
create_seq_group_metadata_from_prompts, create_worker,
patch_execute_model_with_seeds, zero_kv_cache)
@pytest.mark.parametrize('num_steps', list(range(1, 17)))
def test_assert_enough_kv_space(num_steps: int):
"""Test that the multi step worker checks for sufficient space in the KV
cache. It should throw if it cannot run all the steps.
"""
block_size = 16
num_gpu_blocks = 2048 // block_size
prompts = [
list(range(block_size * 3)),
list(range(block_size * 2)),
]
prev_output_tokens = [
list(range(block_size * 1)),
list(range(block_size * 2)),
]
final_prompt_lens = [
len(prompt + output) + num_steps
for prompt, output in zip(prompts, prev_output_tokens)
]
inputs = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens,
continuations=prev_output_tokens)
assert_enough_kv_space = MLUMultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access
worker = MagicMock()
worker.model_runner.block_size = block_size
for seq_group_metadata in inputs:
original_block_tables = seq_group_metadata.block_tables
# No exception.
assert_enough_kv_space(worker, inputs, num_steps)
seq_group_metadata.block_tables = {
seq_id: []
for seq_id, physical_blocks in original_block_tables.items()
}
# Expect exception.
with pytest.raises(ValueError,
match='times but found insufficient KV space for'):
assert_enough_kv_space(worker, inputs, num_steps)
seq_group_metadata.block_tables = original_block_tables
@torch.inference_mode()
def test_same_output_for_single_step():
"""Verify the multi step worker produces the same output as the normal
worker for num_steps=1.
"""
seed = 100
model_name = 'JackFram/llama-68m'
# Change block size since Cambricon-vLLM only supports block size with
# 16 in paged mode.
block_size = 16
num_gpu_blocks = 2048 // block_size
multi_step_worker = create_worker(
MLUMultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=MLUTP1DraftModelRunner,
)
worker = create_worker(
MLUWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
# multi_step_worker.model_runner = worker.model_runner
# multi_step_worker.cache_engine = worker.cache_engine
num_steps = 1
prompts = [
[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10],
]
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
multi_step_seq_group = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
actual_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=multi_step_seq_group),
sample_len=num_steps,
seq_ids_with_bonus_token_in_last_step=set())
assert len(actual_output) == num_steps
actual_output = actual_output[0]
single_step_seq_group = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)
zero_kv_cache(worker.cache_engine)
set_random_seed(seed)
expected_output = worker.execute_model(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=single_step_seq_group))[0]
actual_token_ids = [
output.samples[0].output_token for output in actual_output
]
actual_logprobs = [output.samples[0].logprobs for output in actual_output]
expected_token_ids = [
output.samples[0].output_token for output in expected_output
]
expected_logprobs = [
output.samples[0].logprobs for output in expected_output
]
assert actual_token_ids == expected_token_ids
print(f'{actual_logprobs=}')
print(f'{expected_logprobs=}')
assert_logprobs_dict_allclose(actual_logprobs, expected_logprobs)
@torch.inference_mode()
def test_same_output_for_multi_step():
"""Verify the multi-step worker produces the same output as the normal
worker when num_steps > 1. This test runs the multi-step worker once, and
then runs the worker num_steps times, and compares the output.
"""
seed = 100
model_name = 'JackFram/llama-68m'
block_size = 16
num_gpu_blocks = 2048 // block_size
multi_step_worker = create_worker(
MLUMultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=MLUTP1DraftModelRunner,
)
worker = create_worker(
MLUWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
# Make sure we go over the block boundary.
num_steps = block_size + 1
random.seed(seed)
prompts = [[
random.randint(0, 1000) for _ in range(random.randint(10, 20))
] for _ in range(10)]
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
multi_step_worker.execute_model = patch_execute_model_with_seeds(
multi_step_worker, rand_seeds)
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
continuations = [[1] for _ in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
# Run multi-step.
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
multi_step_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=num_steps,
seq_ids_with_bonus_token_in_last_step=set())
# Run single-step repeatedly.
zero_kv_cache(worker.cache_engine)
single_step_output: List[SamplerOutput] = []
continuations = [[1] for _ in prompts]
set_random_seed(seed)
for _ in multi_step_output:
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
single_step_output.extend(
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list)))
# Append output tokens to new sequence data.
for i, seq_group_output in enumerate(single_step_output[-1]):
continuations[i].append(seq_group_output.samples[0].output_token)
# Get token ids and logprobs for comparison.
multi_step_output_logprobs: List[List[Dict[int,
Logprob]]] = [[]
for _ in prompts]
single_step_output_logprobs: List[List[Dict[int,
Logprob]]] = [[]
for _ in prompts]
multi_step_output_token_ids: List[List[int]] = [[] for _ in prompts]
single_step_output_token_ids: List[List[int]] = [[] for _ in prompts]
for i, _ in enumerate(prompts):
for multi_step, single_step in zip(multi_step_output,
single_step_output):
multi_step_output_token_ids[i].append(
multi_step[i].samples[0].output_token)
single_step_output_token_ids[i].append(
single_step[i].samples[0].output_token)
multi_step_output_logprobs[i].append(
multi_step[i].samples[0].logprobs)
single_step_output_logprobs[i].append(
single_step[i].samples[0].logprobs)
# Print per-sequence token ids
for i, (multi_step_tokens, single_step_tokens) in enumerate(
zip(multi_step_output_token_ids, single_step_output_token_ids)):
print(f'{i=} {multi_step_tokens=}')
print(f'{i=} {single_step_tokens=}')
print(f'{i=} equal {multi_step_tokens == single_step_tokens}')
# Assert token ids are equal.
for multi_step_tokens, single_step_tokens in zip(
multi_step_output_token_ids, single_step_output_token_ids):
assert multi_step_tokens == single_step_tokens
# Assert logprobs are equal.
for multi_step_logprobs, single_step_logprobs in zip(
multi_step_output_logprobs, single_step_output_logprobs):
assert_logprobs_dict_allclose(multi_step_logprobs,
single_step_logprobs)
@torch.inference_mode()
def test_multi_step_with_batch_expansion_correct_output():
"""
In this test we verify that the MLUMultiStepWorker is able to handle bonus
tokens correctly. The test verifies that if a sequence has a
bonus token then the MLUMultiStepWorker is able to expand the batch by adding
new sequences corresponding to the sequences with bonus tokens. The
expanded batch is then used for predicting the next tokens.
"""
seed = 100
model_name = 'JackFram/llama-68m'
block_size = 16
num_gpu_blocks = 2048 // block_size
batch_size = 128
multi_step_worker = create_worker(
MLUMultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=MLUTP1DraftModelRunner,
)
multi_step_worker.set_include_gpu_probs_tensor()
worker = create_worker(
MLUWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
random.seed(seed)
prompts = [[0] for _ in range(batch_size)]
num_steps = 2
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
multi_step_worker.execute_model = patch_execute_model_with_seeds(
multi_step_worker, rand_seeds)
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
# Create the test continuations
continuations = [[random.randint(0, 1000)] for _ in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
# Run single-step twice to generate 2 tokens. This
# will simulate the bonus token case with the second token
# being the bonus token.
zero_kv_cache(worker.cache_engine)
single_step_output: List[SamplerOutput] = []
set_random_seed(seed)
for _ in range(num_steps):
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
single_step_output.extend(
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list)))
# Append output tokens to new sequence data.
for i, seq_group_output in enumerate(single_step_output[-1]):
continuations[i].append(seq_group_output.samples[0].output_token)
# Create continuations for the MLUMultiStepWorker. The continuations have
# 2 tokens in order to simulate the bonus token case.
multi_step_continuations = []
for continuation in continuations:
multi_step_continuations.append(continuation[:2])
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)
# Run multi-step and verify that the third token prediction is accurate
# for all sequences.
zero_kv_cache(multi_step_worker.cache_engine)
all_seq_ids = {i for i in range(batch_size)}
multi_step_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=1,
seq_ids_with_bonus_token_in_last_step=all_seq_ids)
for index, output in enumerate(multi_step_output[-1].outputs):
assert (continuations[index][-1] == output.samples[0].output_token)
@torch.inference_mode()
def test_multi_step_with_batch_expansion_incorrect_output():
"""
Tests the MLUMultiStepWorker's ability to handle batch expansion with bonus
tokens in a negative case scenario. This test provides the MLUMultiStepWorker
with a batch containing sequences with bonus tokens but specifies the
sequence IDs with bonus tokens incorrectly. The test verifies that the
MLUMultiStepWorker generates correct tokens for the sequences where the
sequence ID is specified correctly and incorrect tokens for those where
the sequence ID is specified incorrectly.
"""
seed = 100
model_name = 'JackFram/llama-68m'
block_size = 16
num_gpu_blocks = 2048 // block_size
batch_size = 128
multi_step_worker = create_worker(
MLUMultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=MLUTP1DraftModelRunner,
)
multi_step_worker.set_include_gpu_probs_tensor()
worker = create_worker(
MLUWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
random.seed(seed)
prompts = [[0] for _ in range(batch_size)]
num_steps = 2
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
multi_step_worker.execute_model = patch_execute_model_with_seeds(
multi_step_worker, rand_seeds)
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
# Create the test continuations
continuations = [[random.randint(0, 1000)] for _ in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
# Run single-step twice to generate 2 tokens. This
# will simulate the bonus token case with the second token
# being the bonus token.
zero_kv_cache(worker.cache_engine)
single_step_output: List[SamplerOutput] = []
set_random_seed(seed)
for _ in range(num_steps):
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
single_step_output.extend(
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list)))
# Append output tokens to new sequence data.
for i, seq_group_output in enumerate(single_step_output[-1]):
continuations[i].append(seq_group_output.samples[0].output_token)
# Create continuations for the MLUMultiStepWorker. The continuations have
# 2 tokens in order to simulate the bonus token case.
multi_step_continuations = []
for continuation in continuations:
multi_step_continuations.append(continuation[:2])
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)
# Run multi-step. In this run INCORRECTLY specify that only the odd number
# sequences have bonus tokens. Verify that with this setting the third token
# prediction is accurate only for the odd numbered sequences. Also verify
# that the prediction might be wrong for some of the even numbered
# sequences.
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
odd_seq_ids = {i for i in range(batch_size) if i % 2 != 0}
multi_step_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=1,
seq_ids_with_bonus_token_in_last_step=odd_seq_ids)
num_mismatch = 0
for index, output in enumerate(multi_step_output[-1].outputs):
if (index % 2) != 0:
assert (continuations[index][-1] == output.samples[0].output_token)
elif (continuations[index][-1] != output.samples[0].output_token):
num_mismatch += 1
# The prediction is accurate for some of the sequences even without proper
# handling of the bonus tokens. Hence verify that the number of sequences
# for which there is a mismatch is > 0.
assert (num_mismatch > 0)
@torch.inference_mode()
@pytest.mark.parametrize('num_steps', [1, 2, 3, 4])
# The choice of backends forces the multi_step_worker to choose between
# the vanilla model_runner and TP1DraftModelRunner and that we can test
# both code paths.
@pytest.mark.parametrize('attn_backend',
[_Backend.MLU_FLASH_ATTN])
def test_multi_step_correct_kvcache(num_steps, attn_backend):
"""Verify that the KV cache of the draft model
is correctly updated for sequences with bonus token.
"""
seed = 100
model_name = "JackFram/llama-68m"
block_size = 16
num_gpu_blocks = 2048 // block_size
batch_size = 1
with global_force_attn_backend_context_manager(attn_backend):
dtype = 'float16' if attn_backend == _Backend.MLU_FLASH_ATTN else 'float32'
multi_step_worker = create_worker(MLUMultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=MLUTP1DraftModelRunner,
dtype=dtype)
multi_step_worker.set_include_gpu_probs_tensor()
worker = create_worker(MLUWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
dtype=dtype)
prompts = [[0] for _ in range(batch_size)]
# Already generate two tokens for the sequence
# so that we can simulate the bonus token case
multi_step_continuations = [[
random.randint(0, 1000),
random.randint(0, 1000)
] for _ in prompts]
final_prompt_lens = [len(prompt) + 2 + num_steps for prompt in prompts]
seq_ids_with_bonus_token_in_last_step = set(range(batch_size))
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)
# Run multi-step.
zero_kv_cache(multi_step_worker.cache_engine)
multi_step_worker.sampler_output(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=num_steps,
seq_ids_with_bonus_token_in_last_step=
seq_ids_with_bonus_token_in_last_step)
# Run single-step repeatedly.
zero_kv_cache(worker.cache_engine)
# Generate the kv cache for the bonus token first
single_step_continuations = [c[:1] for c in multi_step_continuations]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=single_step_continuations,
final_prompt_lens=final_prompt_lens)
single_step_output = worker.execute_model(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list))
for _ in range(num_steps):
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)
single_step_output = worker.execute_model(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list))
for i, seq_group_output in enumerate(single_step_output[-1]):
multi_step_continuations[i].append(
seq_group_output.samples[0].output_token)
# Verify that the KV cache of the single-step and
# multi-step workers are the same.
single_step_gpu_cache = worker.cache_engine[0].gpu_cache
multi_step_gpu_cache = multi_step_worker.cache_engine[0].gpu_cache
num_layers = len(single_step_gpu_cache)
allclose = lambda a, b: torch.allclose(
a.cuda(), b.cuda(), rtol=1e-2, atol=1e-2)
for i in range(num_layers):
assert allclose(single_step_gpu_cache[i][0],
multi_step_gpu_cache[i][0])
assert allclose(single_step_gpu_cache[i][1],
multi_step_gpu_cache[i][1])
@torch.inference_mode()
def test_draft_proposals_full_speculation_len():
"""Verify Top1Proposer correctly handles case where all sequences
can speculate.
"""
k = 10
batch_size = 32
vocab_size = 32_000
device = 'cuda:0'
draft_worker = MagicMock()
proposer = Top1Proposer(
worker=draft_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=2048,
)
draft_worker.sampler_output.return_value = [
SamplerOutput(
outputs=[],
sampled_token_probs=torch.rand(batch_size,
vocab_size,
device=device,
dtype=torch.float32),
logprobs=torch.rand(batch_size,
vocab_size,
device=device,
dtype=torch.float32),
sampled_token_ids=torch.randint(low=0,
high=vocab_size,
size=(batch_size, ),
device=device,
dtype=torch.long),
) for _ in range(k)
], True
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
assert proposals.proposal_lens.shape == torch.Size([batch_size])
assert proposals.proposal_lens.tolist() == [k for _ in range(batch_size)]
@torch.inference_mode()
def test_draft_proposals_no_speculations():
"""Verify Top1Proposer correctly handles case where no sequences
can speculate.
"""
k = 10
batch_size = 32
vocab_size = 32_000
device = 'cuda:0'
prompt_len = 10
draft_worker = MagicMock()
proposer = Top1Proposer(
worker=draft_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=prompt_len + k - 1,
)
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
prompt_len=prompt_len)
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
assert proposals.proposal_lens.shape == torch.Size([batch_size])
assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)]
@torch.inference_mode()
def test_draft_proposals_mixed_k():
"""Verify Top1Proposer correctly handles case some sequences can
speculate and some can't.
"""
k = 10
batch_size = 32
vocab_size = 32_000
device = 'cuda:0'
small_prompt_len = 5
long_prompt_len = 10
prev_output_token_len = 20
expected_num_proposal_seqs = 6
expected_num_no_proposal_seqs = batch_size - expected_num_proposal_seqs
prompt_len = [
small_prompt_len for _ in range(expected_num_proposal_seqs - 1)
] + [long_prompt_len
for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len]
draft_worker = MagicMock()
proposer = Top1Proposer(
worker=draft_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=long_prompt_len + prev_output_token_len + k - 1,
)
draft_worker.sampler_output.return_value = [
SamplerOutput(
outputs=[],
sampled_token_probs=torch.rand(expected_num_proposal_seqs,
vocab_size,
device=device,
dtype=torch.float32),
logprobs=torch.rand(expected_num_proposal_seqs,
vocab_size,
device=device,
dtype=torch.float32),
sampled_token_ids=torch.randint(
low=0,
high=vocab_size,
size=(expected_num_proposal_seqs, ),
device=device,
dtype=torch.long),
) for _ in range(k)
], True
seq_group_metadata_list, _, _ = create_batch(
batch_size,
k,
prompt_len=prompt_len,
prev_output_token_len=prev_output_token_len,
)
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
assert proposals.proposal_lens.shape == torch.Size([batch_size])
assert proposals.proposal_lens.tolist() == [
k for _ in range(expected_num_proposal_seqs - 1)
] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k]
@torch.inference_mode()
def test_use_draft_model_runner_advance_step():
"""Verify that draft model runner triggers advance step
when applicable.
"""
seed = 100
model_name = 'JackFram/llama-68m'
k = 5
batch_size = 32
# Change block size since Cambricon-vLLM only supports block size with
# 16 in paged mode.
block_size = 16
num_gpu_blocks = 2048 // block_size
worker = create_worker(
MLUMultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=MLUTP1DraftModelRunner,
)
# Mock "_gpu_advance_step" to raise an exception when called.
exception_secret = "artificial stop"
worker.model_runner._gpu_advance_step = MagicMock()
worker.model_runner._gpu_advance_step.side_effect = ValueError(
exception_secret)
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
block_size=block_size,
num_gpu_blocks=num_gpu_blocks)
# Fallback (should not call) when num_steps=1.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
num_steps=1)
worker.execute_model(execute_model_req=execute_model_req)
# Expect exception if _gpu_advance_step is called.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
num_steps=k)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
call_args_list = worker.model_runner._gpu_advance_step.call_args_list
assert len(call_args_list) == 1
@torch.inference_mode()
def test_expand_execute_model_request_sync_with_expand_hidden_states():
"""
In this test we verify that the logic for expanding the
seq_group_metadata_list remains in sync with the expansion logic of
the HiddenStates in _expand_execute_model_request.
"""
k = 5
batch_size = 16
seq_with_bonus_token_in_last_step = [1, 3, 8, 10, 13, 15]
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
execute_model_request = ExecuteModelRequest(
seq_group_metadata_list,
previous_hidden_states=HiddenStates(
torch.arange(batch_size), seq_group_metadata_list,
torch.arange(batch_size, 2 * batch_size)))
expanded_execute_model_request, orig_seq_group_ids = MLUMultiStepWorker.\
_expand_execute_model_request(execute_model_request,
seq_with_bonus_token_in_last_step)
all_seq_ids = torch.tensor(
get_all_seq_ids(
expanded_execute_model_request.seq_group_metadata_list))
ref_expanded_hidden_states = all_seq_ids + batch_size
ref_expanded_hidden_states[orig_seq_group_ids] -= batch_size
assert (ref_expanded_hidden_states == expanded_execute_model_request.
previous_hidden_states.hidden_states).all().item()

View File

@@ -0,0 +1,224 @@
import torch
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from .utils import create_seq_group_metadata_from_prompts, create_worker
def test_ngram_algo_correctness_for_single_no_match():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario cannot find any candidate in one single batch
"""
# Change block size since Cambricon-vLLM only supports block size with
# 16 in paged mode.
block_size = 16
num_gpu_blocks = 2048 // block_size
seed = 100
model_name = 'JackFram/llama-68m'
vocab_size = 32_000
device = 'cuda:0'
ngram_worker = create_worker(
NGramWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
proposer = Top1Proposer(
worker=ngram_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=20,
)
# set ngram window [1, 3], which is window=1/2/3
ngram_worker.set_ngram_window_size(1, 3)
prompts = [
# shall find no candidate
[1, 2, 3, 4, 5, 6, 7],
]
proposal_len = 5
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len),
seq_ids_with_bonus_token_in_last_step=None)
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([1, proposal_len])
assert proposals.proposal_probs.shape[:-1] == torch.Size([1, proposal_len])
assert proposals.proposal_lens.shape == torch.Size([1])
assert proposals.proposal_lens.tolist() == [0]
def test_ngram_algo_correctness_for_batches_not_match_all():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario find some candidate not full in batchs
"""
# Change block size since Cambricon-vLLM only supports block size with
# 16 in paged mode.
block_size = 16
num_gpu_blocks = 2048 // block_size
seed = 100
model_name = 'JackFram/llama-68m'
vocab_size = 32_000
device = 'cuda:0'
ngram_worker = create_worker(
NGramWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
proposer = Top1Proposer(
worker=ngram_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=20,
)
# set ngram window [1, 3], which is window=1/2/3
ngram_worker.set_ngram_window_size(1, 3)
prompts = [
# shall find no candidate
[1, 2, 3, 4, 5, 6, 7],
# shall find candidate 12,13,14,15,16
[11, 12, 13, 14, 15, 16, 11],
# shall find candidate 23,24,25,26,21
[21, 21, 22, 23, 24, 25, 26, 21, 22],
# shall find candidate 34,35,36,37,38
[31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
# shall find no candidate as exceed max_proposal_len
[
31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 33, 34, 35, 36, 37,
38, 31, 32, 33
],
]
proposal_len = 5
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)
for sg in seq_group_metadata_list:
sg.is_prompt = False
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len),
seq_ids_with_bonus_token_in_last_step=None)
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([5, proposal_len])
assert proposals.proposal_probs.shape[:-1] == torch.Size([5, proposal_len])
assert proposals.proposal_lens.shape == torch.Size([5])
# the first sequence has no match so proposal_len should be overwritten to 0
assert proposals.proposal_lens.tolist(
) == [0] + [proposal_len for _ in range(3)] + [0]
for i in range(proposal_len):
assert proposals.proposal_token_ids[0][i] == -1
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 1]
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 3]
assert proposals.proposal_token_ids[3][i] == prompts[3][i + 5]
assert proposals.proposal_token_ids[4][i] == -1
def test_ngram_algo_correctness_for_batches_match_all():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario find candidate in all batches
"""
# Change block size since Cambricon-vLLM only supports block size with
# 16 in paged mode.
block_size = 16
num_gpu_blocks = 2048 // block_size
seed = 100
model_name = 'JackFram/llama-68m'
vocab_size = 32_000
device = 'cuda:0'
ngram_worker = create_worker(
NGramWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
proposer = Top1Proposer(
worker=ngram_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=20,
)
# set ngram window [0, 3], which is window=1/2/3
ngram_worker.set_ngram_window_size(1, 3)
prompts = [
# shall find candidate 12,13,14,15,16
[11, 12, 13, 14, 15, 16, 11],
# shall find candidate 23,24,25,26,21
[21, 21, 22, 23, 24, 25, 26, 21, 22],
# shall find candidate 34,35,36,37,38
[31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
]
proposal_len = 5
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)
# Normally drafter is run on decode requests only; here we check the output
# of the ngram worker as it is the sole proposer that has no forward.
for sg in seq_group_metadata_list:
sg.is_prompt = False
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len),
seq_ids_with_bonus_token_in_last_step=None)
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([3, proposal_len])
assert proposals.proposal_probs.shape[:-1] == torch.Size([3, proposal_len])
assert proposals.proposal_lens.shape == torch.Size([3])
assert proposals.proposal_lens.tolist() == [proposal_len for _ in range(3)]
for i in range(proposal_len):
assert proposals.proposal_token_ids[0][i] == prompts[0][i + 1]
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 3]
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 5]

View File

@@ -0,0 +1,114 @@
import random
from typing import List
import pytest
import torch
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.mlu_batch_expansion import MLUBatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores
from vllm.spec_decode.mqa_scorer import MQAScorer
from vllm.worker.mlu_worker import MLUWorker
from .utils import create_batch, create_worker
def create_proposal(propose_lens: List[int], vocab_size: int,
device: str) -> SpeculativeProposals:
batch_size = len(propose_lens)
max_propose_len = max(propose_lens)
proposal_probs = torch.rand((batch_size, max_propose_len, vocab_size),
device=device)
proposal_token_ids = torch.full((batch_size, max_propose_len),
fill_value=-1,
device=device)
for i in range(batch_size):
proposal_token_ids[i][:propose_lens[i]] = torch.argmax(
proposal_probs[i][:propose_lens[i]], dim=-1)
propose_lens = torch.tensor(propose_lens, device=device)
return SpeculativeProposals(proposal_token_ids, proposal_probs,
propose_lens)
def assert_score_equal(score1: SpeculativeScores,
score2: SpeculativeScores) -> None:
assert torch.allclose(score1.probs, score2.probs)
assert torch.allclose(score1.logprobs, score2.logprobs)
assert torch.equal(
score1.token_ids,
score2.token_ids), f"{score1.token_ids}, {score2.token_ids}"
@pytest.mark.parametrize('model_name', ['facebook/opt-125m'])
@pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16])
@pytest.mark.parametrize('max_propose_len', [1, 3, 5])
@pytest.mark.parametrize('mixed_propose_len', [True])
@pytest.mark.parametrize('device', ['cuda'])
@pytest.mark.parametrize('prefill_chunking', [False, True])
def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
mixed_propose_len: bool, device: str,
prefill_chunking: bool) -> None:
"""
Compare the batch expansion scorer and mqa scorer return the same score.
We test for both queries with the same propose length and different
propose length, as well as mixed prefill-decode batches.
"""
seed = 0
block_size = 16
num_gpu_blocks = 2048 // block_size
scorer_worker = create_worker(MLUWorker, model_name, block_size,
num_gpu_blocks, seed)
scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True
scorer_worker.model_runner.model.sampler.\
should_modify_greedy_probs_inplace = True
vocab_size = scorer_worker.vocab_size
if not mixed_propose_len:
propose_lens = [max_propose_len] * batch_size
else:
# There must be at least 1 decode request, otherwise
# we have nothing to score (`_run_no_spec`).
non_zero_cnt = random.randint(1, batch_size)
propose_lens = [max_propose_len
] * non_zero_cnt + [0] * (batch_size - non_zero_cnt)
random.shuffle(propose_lens)
seq_group_metadatalist, _, _ = create_batch(batch_size,
max_propose_len,
block_size=block_size,
num_gpu_blocks=num_gpu_blocks)
if mixed_propose_len and prefill_chunking and (n_prefills :=
batch_size - non_zero_cnt):
prefill, _, _ = create_batch(n_prefills,
None,
prefill_chunk_size=4,
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
seq_ids=list(
range(batch_size,
batch_size + n_prefills)))
# re-order to guarantee prefill|decode order
target_group_metadatalist = [
seq_group_metadatalist[i] for i, p in enumerate(propose_lens)
if p > 0
]
seq_group_metadatalist = prefill + target_group_metadatalist
propose_lens = [0] * n_prefills + [p for p in propose_lens if p > 0]
proposals = create_proposal(propose_lens, vocab_size, device)
requests = ExecuteModelRequest(seq_group_metadatalist,
num_lookahead_slots=max_propose_len)
batch_expansion_scorer = MLUBatchExpansionTop1Scorer(scorer_worker, device,
vocab_size)
batch_expansion_score = batch_expansion_scorer.score_proposals(
requests, proposals)
mqa_scorer = MQAScorer(scorer_worker, device, vocab_size)
mqa_score = mqa_scorer.score_proposals(requests, proposals)
assert_score_equal(batch_expansion_score, mqa_score)

View File

@@ -0,0 +1,909 @@
import random
from collections import defaultdict
from types import SimpleNamespace
from typing import Dict, List, Set
from unittest.mock import MagicMock
import pytest
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SequenceOutput
from vllm.spec_decode.mlu_batch_expansion import MLUBatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.mlu_metrics import (MLUAsyncMetricsCollector,
SpecDecodeWorkerMetrics)
from vllm.spec_decode.mlu_multi_step_worker import MLUMultiStepWorker
from vllm.spec_decode.mlu_spec_decode_worker import (MLUSpecDecodeWorker,
split_num_cache_blocks_evenly)
from .test_utils import mock_spec_decode_sampler
from .utils import create_batch, create_sampler_output_list, mock_worker
BatchExpansionTop1Scorer = MLUBatchExpansionTop1Scorer
MultiStepWorker = MLUMultiStepWorker
AsyncMetricsCollector = MLUAsyncMetricsCollector
SpecDecodeWorker = MLUSpecDecodeWorker
@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_correctly_calls_draft_model(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker calls the draft worker with correct
inputs. Everything else is mocked out.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(
draft_worker,
target_worker,
mock_spec_decode_sampler(acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector)
exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
call_args_list = draft_worker.get_spec_proposals.call_args_list
assert len(call_args_list) == 1
for args, _ in call_args_list:
actual_execute_model_data = args[0]
assert actual_execute_model_data == execute_model_req
@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_batch_expansion_correctly_calls_target_model(
k: int, batch_size: int, acceptance_sampler_method: str):
"""Verify SpecDecodeWorker calls the target model with correct
inputs with batch expansion. Everything else is mocked out.
"""
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
target_worker = mock_worker(use_spec=False)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda'
target_worker.device = 'cuda'
set_random_seed(1)
worker = SpecDecodeWorker(
draft_worker,
target_worker,
mock_spec_decode_sampler(acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector,
disable_mqa_scorer=True)
worker.init_device()
vocab_size = 32_000
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device='cuda')
proposal_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device='cuda')
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
device='cuda') * k
seq_group_metadata_list, prompts, prev_output_tokens = create_batch(
batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens)
exception_secret = 'artificial stop'
target_worker.execute_model.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
seen_contexts: List[List[int]] = []
call_args_list = target_worker.execute_model.call_args_list
assert len(call_args_list) == 1
for _, kwargs in call_args_list:
seq_group_metadata_list = kwargs[
"execute_model_req"].seq_group_metadata_list
assert len(seq_group_metadata_list) == (k + 1) * batch_size
for seq_group_metadata in seq_group_metadata_list:
for seq_data in seq_group_metadata.seq_data.values():
seen_contexts.append(seq_data.get_token_ids())
expected_seen_contexts: List[List[int]] = []
for prompt, prev_generated, draft_tokens in zip(
prompts, prev_output_tokens, proposal_token_ids.tolist()):
for i in range(len(draft_tokens) + 1):
expected_seen_contexts.append(prompt + prev_generated +
draft_tokens[:i])
seen_contexts.sort()
expected_seen_contexts.sort()
assert expected_seen_contexts == seen_contexts
@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker calls the rejection sampler with
correct inputs. Everything else is mocked out.
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker,
vocab_size=vocab_size,
use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda'
target_worker.device = 'cuda'
set_random_seed(1)
worker = SpecDecodeWorker(draft_worker,
target_worker,
spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device()
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device='cuda')
proposal_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device='cuda')
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
device='cuda') * k
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens)
target_token_ids = torch.randint(low=0,
high=vocab_size,
size=(1, batch_size * (k + 1)),
dtype=torch.int64,
device='cuda')
target_token_probs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_token_logprobs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs,
target_token_logprobs)
target_worker.execute_model.return_value = [target_output[0]]
exception_secret = 'artificial stop'
spec_decode_sampler.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
assert len(spec_decode_sampler.call_args_list) == 1
_, kwargs = spec_decode_sampler.call_args_list[0]
actual = SimpleNamespace(**kwargs)
assert torch.equal(actual.bonus_token_ids,
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
assert torch.equal(actual.target_with_bonus_probs,
target_token_probs.reshape(batch_size, k + 1, -1))
assert torch.equal(actual.draft_token_ids, proposal_token_ids)
assert torch.equal(actual.draft_probs, proposal_probs)
@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_correctly_formats_output(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker formats sampler output correctly.
Everything else is mocked out.
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker,
vocab_size=vocab_size,
use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda'
target_worker.device = 'cuda'
set_random_seed(1)
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
worker = SpecDecodeWorker(draft_worker,
target_worker,
spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device()
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device='cuda')
proposal_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device='cuda')
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
device='cuda') * k
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens)
target_token_ids = torch.randint(low=0,
high=vocab_size,
size=(1, batch_size * (k + 1)),
dtype=torch.int64,
device='cuda')
target_token_probs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_token_logprobs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs,
target_token_logprobs)
target_worker.execute_model.return_value = [target_output[0]]
spec_decode_sampler_output = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k + 1),
dtype=torch.int64,
device='cuda')
for i in range(batch_size):
minimum_accepted_tokens = 1
spec_decode_sampler_output[i][
-random.randint(minimum_accepted_tokens, k + 1):] = -1
spec_decode_sampler.return_value = spec_decode_sampler_output
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
expected_output = create_sampler_output_list(
token_ids=spec_decode_sampler_output.transpose(0, 1),
probs=[None for _ in range(k + 1)],
logprobs=[None for _ in range(k + 1)])
seq_ids = [
next(iter(seq_group_metadata.seq_data.keys()))
for seq_group_metadata in seq_group_metadata_list
]
actual_output_by_seq: Dict[int, List[SequenceOutput]] = {
seq_id: []
for seq_id in seq_ids
}
expected_output_by_seq: Dict[int, List[SequenceOutput]] = {
seq_id: []
for seq_id in seq_ids
}
for step in output:
for seq_group in step:
for sample in seq_group.samples:
seq_id = sample.parent_seq_id
actual_output_by_seq[seq_id].append(sample)
for step in expected_output:
for seq_group in step:
for sample in seq_group.samples:
seq_id = sample.parent_seq_id
expected_output_by_seq[seq_id].append(sample)
all_seen_seq_ids = set(
list(actual_output_by_seq.keys()) +
list(expected_output_by_seq.keys()))
for seq_id in all_seen_seq_ids:
actual_by_step = actual_output_by_seq[seq_id]
expected_by_step = expected_output_by_seq[seq_id]
for i in range(k + 1):
if i >= len(actual_by_step):
assert expected_by_step[i].output_token == -1
continue
assert actual_by_step[i].output_token == expected_by_step[
i].output_token
@pytest.mark.parametrize('k', [1, 2])
@pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('returns_metrics', [True, False])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker collects metrics.
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker,
vocab_size=vocab_size,
use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'cuda'
target_worker.device = 'cuda'
set_random_seed(1)
worker = SpecDecodeWorker(draft_worker,
target_worker,
spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device()
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device='cuda')
proposal_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device='cuda')
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
device='cuda') * k
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens)
target_token_ids = torch.randint(low=0,
high=vocab_size,
size=(1, batch_size * (k + 1)),
dtype=torch.int64,
device='cuda')
target_token_probs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_token_logprobs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs,
target_token_logprobs)
target_worker.execute_model.return_value = [target_output[0]]
spec_decode_sampler_output = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k + 1),
dtype=torch.int64,
device='cuda')
for i in range(batch_size):
minimum_accepted_tokens = 1
spec_decode_sampler_output[i][
-random.randint(minimum_accepted_tokens, k + 1):] = -1
spec_decode_sampler.return_value = spec_decode_sampler_output
mock_rejsample_metrics = MagicMock(
spec=SpecDecodeWorkerMetrics) if returns_metrics else None
metrics_collector.maybe_collect_rejsample_metrics.return_value = (
mock_rejsample_metrics)
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
call_args_list = (
metrics_collector.maybe_collect_rejsample_metrics.call_args_list)
assert len(call_args_list) == 1
args, kwargs = call_args_list[0]
assert args[0] == k or kwargs.get('k', -1) == k
@pytest.mark.parametrize('k', [0])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_k_equals_zero(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify that the SpecDecodeWorker calls the draft and target workers
when k is zero. This happens during prefill.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
sampler_output = MagicMock(spec=SamplerOutput)
sampler_output.hidden_states = None
target_worker.execute_model.return_value = [sampler_output]
draft_worker.device = 'cuda'
target_worker.device = 'cuda'
set_random_seed(1)
worker = SpecDecodeWorker(
proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector,
)
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
prev_output_token_len=0)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
out = worker.execute_model(execute_model_req=execute_model_req)
assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].sampled_token_probs is None, (
"expect gpu tensor references to be None")
assert out[
0].sampled_token_ids is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with(execute_model_req)
target_worker.execute_model.assert_called_once_with(execute_model_req)
@pytest.mark.parametrize('k', [0, 5])
@pytest.mark.parametrize('batch_size', [0])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_empty_input_batch(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify that the SpecDecodeWorker calls the draft and target workers
when the input batch is empty. This can happen if the engine communicates
to the workers information without scheduling a batch.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
sampler_output = MagicMock(spec=SamplerOutput)
sampler_output.hidden_states = None
target_worker.execute_model.return_value = [sampler_output]
draft_worker.device = 'cuda'
target_worker.device = 'cuda'
set_random_seed(1)
worker = SpecDecodeWorker(
proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector,
)
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
prev_output_token_len=0)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
out = worker.execute_model(execute_model_req=execute_model_req)
assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].sampled_token_probs is None, (
"expect gpu tensor references to be None")
assert out[
0].sampled_token_ids is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with(execute_model_req)
target_worker.execute_model.assert_called_once_with(execute_model_req)
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@pytest.mark.skip_global_cleanup
def test_init_device(acceptance_sampler_method: str):
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
well as other GPU initialization.
"""
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
target_worker = mock_worker(use_spec=False)
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(
proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector,
)
worker.init_device()
draft_worker.init_device.assert_called_once()
target_worker.init_device.assert_called_once()
metrics_collector.init_mlu_tensors.assert_called_once()
spec_decode_sampler.init_gpu_tensors.assert_called_once()
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_initialize_cache(acceptance_sampler_method):
"""Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer
workers.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
metrics_collector=metrics_collector)
kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
worker.initialize_cache(**kwargs)
draft_worker.initialize_cache.assert_called_once_with(**kwargs)
target_worker.initialize_cache.assert_called_once_with(**kwargs)
@pytest.mark.parametrize('available_gpu_blocks', [1, 1024])
@pytest.mark.parametrize('available_cpu_blocks', [500])
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@pytest.mark.skip_global_cleanup
def test_determine_num_available_blocks(available_gpu_blocks: int,
available_cpu_blocks: int,
target_cache_block_size_bytes: int,
draft_kv_size_bytes: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks.
Specifically, it should run profiling in the scorer worker, and then evenly
split the blocks between proposer and scorer worker.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.determine_num_available_blocks.return_value = (
available_gpu_blocks, available_cpu_blocks)
target_worker.get_cache_block_size_bytes.return_value = (
target_cache_block_size_bytes)
draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes
worker = SpecDecodeWorker(
draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks()
target_worker.determine_num_available_blocks.assert_called_once()
assert num_cpu_blocks == available_cpu_blocks
assert num_gpu_blocks == split_num_cache_blocks_evenly(
target_cache_block_size_bytes, draft_kv_size_bytes,
available_gpu_blocks)
@pytest.mark.parametrize('available_gpu_blocks',
list(range(20)) + [1024, 1024**2])
@pytest.mark.parametrize('target_cache_block_size_bytes',
[2 * 2 * 4096, 2 * 2 * 8192])
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
@pytest.mark.skip_global_cleanup
def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
target_cache_block_size_bytes: int,
draft_kv_size_bytes: int):
"""Verify split_num_cache_blocks_evenly does not exceed original memory
allocation in bytes.
"""
num_blocks = split_num_cache_blocks_evenly(target_cache_block_size_bytes,
draft_kv_size_bytes,
available_gpu_blocks)
assert (num_blocks * target_cache_block_size_bytes) + (
num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks *
target_cache_block_size_bytes)
@torch.inference_mode()
def test_populate_seq_ids_with_bonus_tokens():
"""
Verify that a call to _create_output_sampler_list correctly updates
seq_with_bonus_token_in_last_step.
seq_with_bonus_token_in_last_step is an internal data structure in
SpecDecodeWorker that tracks the sequence IDs which are assigned bonus
tokens by the target model in their last forward pass. This state is
maintained only for models relying on the KV cache, such as those using
the MultiStepWorker.
"""
batch_size = 10
k = 5
vocab_size = 10000
num_sequences_with_bonus_tokens = 5
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
target_worker.device = 'cuda'
set_random_seed(1)
draft_worker = mock_worker(cls=MultiStepWorker)
draft_worker.device = 'cuda'
# The sequence_ids attached to each sequence in the batch.
# The sequence at index i has seq_id assigned_seq_ids[i]
assigned_seq_ids = list(range(batch_size))
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
seq_ids=assigned_seq_ids,
prev_output_token_len=10)
target_token_logprobs = torch.rand(batch_size, (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
accepted_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, (k + 1)),
dtype=torch.int64,
device='cuda')
expected_request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set)
for seq_group_metadata in seq_group_metadata_list:
for seq_id in seq_group_metadata.seq_data:
expected_request_id_seq_ids_mapping[
seq_group_metadata.request_id].add(seq_id)
# Generate a random sample of sequence indexes with bonus tokens
seq_indexes_with_bonus_tokens = random.sample(
range(batch_size), num_sequences_with_bonus_tokens)
# Create a mask that is True for indices in seq_indexes_with_bonus_tokens
mask = torch.ones(batch_size, dtype=torch.bool, device='cuda')
mask[seq_indexes_with_bonus_tokens] = False
# Set the last token ID to -1 for all indices not in
# seq_indexes_with_bonus_tokens to indicate the lack of bonus token in
# those indices.
accepted_token_ids[mask, -1:] = -1
worker = SpecDecodeWorker(draft_worker,
target_worker,
mock_spec_decode_sampler("rejection_sampler"),
disable_logprobs=False,
metrics_collector=metrics_collector)
# Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs.
# This set includes all sequence IDs in the batch as well as an additional
# `num_extra_sequence_ids` sequence IDs. Note that the sequence IDs are in
# the range [0, batch_size + num_extra_sequence_ids).
num_extra_sequence_ids = 10
worker._seq_with_bonus_token_in_last_step = set(
range(batch_size + num_extra_sequence_ids))
worker._create_output_sampler_list(
seq_group_metadata_list=seq_group_metadata_list,
accepted_token_ids=accepted_token_ids,
target_logprobs=target_token_logprobs,
k=k,
stage_times=(0, 0, 0))
# Verify that _seq_with_bonus_token_in_last_step contains the following:
# 1. Sequence IDs that were already present in
# _seq_with_bonus_token_in_last_step but were not part of the current
# batch are retained.
# 2. Of the sequence IDs present in the current batch, only those with a
# bonus token are retained in _seq_with_bonus_token_in_last_step.
# Sequence IDs that are present in the current batch but do not have
# bonus tokens are removed from _seq_with_bonus_token_in_last_step.
expected_seq_ids_with_bonus_tokens = \
set([assigned_seq_ids[i] for i in seq_indexes_with_bonus_tokens])
additional_sequence_ids = \
set(range(batch_size, batch_size + num_extra_sequence_ids))
assert worker._seq_with_bonus_token_in_last_step == \
expected_seq_ids_with_bonus_tokens.union(additional_sequence_ids)
assert worker._request_id_seq_id_mapping == \
expected_request_id_seq_ids_mapping
@torch.inference_mode()
def test_handle_finished_requests():
"""
Test to verify that finished request IDs are appropriately processed to
update the internal state of the SpecDecodeWorker.
This test initializes the SpecDecodeWorker with mock data, marks certain
requests as finished, and ensures that the corresponding sequence IDs are
correctly removed from the internal mappings.
"""
batch_size = 32
k = 3
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker,
mock_spec_decode_sampler("rejection_sampler"),
metrics_collector)
# Initialize the request_id_seq_id_mapping mapping dict with a few fake
# request ids and corresponding sequence ids.
worker._request_id_seq_id_mapping = \
{'request-1': {1,2,3}, 'request-2': {4,5,6,7},
'request-3': {8,9}, 'request-4': {10,11}}
# Initialize seq_with_bonus_token_in_last_step with a few fake
# sequence ids.
worker._seq_with_bonus_token_in_last_step = {1, 4, 5, 8, 9, 10}
exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
# Mark requests with ids request-1 and request-3 as finished.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
finished_requests_ids=['request-1', 'request-3'])
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
# Verify that request-1 and request-3 are removed from
# request_id_seq_id_mapping
assert worker._request_id_seq_id_mapping == \
{'request-2': {4,5,6,7}, 'request-4': {10,11}}
# Verify that all sequence ids corresponding to 'request-1'
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
assert worker._seq_with_bonus_token_in_last_step == \
{4,5,10}
@pytest.mark.parametrize('k', [3])
@pytest.mark.parametrize('batch_size', [2, 32])
@pytest.mark.parametrize("batch_composition",
["prefill_only", "decode_only", "mixed"])
@torch.inference_mode()
def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
"""
Verify SpecDecodeWorker calls match the expected flow.
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker,
target_worker,
mock_spec_decode_sampler("rejection_sampler"),
disable_logprobs=False,
metrics_collector=metrics_collector)
exception_secret = 'artificial stop'
worker.scorer = mock_worker(BatchExpansionTop1Scorer)
worker.scorer.score_proposals.side_effect = ValueError(exception_secret)
# Create batch with combination of terminal/non-terminal prefill chunks
# and decodes (different seq_ids).
decodes, _, _ = create_batch(batch_size, k)
# Pre-chunking here, get 'batch_size' chunks.
prefill, _, _ = create_batch(batch_size,
k,
prefill_chunk_size=4,
seq_ids=list(range(batch_size,
batch_size * 2)))
if batch_composition == "prefill_only":
n_prefills = batch_size
elif batch_composition == "decode_only":
n_prefills = 0
else:
n_prefills = random.randint(1, batch_size - 1)
n_decodes = batch_size - n_prefills
prefill = random.sample(prefill, n_prefills)
decodes = random.sample(decodes, n_decodes)
target_group_metadata_list = prefill + decodes
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=target_group_metadata_list,
num_lookahead_slots=k)
target_token_ids = torch.randint(low=0,
high=vocab_size,
size=(1, batch_size * (k + 1)),
dtype=torch.int64,
device='cuda')
target_token_probs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_token_logprobs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='cuda')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs,
target_token_logprobs)
target_worker.execute_model.return_value = [target_output[0]]
if not len(decodes):
worker.execute_model(execute_model_req=execute_model_req)
# no spec run (prefill only)
draft_worker.execute_model.assert_called_once_with(execute_model_req)
target_worker.execute_model.assert_called_once_with(execute_model_req)
else:
# Decode-only run OR mixed batch, scorer call fails (it's mocked)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
# but first draft still counted
assert draft_worker.get_spec_proposals.call_count == 1

View File

@@ -0,0 +1,147 @@
from unittest.mock import MagicMock
import pytest
import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.sampler import _get_ranks
from vllm.model_executor.layers.typical_acceptance_sampler import (
TypicalAcceptanceSampler)
from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
from vllm.spec_decode.util import (get_sampled_token_logprobs,
split_batch_by_proposal_len)
def test_get_all_seq_ids():
"""Verify get_all_seq_ids extracts all seq ids.
"""
expected_seq_ids = list(range(10)) + list(range(100, 110))
seq_group_metadata_list = [
SequenceGroupMetadata(
request_id=str(seq_id),
is_prompt=True,
seq_data={
seq_id: MagicMock(),
},
sampling_params=MagicMock(),
block_tables={
seq_id: MagicMock(),
},
lora_request=None,
) for seq_id in expected_seq_ids
]
actual_seq_ids = get_all_seq_ids(seq_group_metadata_list)
assert actual_seq_ids == expected_seq_ids
@pytest.fixture
def fake_sequence_group_metadata():
seq_ids = list(range(3))
return [
SequenceGroupMetadata(
request_id=str(i),
is_prompt=True,
seq_data={
i: MagicMock(),
},
sampling_params=MagicMock(),
block_tables={
i: MagicMock(),
},
lora_request=None,
) for i in seq_ids
]
def test_filter_zero_length_proposals(fake_sequence_group_metadata):
proposal_lens = [0, 1, 0]
_, (filtered_groups,
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens)
expected_groups = [
fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
]
expected_indices = [0, 2]
assert filtered_groups == expected_groups
assert indices == expected_indices
def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
proposal_lens = [0, 1, 2]
(filtered_groups,
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens)
expected_groups = [
fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
]
expected_indices = [1, 2]
assert filtered_groups == expected_groups
assert indices == expected_indices
def test_empty_inputs():
_, (filtered_groups, indices) = split_batch_by_proposal_len([], [])
assert filtered_groups == []
assert indices == []
def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
proposal_lens = [0, 0, 0]
(filtered_groups,
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens)
assert filtered_groups == []
assert indices == []
def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
proposal_lens = [1, 1, 1]
_, (filtered_groups,
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens)
assert filtered_groups == []
assert indices == []
def mock_spec_decode_sampler(acceptance_sampler_method):
"""
Returns either a RejectionSampler or TypicalAcceptanceSampler
object depending on whether acceptance_sampler_method is
'rejection_sampler' or 'typical_acceptance_sampler' respectively.
"""
if acceptance_sampler_method == "rejection_sampler":
sampler = MagicMock(spec=RejectionSampler)
sampler.token_id_dtype = torch.int64
return sampler
elif acceptance_sampler_method == "typical_acceptance_sampler":
sampler = MagicMock(spec=TypicalAcceptanceSampler)
sampler.token_id_dtype = torch.int64
return sampler
else:
raise ValueError(f"Invalid sampler name {acceptance_sampler_method}")
def test_get_sampled_token_logprobs():
"""Verify get_sampled_token_logprobs returns consistent rankings
with regular get_ranks when probabilities match exactly.
"""
logprob_tensor = torch.tensor(
[[[-.1, -.1]] * 2]) # shape (num_steps, batch_size, vocab_size)
sampled_token_tensor = torch.tensor([[1,
0]]) # shape (num_steps, batch_size)
ranks_spec_dec, _ = get_sampled_token_logprobs(logprob_tensor,
sampled_token_tensor)
ranks_regular = _get_ranks(logprob_tensor.reshape((2, -1)),
sampled_token_tensor.reshape(-1))
assert torch.equal(ranks_spec_dec.reshape(-1), ranks_regular)

View File

@@ -0,0 +1,281 @@
from itertools import count
from typing import Callable, Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import TypeVar, Union
from unittest.mock import MagicMock
import torch
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sampling_params import SamplingParams
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceData, SequenceGroupMetadata, SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
from vllm.worker.worker import Worker
T = TypeVar("T", bound=Worker)
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
return (seq_len + block_size - 1) // block_size
def mock_worker(cls=None,
vocab_size: int = 30_000,
max_model_len: int = 2048,
rank: int = 0,
use_spec: bool = True) -> MagicMock:
if cls is None:
cls = Worker
spec = cls if use_spec else None
worker = MagicMock(spec=spec)
worker.vocab_size = vocab_size
worker.max_model_len = max_model_len
worker.rank = rank
worker.device = 'cuda:0'
return worker
def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
seed_iter = iter(rand_seeds)
original_execute_model = worker.execute_model
def new_execute_model(*args, **kwargs):
result = original_execute_model(*args, **kwargs)
set_random_seed(next(seed_iter))
return result
return new_execute_model
def zero_kv_cache(cache_engine: List[CacheEngine]):
assert cache_engine[0].gpu_cache
# kv blocks in mlu platforms are different from those in gpu platforms.
# for key_blocks, value_blocks in cache_engine[0].gpu_cache:
# key_blocks.zero_()
# value_blocks.zero_()
for kv_blocks, scale_blocks in cache_engine[0].gpu_cache:
kv_blocks.zero_()
if scale_blocks is not None:
scale_blocks.zero_()
def create_worker(cls: Callable[..., T],
model_name: str,
block_size: int,
num_gpu_blocks: int,
seed: int,
is_driver_worker: bool = True,
enforce_eager: bool = True,
model_runner_cls: Optional[ModelRunner] = None,
dtype: Optional[str] = "auto") -> T:
engine_args = EngineArgs(
model=model_name,
seed=seed,
block_size=block_size,
enforce_eager=enforce_eager,
dtype=dtype,
)
engine_config = engine_args.create_engine_config()
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
worker = cls(
vllm_config=engine_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
model_runner_cls=model_runner_cls,
)
worker.init_device()
worker.load_model()
engine_config.cache_config.num_gpu_blocks = num_gpu_blocks
engine_config.cache_config.num_cpu_blocks = 0
worker.initialize_cache(
num_gpu_blocks=engine_config.cache_config.num_gpu_blocks,
num_cpu_blocks=engine_config.cache_config.num_cpu_blocks)
return worker
def create_seq_group_metadata_from_prompts(
prompts: List[List[int]],
num_gpu_blocks: int,
block_size: int,
final_prompt_lens: List[int],
continuations: Optional[List[List[int]]] = None,
seq_ids: Optional[List[int]] = None,
) -> List[SequenceGroupMetadata]:
if continuations is None:
continuations = [[] for _ in prompts]
if seq_ids is None:
seq_ids = list(i for i, _ in enumerate(prompts))
free_gpu_blocks = list(range(num_gpu_blocks))
block_allocations = {
i: [
free_gpu_blocks.pop()
for _ in range(round_up_to_next_block(final_len, block_size))
]
for i, final_len in enumerate(final_prompt_lens)
}
seq_grou_metadata_list = []
for i, (prompt_token_ids,
cont_token_ids) in enumerate(zip(prompts, continuations)):
data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids)
data.update_num_computed_tokens(
len(prompt_token_ids) + len(cont_token_ids) - 1)
seq_data = {i: data}
seq_grou_metadata_list.append(
SequenceGroupMetadata(
request_id=str(i),
is_prompt=len(cont_token_ids) == 0,
seq_data=seq_data,
sampling_params=SamplingParams(temperature=0.0),
block_tables={i: block_allocations[i][:]},
))
return seq_grou_metadata_list
def create_chunked_seq_group_metadata_from_prompt(
prompt: List[int],
num_gpu_blocks: int,
chunk_size: int,
block_size: int,
seq_id: Optional[int] = None) -> List[SequenceGroupMetadata]:
if seq_id is None:
seq_id = 0
free_gpu_blocks = list(range(num_gpu_blocks))
block_allocations = [
free_gpu_blocks.pop()
for _ in range(round_up_to_next_block(len(prompt), block_size))
]
seq_group_metadata_list = []
for i, idx in enumerate(range(0, len(prompt), chunk_size)):
chunk_ids = prompt[idx:idx + chunk_size]
data = SequenceData.from_seqs(prompt)
data.update_num_computed_tokens(idx)
seq_data = {i: data}
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=str(seq_id),
is_prompt=True,
do_sample=idx + chunk_size >= len(prompt), # terminal chunk
seq_data=seq_data,
sampling_params=SamplingParams(temperature=0.0),
block_tables={i: block_allocations},
token_chunk_size=len(chunk_ids)))
return seq_group_metadata_list
def assert_logprobs_dict_allclose(
actual_logprobs: List[Dict[int, Logprob]],
expected_logprobs: List[Dict[int, Logprob]]) -> None:
for single_step_actual_logprobs, single_step_expected_logprobs in zip(
actual_logprobs, expected_logprobs):
assert set(single_step_actual_logprobs.keys()) == set(
single_step_expected_logprobs.keys())
for token_id in single_step_actual_logprobs:
actual = torch.tensor(
single_step_actual_logprobs[token_id].logprob)
expected = torch.tensor(
single_step_expected_logprobs[token_id].logprob)
torch.testing.assert_close(actual, expected)
def create_sampler_output_list(
token_ids: torch.Tensor,
probs: GenericSequence[Optional[torch.Tensor]],
logprobs: GenericSequence[Optional[torch.Tensor]],
seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]:
num_steps, batch_size = token_ids.shape
token_ids_by_step = token_ids.tolist()
if seq_ids is None:
seq_ids = list(range(batch_size))
return [
SamplerOutput(outputs=[
CompletionSequenceGroupOutput(
samples=[
SequenceOutput(
output_token=token_id,
parent_seq_id=seq_ids[seq_index],
logprobs={token_id: Logprob(0)},
)
],
prompt_logprobs=None,
) for seq_index, token_id in enumerate(token_ids_by_step[step])
],
sampled_token_probs=probs[step],
logprobs=logprobs[step],
sampled_token_ids=token_ids[step])
for step in range(num_steps)
]
def create_batch(batch_size,
k,
prompt_len: Union[int, List[int]] = 10,
prev_output_token_len: int = 10,
seq_ids: Optional[List[int]] = None,
num_gpu_blocks: Optional[int] = None,
block_size: Optional[int] = None,
prefill_chunk_size: Optional[int] = None):
if block_size is None:
block_size = 8
if num_gpu_blocks is None:
num_gpu_blocks = 2048 // block_size
iterator = count()
if isinstance(prompt_len, int):
prompt_lens = [prompt_len for _ in range(batch_size)]
else:
prompt_lens = prompt_len
prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens]
if prefill_chunk_size:
# Create a batch of chunked prompts.
if not seq_ids:
seq_ids = list(range(len(prompts)))
seq_group_metadata_list = []
for p, sid in zip(prompts, seq_ids):
seq_group_metadata_list += \
create_chunked_seq_group_metadata_from_prompt(
p, num_gpu_blocks, prefill_chunk_size, block_size, sid)
seq_group_metadata_list = seq_group_metadata_list[:batch_size]
prev_output_tokens = []
else:
prev_output_tokens = [[
next(iterator) for _ in range(prev_output_token_len)
] for _ in range(batch_size)]
final_prompt_lens = [
len(prompt) + len(prev_output_token) + k + 1
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size, final_prompt_lens,
prev_output_tokens, seq_ids)
return seq_group_metadata_list, prompts, prev_output_tokens