Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

0
tests/v1/e2e/__init__.py Normal file
View File

View File

@@ -0,0 +1,388 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from itertools import repeat
from typing import Any
import pytest
import torch._dynamo.config as dynamo_config
from vllm import SamplingParams
from vllm.logprobs import Logprob
from vllm.platforms import current_platform
from vllm.sampling_params import StructuredOutputsParams
from vllm.v1.metrics.reader import Metric
from ...conftest import VllmRunner
from ...models.utils import check_outputs_equal
MODEL = "Qwen/Qwen3-0.6B"
MTP_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
first_prompt = (
"The following numbers of the sequence "
+ ", ".join(str(i) for i in range(10))
+ " are:"
)
example_prompts = [first_prompt, "In one word, the capital of France is "] + [
f"Tell me about the number {i}: " for i in range(32)
]
default_params = dict(
temperature=0.0, # greedy
max_tokens=23,
min_tokens=18,
)
def test_without_spec_decoding(
sample_json_schema,
monkeypatch: pytest.MonkeyPatch,
):
"""Test consistency of combos of async scheduling, preemption,
uni/multiproc executor, prefill chunking."""
struct_outputs = StructuredOutputsParams(json=sample_json_schema)
test_sampling_params: list[dict[str, Any]] = [
dict(),
# dict(min_tokens=20),
dict(presence_penalty=-1.0),
dict(bad_words=["the", " the"]),
dict(logprobs=2),
dict(logprobs=2, presence_penalty=-1.0),
dict(structured_outputs=struct_outputs),
dict(
structured_outputs=struct_outputs,
logprobs=2,
presence_penalty=-1.0,
),
]
# test_preemption, executor, async_scheduling,
# spec_config, test_prefill_chunking
test_configs = [
(False, "mp", False, None, False),
(True, "mp", False, None, True),
(False, "mp", True, None, False),
(False, "uni", True, None, False),
(True, "mp", True, None, False),
(True, "uni", True, None, False),
(False, "mp", True, None, True),
(True, "mp", True, None, True),
(True, "uni", True, None, True),
]
if current_platform.is_rocm():
# On ROCm, Only test with structured_outputs (deterministic)
# and skip chunk_prefill (more variable).
test_configs = [
cfg
for cfg in test_configs
if not cfg[4] # skip chunk_prefill=True
]
test_sampling_params = [
p for p in test_sampling_params if p.get("structured_outputs") is not None
]
run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)
def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
"""Test consistency and acceptance rates with some different combos of
preemption, executor, async scheduling, prefill chunking,
spec decoding model length.
"""
spec_config = {
"method": "eagle3",
"num_speculative_tokens": 2,
"model": "nm-testing/Llama3_2_1B_speculator.eagle3",
}
# Set small draft model len to force doesn't-fit-in-drafter case.
spec_config_short = spec_config | {"max_model_len": 50}
test_sampling_params = [
dict(),
dict(logprobs=2),
]
# test_preemption, executor, async_scheduling,
# spec_config, test_prefill_chunking
test_configs = [
(False, "mp", False, None, False),
(False, "mp", False, spec_config, False),
(True, "mp", False, spec_config, True),
(True, "uni", False, spec_config_short, True),
(False, "mp", True, spec_config, False),
(True, "mp", True, spec_config, False),
(False, "mp", True, spec_config_short, True),
(True, "uni", True, spec_config, False),
(True, "uni", True, spec_config_short, False),
(True, "mp", True, spec_config, True),
(True, "uni", True, spec_config_short, True),
]
# On ROCm, use TRITON_ATTN + float32 for better numerical consistency
run_tests(
monkeypatch,
MTP_MODEL,
test_configs,
test_sampling_params,
is_testing_with_spec_decoding=True,
)
@dynamo_config.patch(cache_size_limit=16)
def run_tests(
monkeypatch: pytest.MonkeyPatch,
model: str,
test_configs: list[tuple],
test_sampling_params: list[dict[str, Any]],
is_testing_with_spec_decoding: bool = False,
):
"""Test consistency of combos of async scheduling, preemption,
uni/multiproc executor with spec decoding."""
with monkeypatch.context() as m:
# avoid precision errors
if current_platform.is_rocm():
if is_testing_with_spec_decoding:
# Use TRITON_ATTN for spec decoding test for consistency
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
else:
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_FA")
else:
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
# lock matmul precision to full FP32 (IEEE)
m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "ieee")
# m.setenv("VLLM_BATCH_INVARIANT", "1")
outputs: list[tuple[str, list, list]] = []
for n, (
test_preemption,
executor,
async_scheduling,
spec_config,
test_prefill_chunking,
) in enumerate(test_configs, 1):
test_str = f"{n}/{len(test_configs)}"
test_results = run_test(
model,
test_str,
test_sampling_params,
test_preemption,
executor,
async_scheduling,
spec_config,
test_prefill_chunking=test_prefill_chunking,
is_testing_with_spec_decoding=is_testing_with_spec_decoding,
)
outputs.append(test_results)
baseline_config, baseline_tests, _ = outputs[0]
_, _, baseline_acceptances = next(
(o for o in outputs if o[2] is not None), (None, None, None)
)
print(f"BASELINE: config=[{baseline_config}], accept_rates={baseline_acceptances}")
failure = None
for test_config, test_outputs, test_acceptance_rates in outputs[1:]:
for (base_outs, base_logprobs), base_acceptance_rate, (
test_outs,
test_logprobs,
), test_acceptance_rate, params in zip(
baseline_tests,
baseline_acceptances or repeat(None),
test_outputs,
test_acceptance_rates or repeat(None),
test_sampling_params,
):
try:
check_outputs_equal(
outputs_0_lst=base_outs,
outputs_1_lst=test_outs,
name_0=f"baseline=[{baseline_config}], params={params}",
name_1=f"config=[{test_config}], params={params}",
)
# On ROCm with TRITON_ATTN (spec decoding test), skip strict
# logprobs comparison when logprobs are requested
skip_logprobs_check = (
current_platform.is_rocm()
and params.get("logprobs")
and is_testing_with_spec_decoding
)
if not skip_logprobs_check:
assert _all_logprobs_match(base_logprobs, test_logprobs)
if (
base_acceptance_rate is not None
and test_acceptance_rate is not None
):
if "spec_mml=None" in test_config:
# Preemption causes more variance in acceptance rates
if (
current_platform.is_rocm()
and "preemption=True" in test_config
):
tolerance = 0.10
else:
tolerance = 0.05
assert (
test_acceptance_rate > base_acceptance_rate
or test_acceptance_rate
== pytest.approx(base_acceptance_rate, rel=tolerance)
)
else:
# Currently the reported acceptance rate is expected to be
# lower when we sometimes skip drafting altogether.
assert test_acceptance_rate > 0.1
print(
f"PASSED: config=[{test_config}], params={params}"
f" accept_rate={test_acceptance_rate}"
)
except AssertionError as e:
print(
f"FAILED: config=[{test_config}], params={params}"
f" accept_rate={test_acceptance_rate}"
)
if failure is None:
failure = e
if failure is not None:
raise failure
def run_test(
model: str,
test_str: str,
sampling_param_tests: list[dict[str, Any]],
test_preemption: bool,
executor: str,
async_scheduling: bool,
spec_config: dict[str, Any] | None,
test_prefill_chunking: bool,
is_testing_with_spec_decoding: bool = False,
):
spec_decoding = spec_config is not None
cache_arg: dict[str, Any] = (
# Force preemptions
dict(num_gpu_blocks_override=32)
if test_preemption
else dict(gpu_memory_utilization=0.9)
)
spec_mml = (spec_config or {}).get("max_model_len")
test_config = (
f"executor={executor}, preemption={test_preemption}, "
f"async_sched={async_scheduling}, "
f"chunk_prefill={test_prefill_chunking}, "
f"spec_decoding={spec_decoding}, spec_mml={spec_mml}"
)
print("-" * 80)
print(f"---- TESTING {test_str}: {test_config}")
print("-" * 80)
# On ROCm: use float16 for first test (ROCM_AITER_FA), but float32 for
# spec decoding test (TRITON_ATTN) for better precision.
# On others: always use float32.
if current_platform.is_rocm() and not is_testing_with_spec_decoding:
dtype = "float16"
else:
dtype = "float32"
with VllmRunner(
model,
max_model_len=512,
enable_chunked_prefill=test_prefill_chunking,
# Force prefill chunking
max_num_batched_tokens=48 if test_prefill_chunking else None,
# enforce_eager=True,
async_scheduling=async_scheduling,
distributed_executor_backend=executor,
dtype=dtype,
speculative_config=spec_config,
disable_log_stats=False,
**cache_arg,
) as vllm_model:
results = []
acceptance_rates: list[float] | None = [] if spec_decoding else None
for override_params in sampling_param_tests:
metrics_before = vllm_model.llm.get_metrics()
print(f"----------- RUNNING PARAMS: {override_params}")
results.append(
vllm_model.generate(
example_prompts,
sampling_params=SamplingParams(**default_params, **override_params),
return_logprobs=True,
)
)
metrics_after = vllm_model.llm.get_metrics()
if acceptance_rates is not None:
acceptance_rate = _get_acceptance_rate(metrics_before, metrics_after)
acceptance_rates.append(acceptance_rate)
print(f"ACCEPTANCE RATE {acceptance_rate}")
if test_preemption:
preemptions = _get_count(
metrics_before, metrics_after, "vllm:num_preemptions"
)
assert preemptions > 0, "preemption test had no preemptions"
if len(results) > 1:
# First check that the different parameter configs
# actually result in different output.
for (other_test_outs, other_test_logprobs), params in zip(
results[1:], sampling_param_tests[1:]
):
with pytest.raises(AssertionError):
check_outputs_equal(
outputs_0_lst=results[0][0],
outputs_1_lst=other_test_outs,
name_0=f"baseline params={params}",
name_1=f"other params={params}",
)
assert _all_logprobs_match(results[0][1], other_test_logprobs)
return test_config, results, acceptance_rates
def _all_logprobs_match(req_a, req_b) -> bool:
return (
req_a == req_b
or len(req_a) == len(req_b)
and all(
len(seq_a) == len(seq_b)
and all(_logprobs_match(a, b) for a, b in zip(seq_a, seq_b))
for seq_a, seq_b in zip(req_a, req_b)
)
)
def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool:
if current_platform.is_rocm():
# ROCm has higher numerical variance
# due to use of float16.
rel_tol, abs_tol = 5e-2, 1e-5
else:
rel_tol, abs_tol = 1e-3, 1e-6
return (
len(lps_a) == len(lps_b)
and lps_a.keys() == lps_b.keys()
and all(
a.decoded_token == b.decoded_token
and a.rank == b.rank
and a.logprob == pytest.approx(b.logprob, rel=rel_tol, abs=abs_tol)
for a, b in ((lps_a[x], lps_b[x]) for x in lps_a)
)
)
def _get_acceptance_rate(before: list[Metric], after: list[Metric]) -> float:
draft = _get_count(before, after, "vllm:spec_decode_num_draft_tokens")
accept = _get_count(before, after, "vllm:spec_decode_num_accepted_tokens")
return accept / draft if draft > 0 else 0.0
def _get_count(before: list[Metric], after: list[Metric], name: str) -> int:
before_val = next(m.value for m in before if m.name == name)
after_val = next(m.value for m in after if m.name == name)
return after_val - before_val

View File

@@ -0,0 +1,131 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test that verifies no implicit GPU-CPU synchronization occurs during
speculative decoding generation under expected conditions.
"""
import multiprocessing
import sys
import traceback
import pytest
import torch
@pytest.fixture
def sync_tracker():
"""
Fixture that patches CommonAttentionMetadata.seq_lens_cpu to detect
lazy init syncs. Prints stack traces immediately when syncs occur.
"""
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
# Shared counter for cross-process communication (inherited by fork)
sync_count = multiprocessing.Value("i", 0)
# Save original property
original_prop = CommonAttentionMetadata.seq_lens_cpu
original_fget = original_prop.fget
# Create tracking wrapper
def tracking_seq_lens_cpu(self):
if self._seq_lens_cpu is None:
# Increment counter
with sync_count.get_lock():
sync_count.value += 1
count = sync_count.value
# Print stack trace immediately (shows in subprocess output)
print(f"\n{'=' * 60}", file=sys.stderr)
print(f"SYNC #{count}: seq_lens_cpu lazy init triggered!", file=sys.stderr)
print(f"{'=' * 60}", file=sys.stderr)
traceback.print_stack(file=sys.stderr)
print(f"{'=' * 60}\n", file=sys.stderr)
sys.stderr.flush()
return original_fget(self)
# Apply patch
CommonAttentionMetadata.seq_lens_cpu = property(tracking_seq_lens_cpu)
class SyncTracker:
@property
def count(self) -> int:
return sync_count.value
def assert_no_sync(self, msg: str = ""):
count = sync_count.value
assert count == 0, (
f"Unexpected GPU-CPU sync: seq_lens_cpu lazy init triggered "
f"{count} times. See stack traces above. {msg}"
)
yield SyncTracker()
# Restore original property
CommonAttentionMetadata.seq_lens_cpu = original_prop
torch._dynamo.reset()
# Test configurations: (model, spec_model, method, num_spec_tokens, backend_env)
SPEC_DECODE_CONFIGS = [
pytest.param(
"meta-llama/Llama-3.2-1B-Instruct",
"nm-testing/Llama3_2_1B_speculator.eagle3",
"eagle3",
2,
id="eagle3-llama",
),
pytest.param(
"eagle618/deepseek-v3-random",
"eagle618/eagle-deepseek-v3-random",
"eagle",
2,
id="eagle-mla-deepseek",
),
]
@pytest.mark.parametrize(
"model,spec_model,method,num_spec_tokens",
SPEC_DECODE_CONFIGS,
)
def test_no_sync_with_spec_decode(
sync_tracker,
model: str,
spec_model: str,
method: str,
num_spec_tokens: int,
):
"""
Test that no implicit GPU-CPU sync occurs during speculative decoding
generation.
"""
# Import vLLM AFTER sync_tracker fixture has applied the patch
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
llm = LLM(
model=model,
max_model_len=256,
speculative_config={
"method": method,
"num_speculative_tokens": num_spec_tokens,
"model": spec_model,
},
enforce_eager=True,
async_scheduling=True,
)
outputs = llm.generate(
["Hello, my name is"],
SamplingParams(temperature=0, max_tokens=10),
)
assert len(outputs) == 1
assert len(outputs[0].outputs[0].text) > 0
del llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
sync_tracker.assert_no_sync()

View File

@@ -0,0 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm import LLM, SamplingParams
from ...utils import create_new_process_for_each_test
@create_new_process_for_each_test()
@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "FLASHINFER"])
def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:"
if attn_backend == "FLASHINFER":
pytest.skip(
"This test is failing with FlashInfer backend and "
"needs investigation. See issue #25679."
)
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct")
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
# No cascade attention.
single_prompt = [example_system_message + prompt]
responses = llm.generate(single_prompt, sampling_params)
ref_output = responses[0].outputs[0].text
# (Probably) Use cascade attention.
prompts = [example_system_message + prompt] * 64
responses = llm.generate(prompts, sampling_params)
for response in responses:
assert response.outputs[0].text == ref_output

View File

@@ -0,0 +1,63 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for vLLM `vllm/v1/engine/processor.Processor._validate_model_input()`
handling of maximum context length for decoder models.
This test ensures:
- A prompt that is one token shorter than the model's maximum context length
can be processed successfully when requesting one additional token.
- A prompt that reaches the model's maximum context length throws a
`ValueError` when requesting at least one additional token.
"""
import pytest
from tests.conftest import VllmRunner
from tests.utils import create_new_process_for_each_test
@create_new_process_for_each_test()
@pytest.mark.parametrize("model, max_model_len", [("JackFram/llama-160m", 2048)])
@pytest.mark.parametrize(
"prompt_len, max_tokens",
[
(2047, 1), # prompt_len = max_model_len - 1 -> allowed
(2048, 1), # prompt_len = max_model_len -> not allowed
],
)
def test_decoder_max_context_length_validation(
model: str,
max_model_len: int,
vllm_runner: type[VllmRunner],
prompt_len: int,
max_tokens: int,
) -> None:
"""Check vLLM decoder model input validation for edge cases where
the prompt length is (almost) equal to the max model length."""
prompt_ids = [[43] * prompt_len]
with vllm_runner(
model_name=model,
tokenizer_name=model,
max_model_len=max_model_len,
max_num_seqs=1,
tensor_parallel_size=1,
) as vllm_model:
if prompt_len + max_tokens <= max_model_len:
# Should succeed as constraints are met
vllm_model.generate_greedy(prompt_ids, max_tokens)
else:
# Should raise the ValueError defined in
# vllm/v1/engine/processor.Processor_validate_model_input()
expected_msg = (
f"The decoder prompt (length {prompt_len}) plus the number of "
f"requested output tokens (at least 1) is longer than "
f"the maximum model length of {max_model_len}. "
"Make sure that `max_model_len` is no smaller than the number of "
"text tokens (prompt + requested output tokens)."
)
with pytest.raises(ValueError) as excinfo:
vllm_model.generate_greedy(prompt_ids, max_tokens)
assert expected_msg in str(excinfo.value)

View File

@@ -0,0 +1,98 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import pytest
from vllm import LLM, SamplingParams
from vllm.platforms import current_platform
from ...utils import check_answers, prep_prompts
@dataclass
class TestConfig:
sliding_window: int
ln_range: tuple[int, int]
model_config = {
"bigcode/starcoder2-3b": TestConfig(4096, (800, 1100)),
"google/gemma-3-1b-it": TestConfig(4096, (400, 800)),
}
@pytest.mark.parametrize(
"model",
[
"bigcode/starcoder2-3b", # sliding window only
"google/gemma-3-1b-it", # sliding window + full attention
],
)
@pytest.mark.parametrize("batch_size", [5])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("disable_hybrid_kv_cache_manager", [True, False])
def test_sliding_window_retrieval(
model, batch_size, seed, disable_hybrid_kv_cache_manager
):
"""
The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then
asks for value of one of them (which is outside the sliding window).
If we tell it upfront which we are going to be looking for, then
it answers correctly (mostly).
"""
# NOTE: For ROCm, we have to enforce eager mode to use custom kernel
# implementation of GELU with tanh approximation, as PyTorch's native
# implementation is currently unstable with torch.compile and produces garbage.
enforce_eager = current_platform.is_rocm()
test_config = model_config[model]
llm = LLM(
model=model,
disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager,
enforce_eager=enforce_eager,
)
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
prompts, answer, indices = prep_prompts(batch_size, ln_range=test_config.ln_range)
check_length(prompts, llm, test_config.sliding_window)
# Fresh generation
responses = llm.generate(prompts, sampling_params)
check_answers(
indices,
answer,
[response.outputs[0].text for response in responses],
accept_rate=1.0,
)
# Re-generate with the same prompts to test prefix caching
responses = llm.generate(prompts, sampling_params)
check_answers(
indices,
answer,
[response.outputs[0].text for response in responses],
accept_rate=1.0,
)
def check_length(prompts: list[str], llm: LLM, sliding_window: int):
"""
Check if the prompt length is valid, i.e., longer than the sliding window
size and shorter than the model's max length.
Args:
prompts: list of prompts
llm: LLM object
sliding_window: Sliding window size
"""
tokenizer = llm.get_tokenizer()
max_model_len = llm.llm_engine.model_config.max_model_len
assert any(len(tokenizer.encode(prompt)) > sliding_window for prompt in prompts), (
"Prompt is too short for test"
)
assert all(len(tokenizer.encode(prompt)) <= max_model_len for prompt in prompts), (
"Prompt is too long for test"
)

View File

@@ -0,0 +1,101 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import pytest
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationMode
from vllm.platforms import current_platform
from ...utils import check_answers, fork_new_process_for_each_test, prep_prompts
# global seed
SEED = 42
@pytest.fixture
def test_prompts():
"""
Adapted from tests/v1/e2e/test_spec_decode.py
"""
prompt_types = ["repeat", "sentence"]
# Setting higher num prompts increases the chance of numerics mismatch
# due to matrix multiplication numerics depending on batch dimension
num_prompts = 10
prompts = []
random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
for kind in random_prompt_type_choices:
word_choices = ["test", "temp", "hello", "where"]
word = random.choice(word_choices)
if kind == "repeat":
prompt = f"""please repeat the word '{word}' 10 times."""
elif kind == "sentence":
prompt = f"""please give a ten-word sentence that
uses the word {word} at least once."""
else:
raise ValueError(f"Unknown prompt type: {kind}")
prompts.append(prompt)
return prompts
use_fork_for_test = (
fork_new_process_for_each_test if not current_platform.is_rocm() else lambda x: x
)
@use_fork_for_test
@pytest.mark.parametrize("kv_sharing_fast_prefill", [False, True])
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_kv_sharing_fast_prefill(
monkeypatch: pytest.MonkeyPatch,
kv_sharing_fast_prefill: bool,
enforce_eager: bool,
):
if not enforce_eager and current_platform.is_rocm():
# Relevant context: https://github.com/vllm-project/vllm/pull/29244
pytest.skip(
"ROCm: torch.compile produces incorrect output for gemma-3n's GELU "
"with tanh approximation. Use enforce_eager=True instead."
)
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
compilation_config = CompilationConfig(
# This allows vLLM compilation backend to handle allocating and
# managing buffers for cudagraph
cudagraph_copy_inputs=True,
mode=CompilationMode.VLLM_COMPILE
if not enforce_eager
else CompilationMode.NONE,
)
batch_size = 10
with monkeypatch.context() as m:
# Make scheduling deterministic for reproducibility
if current_platform.is_rocm():
# Use spawn to prevent cuda re-initialization error
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
else:
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
prompts, answer, indices = prep_prompts(batch_size)
llm = LLM(
model="google/gemma-3n-E2B-it",
enforce_eager=enforce_eager,
compilation_config=compilation_config,
seed=SEED,
kv_sharing_fast_prefill=kv_sharing_fast_prefill,
)
responses = llm.generate(prompts, sampling_params)
check_answers(
indices,
answer,
[response.outputs[0].text for response in responses],
accept_rate=1.0,
)

View File

@@ -0,0 +1,139 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This script contains:
1. test lora with speculative decoding for batch inference
"""
import random
import numpy as np
import pytest
import torch
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform
LORA_TEST_PROMPT_MAP: dict[str, str] = {}
LORA_TEST_PROMPT_MAP["premjatin/qwen-linear-algebra-coder"] = """
### INSTRUCTION:
You are an AI assistant that generates Python code to solve linear
algebra problems.
### PROBLEM:
Find the eigenvalues and eigenvectors of the following 3x3 matrix:
[[3, 2, 0],
[2, 3, 0],
[0, 0, 2]]
### OUTPUT FORMAT (STRICT):
Numbers should be represented as integers only.
### PYTHON SOLUTION:
"""
SEED = 42
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
@pytest.mark.parametrize(
"model_setup",
[
(
"eagle3",
"Qwen/Qwen3-1.7B",
"AngelSlim/Qwen3-1.7B_eagle3",
"premjatin/qwen-linear-algebra-coder",
1,
)
],
)
def test_batch_inference_correctness(
monkeypatch: pytest.MonkeyPatch,
model_setup: tuple[str, str, str, str, int],
):
"""
Compare the outputs of a LLM with only Lora and a LLM with both SD and Lora.
Should be the same and no failure when doing batch inference.
model_setup: (method, model_name, spec_model_name, lora_path, tp_size)
"""
with monkeypatch.context() as m:
# Disable randomness
m.setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
method, model_name, spec_model_name, lora_path, tp_size = model_setup
# without speculative decoding
ref_llm = LLM(
model=model_name,
trust_remote_code=True,
tensor_parallel_size=tp_size,
max_model_len=2048,
max_num_seqs=4,
enable_lora=True,
max_loras=1,
max_cpu_loras=1,
max_lora_rank=16,
)
prompts = [LORA_TEST_PROMPT_MAP[lora_path]] * 100
lora_request = LoRARequest("adapter", 1, lora_path)
sampling_params = SamplingParams(
temperature=0.0, top_p=1.0, top_k=-1, seed=SEED, max_tokens=128
)
ref_outputs = ref_llm.generate(
prompts, sampling_params, lora_request=lora_request
)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
lora_spec_llm = LLM(
model=model_name,
trust_remote_code=True,
tensor_parallel_size=tp_size,
speculative_config={
"method": method,
"model": spec_model_name,
"num_speculative_tokens": 3,
"max_model_len": 2048,
},
max_model_len=2048,
max_num_seqs=4,
enable_lora=True,
max_loras=1,
max_cpu_loras=1,
max_lora_rank=16,
)
lora_spec_outputs = lora_spec_llm.generate(
prompts, sampling_params, lora_request=lora_request
)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, lora_spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 90% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
print(f"match ratio: {matches}/{len(ref_outputs)}")
assert matches > int(0.90 * len(ref_outputs))
del lora_spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

View File

@@ -0,0 +1,502 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Comprehensive end-to-end tests for `min_tokens` in the V1 engine.
Addresses #21950: verify and add CI coverage.
Covers:
1) Basic functionality
2) Stop strings with `min_tokens` (bug #21987; fix in PR #22014)
3) EOS behavior with `min_tokens` (potential logits-processor bug)
4) Edge cases (min_tokens == max_tokens, min_tokens == 0)
5) Multiple stop conditions
"""
import pytest
from vllm import LLM, SamplingParams
from vllm.outputs import RequestOutput
# Test configuration
TEST_MODEL = "facebook/opt-125m" # Small model for fast CI execution
GREEDY = 0.0 # Deterministic generation for consistent testing
class MinTokensTestCase:
"""Data class for min_tokens test scenarios"""
def __init__(
self,
name: str,
min_tokens: int,
max_tokens: int,
stop: str | list[str] | None = None,
expected_min_len: int | None = None,
expected_exact_len: int | None = None,
):
self.name = name
self.min_tokens = min_tokens
self.max_tokens = max_tokens
self.stop = stop
self.expected_min_len = expected_min_len or min_tokens
self.expected_exact_len = expected_exact_len
def __str__(self):
return (
f"{self.name}: min={self.min_tokens}, "
f"max={self.max_tokens}, stop={self.stop}"
)
# Test scenarios covering all critical cases
MIN_TOKENS_TEST_CASES = [
# === BASIC FUNCTIONALITY (should work) ===
MinTokensTestCase(
name="basic_min_tokens_no_stop",
min_tokens=8,
max_tokens=20,
stop=None,
expected_min_len=8,
),
MinTokensTestCase(
name="min_tokens_zero",
min_tokens=0,
max_tokens=10,
stop=None,
expected_min_len=0,
),
MinTokensTestCase(
name="min_equals_max_no_stop",
min_tokens=15,
max_tokens=15,
stop=None,
expected_exact_len=15,
),
# === STOP STRINGS WITH MIN_TOKENS ===
# These tests expose the detokenizer bug where stop strings
# bypass min_tokens
# Using mathematically guaranteed approach with wide stop nets
pytest.param(
MinTokensTestCase(
name="min_tokens_with_comprehensive_stops",
min_tokens=5,
max_tokens=20,
stop=[
"a",
"e",
"i",
"o",
"u",
"t",
"n",
"s",
"r",
"l",
" ",
],
expected_min_len=5,
),
marks=pytest.mark.xfail(
reason=(
"Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)"
),
strict=False,
),
id="min_tokens_with_comprehensive_stops",
),
pytest.param(
MinTokensTestCase(
name="min_tokens_with_simple_char_stop",
min_tokens=3,
max_tokens=15,
stop=["e", "a", " "],
expected_min_len=3,
),
marks=pytest.mark.xfail(
reason=(
"Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)"
),
strict=False,
),
id="min_tokens_with_simple_char_stop",
),
# === EOS TOKEN WITH MIN_TOKENS (potential LogitsProcessor bug) ===
# These test the MinTokensLogitsProcessor handling of EOS tokens
pytest.param(
MinTokensTestCase(
name="min_equals_max_eos_only",
min_tokens=20,
max_tokens=20,
stop=None, # Relies on default EOS token behavior
expected_exact_len=20,
),
marks=pytest.mark.xfail(
reason=("Potential logits-processor bug: EOS tokens may bypass min_tokens"),
strict=False,
),
id="min_equals_max_eos_only",
),
# === EDGE CASES ===
MinTokensTestCase(
name="large_min_tokens",
min_tokens=50,
max_tokens=60,
stop=None,
expected_min_len=50,
),
MinTokensTestCase(
name="min_tokens_with_empty_stop_list",
min_tokens=5,
max_tokens=15,
stop=[], # Empty stop list
expected_min_len=5,
),
]
@pytest.fixture(scope="module")
def llm_v1():
"""Create V1 LLM instance for testing"""
llm = LLM(
model=TEST_MODEL,
tensor_parallel_size=1,
max_model_len=1024, # Small context for fast testing
enforce_eager=True, # Avoid graph compilation overhead
)
return llm
def get_token_count(output: RequestOutput) -> int:
"""Extract token count from LLM output"""
if not output.outputs:
return 0
return len(output.outputs[0].token_ids)
def assert_min_tokens_satisfied(
output: RequestOutput, test_case: MinTokensTestCase
) -> None:
"""Assert that min_tokens requirement is satisfied"""
token_count = get_token_count(output)
stop_reason = output.outputs[0].stop_reason if output.outputs else "no output"
if test_case.expected_exact_len is not None:
# Exact length requirement
assert token_count == test_case.expected_exact_len, (
f"Expected exactly {test_case.expected_exact_len} tokens, "
f"got {token_count} tokens. "
f"Stop reason: {stop_reason}"
)
else:
# Minimum length requirement
assert token_count >= (test_case.expected_min_len or 0), (
f"Expected at least {test_case.expected_min_len} tokens, "
f"got {token_count} tokens. "
f"Stop reason: {stop_reason}"
)
@pytest.mark.parametrize(
"test_case",
MIN_TOKENS_TEST_CASES,
ids=lambda tc: tc.name,
)
def test_min_tokens_comprehensive(llm_v1: LLM, test_case: MinTokensTestCase):
"""
Comprehensive test for min_tokens functionality in V1 engine.
This test covers all critical scenarios for min_tokens:
- Basic functionality (should work)
- Stop strings with min_tokens (known bug)
- EOS tokens with min_tokens (potential bug)
- Edge cases
Args:
llm_v1: V1 LLM instance
test_case: Test scenario parameters
"""
# Known failing cases are handled via param-level xfail marks above.
# Create sampling parameters
sampling_params = SamplingParams(
min_tokens=test_case.min_tokens,
max_tokens=test_case.max_tokens,
stop=test_case.stop,
temperature=GREEDY,
include_stop_str_in_output=True, # Include stop strings for debugging
)
# Use simple prompt. Comprehensive stop lists should catch any generation
prompt = "Hello"
# Generate output
outputs = llm_v1.generate([prompt], sampling_params)
assert len(outputs) == 1, "Expected exactly one output"
output = outputs[0]
# Debug information
token_count = get_token_count(output)
generated_text = output.outputs[0].text if output.outputs else ""
stop_reason = output.outputs[0].stop_reason if output.outputs else "unknown"
print(f"\nTest: {test_case.name}")
print(f"Generated {token_count} tokens")
print(f"Stop reason: {stop_reason}")
print(f"Generated text: {repr(generated_text)}")
print(f"Expected min: {test_case.expected_min_len}")
if test_case.expected_exact_len:
print(f"Expected exact: {test_case.expected_exact_len}")
# Validate min_tokens requirement
assert_min_tokens_satisfied(output, test_case)
def test_min_tokens_basic_functionality(llm_v1: LLM):
"""
Test basic min_tokens functionality without stop conditions.
This is a baseline test that should always pass and validates
that min_tokens works correctly in the simple case.
"""
sampling_params = SamplingParams(min_tokens=10, max_tokens=20, temperature=GREEDY)
prompt = "Once upon a time"
outputs = llm_v1.generate([prompt], sampling_params)
assert len(outputs) == 1
token_count = get_token_count(outputs[0])
assert token_count >= 10, f"Expected at least 10 tokens, got {token_count}"
assert token_count <= 20, f"Expected at most 20 tokens, got {token_count}"
@pytest.mark.xfail(
reason=("Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)"),
strict=False,
)
def test_min_tokens_stop_strings_bug(llm_v1: LLM):
"""
Test the specific bug where stop strings bypass min_tokens.
This test specifically reproduces the bug Calvin is fixing in PR #22014.
It should fail until that fix is merged.
Strategy: Use guaranteed stop characters that will appear
in any generated text.
"""
# If the bug is fixed upstream, this test will XPASS
sampling_params = SamplingParams(
min_tokens=15,
max_tokens=50,
# Common letter; likely appears early
stop=["e"],
temperature=GREEDY,
include_stop_str_in_output=True,
)
# Simple prompt that will generate text containing "e"
prompt = "The quick brown fox"
outputs = llm_v1.generate([prompt], sampling_params)
assert len(outputs) == 1
token_count = get_token_count(outputs[0])
generated_text = outputs[0].outputs[0].text if outputs[0].outputs else ""
# Debug info to understand what happened
print(f"Generated text: {repr(generated_text)}")
print(f"Token count: {token_count}")
print(f"Contains 'e': {'e' in generated_text}")
# This assertion should fail due to the bug - if stop string is found early,
# the model should still continue generating until min_tokens is reached
stop_reason = (
outputs[0].outputs[0].stop_reason if outputs[0].outputs else "no output"
)
assert token_count >= 15, (
"Bug confirmed: "
f"{token_count} tokens < min_tokens=15. "
f"Reason: {stop_reason}. "
f"Text: {repr(generated_text)}"
)
@pytest.mark.xfail(
reason=("Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)"),
strict=False,
)
def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM):
"""
Guaranteed test for stop strings bypassing min_tokens bug.
Strategy: Use very low temperature and multiple common stop strings
to virtually guarantee early detection, combined with long min_tokens
to ensure the bug is exposed regardless of model behavior.
"""
# If the bug is fixed upstream, this test will XPASS
sampling_params = SamplingParams(
min_tokens=50, # Set high min_tokens to ensure bug detection
max_tokens=200,
# Use multiple very common patterns - at least one will appear
stop=["e", "a", "i", "o", "u", " ", "t", "n", "s", "r"],
temperature=GREEDY,
include_stop_str_in_output=True,
)
# Simple prompt that will generate some text
prompt = "The cat"
outputs = llm_v1.generate([prompt], sampling_params)
assert len(outputs) == 1
token_count = get_token_count(outputs[0])
generated_text = outputs[0].outputs[0].text if outputs[0].outputs else ""
stop_reason = outputs[0].outputs[0].stop_reason if outputs[0].outputs else "unknown"
print(f"Generated text: {repr(generated_text)}")
print(f"Token count: {token_count}")
print(f"Stop reason: {stop_reason}")
# With the bug, this will fail because ANY of the common characters
# will trigger early termination before min_tokens=50 is reached
# It's virtually impossible to generate 50 tokens without hitting
# at least one of: e, a, i, o, u, space, t, n, s, r
finish_reason = (
outputs[0].outputs[0].finish_reason if outputs[0].outputs else "unknown"
)
print(f"Finish reason: {finish_reason}")
if finish_reason == "stop":
assert token_count >= 50, (
"Bug confirmed: "
f"{token_count} tokens < min_tokens=50. "
f"Reason: {finish_reason}. "
f"Text: {repr(generated_text)}"
)
@pytest.mark.xfail(
reason=("Potential logits-processor bug: EOS tokens may bypass min_tokens"),
strict=False,
)
def test_min_tokens_eos_behavior(llm_v1: LLM):
"""
Verify EOS handling with and without min_tokens.
- Without min_tokens: expect early EOS -> finish_reason == "stop",
stop_reason is None, and generated tokens < max_tokens (25).
- With min_tokens: EOS should be blocked until min_tokens is reached
(finish_reason == "length"); verify that eos_token_id does not appear
in generated token_ids.
"""
# tokenizer + eos id
tokenizer = llm_v1.get_tokenizer()
eos_token_id = tokenizer.eos_token_id
prompt = "Give a file extension."
max_toks = 32
# Case 1: WITHOUT min_tokens
sp_no_min = SamplingParams(
max_tokens=max_toks,
temperature=GREEDY,
)
out_no_min = llm_v1.generate([prompt], sp_no_min)
assert len(out_no_min) == 1
choice_no_min = out_no_min[0].outputs[0]
ids_no_min = choice_no_min.token_ids or []
finish_no_min = choice_no_min.finish_reason
stop_no_min = choice_no_min.stop_reason
print(
"[no-min] tokens=",
len(ids_no_min),
" finish=",
finish_no_min,
" stop_reason=",
stop_no_min,
)
assert finish_no_min == "stop", (
f"Expected finish_reason 'stop' without min_tokens, got {finish_no_min}"
)
assert stop_no_min is None, (
"For EOS-based stop (no user stop strings), stop_reason should be None."
)
assert len(ids_no_min) < max_toks, (
f"Expected early EOS with < {max_toks} tokens, got {len(ids_no_min)}"
)
# Case 2: WITH min_tokens
sp_with_min = SamplingParams(
min_tokens=max_toks,
max_tokens=max_toks,
temperature=GREEDY,
)
out_with_min = llm_v1.generate([prompt], sp_with_min)
assert len(out_with_min) == 1
choice_with_min = out_with_min[0].outputs[0]
ids_with_min = choice_with_min.token_ids or []
finish_with_min = choice_with_min.finish_reason
stop_with_min = choice_with_min.stop_reason
print(
"[with-min] tokens=",
len(ids_with_min),
" finish=",
finish_with_min,
" stop_reason=",
stop_with_min,
)
# Exact length reached; EOS should have been blocked
assert len(ids_with_min) == max_toks, (
f"Expected exactly {max_toks} tokens with min_tokens; got {len(ids_with_min)}"
)
assert finish_with_min == "length", (
f"Expected finish_reason 'length'; got {finish_with_min}"
)
assert eos_token_id not in ids_with_min, (
"EOS token id should not appear when min_tokens prevents early EOS."
)
def test_min_tokens_validation():
"""
Test that SamplingParams correctly validates min_tokens parameters.
This tests the parameter validation logic in SamplingParams.
"""
# Valid cases
SamplingParams(min_tokens=0, max_tokens=10)
SamplingParams(min_tokens=5, max_tokens=10)
SamplingParams(min_tokens=10, max_tokens=10)
# Invalid cases
with pytest.raises(
ValueError,
match="min_tokens must be greater than or equal to 0",
):
SamplingParams(min_tokens=-1, max_tokens=10)
with pytest.raises(
ValueError,
match="min_tokens must be less than or equal to max_tokens",
):
SamplingParams(min_tokens=15, max_tokens=10)
if __name__ == "__main__":
"""
Run tests locally for development.
Usage:
cd vllm/
python -m pytest tests/v1/e2e/test_min_tokens.py -v
"""
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,167 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch.nn as nn
from vllm.platforms import current_platform
prompt = """
Generals gathered in their masses
Just like witches at black masses
Evil minds that plot destruction
Sorcerer of death's construction
In the fields, the bodies burning
As the war machine keeps turning
Death and hatred to mankind
Poisoning their brainwashed minds
Oh, Lord, yeah
Politicians hide themselves away
They only started the war
Why should they go out to fight?
They leave that all to the poor, yeah
Time will tell on their power minds
Making war just for fun
Treating people just like pawns in chess
Wait till their judgment day comes, yeah
Now, in darkness, world stops turning
Ashes where their bodies burning
No more war pigs have the power
Hand of God has struck the hour
Day of Judgment, God is calling
On their knees, the war pigs crawling
Begging mercies for their sins
Satan, laughing, spreads his wings
Oh, Lord, yeah
"""
class WrapperPooler(nn.Module):
def __init__(self, pooler):
super().__init__()
self.pooler = pooler
self.chunks = []
def get_pooling_updates(self, task):
return self.pooler.get_pooling_updates(task)
def forward(
self,
hidden_states,
pooling_metadata,
):
self.chunks.append(hidden_states.shape[0])
return self.pooler(hidden_states, pooling_metadata)
def inject_pooler(self):
model = self.get_model()
wrapper = WrapperPooler(model.pooler)
model.pooler = wrapper
def retrieve_chunks(self):
model = self.get_model()
chunks = model.pooler.chunks
model.pooler.chunks = []
return chunks
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
def test_pooling_chunked_prefill(vllm_runner, monkeypatch):
"""Test chunked prefill for pooling models with LastPool."""
with monkeypatch.context() as m:
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
model_id = "Qwen/Qwen3-Embedding-0.6B"
chunk_size = 10
# Set chunking parameters to force chunked prefill
# Note: Chunked prefill is automatically handled by vLLM
# internally based on the model size and prompt
with vllm_runner(
model_id,
runner="pooling",
long_prefill_token_threshold=chunk_size,
tensor_parallel_size=1,
enforce_eager=True,
enable_chunked_prefill=True,
) as llm:
llm.get_llm().llm_engine.collective_rpc(inject_pooler)
tokenizer = llm.get_llm().get_tokenizer()
tokens = tokenizer(prompt)["input_ids"]
prompt_len = len(tokens)
full_chunks, last_chunk = divmod(prompt_len, chunk_size)
expected_chunks = [chunk_size] * full_chunks
if last_chunk:
expected_chunks.append(last_chunk)
llm.embed([prompt])
chunks = llm.get_llm().llm_engine.collective_rpc(retrieve_chunks)[0]
# Check that PoolerWrapper was called and chunks were received
assert len(chunks) > 1
assert chunks == expected_chunks
# Disable chunked prefill
with vllm_runner(
model_id,
runner="pooling",
tensor_parallel_size=1,
enforce_eager=True,
) as llm:
llm.get_llm().llm_engine.collective_rpc(inject_pooler)
llm.embed([prompt])
chunks = llm.get_llm().llm_engine.collective_rpc(retrieve_chunks)[0]
# Check that PoolerWrapper was called and no chunks were received
assert len(chunks) == 1
assert chunks[0] == prompt_len
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
def test_pooling_prefix_cache(vllm_runner, monkeypatch):
"""Test chunked prefill for pooling models with LastPool."""
verses = prompt.split("\n\n")
with monkeypatch.context() as m:
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
model_id = "Qwen/Qwen3-Embedding-0.6B"
with vllm_runner(
model_id,
runner="pooling",
enable_prefix_caching=True,
tensor_parallel_size=1,
enforce_eager=True,
) as llm:
llm.get_llm().llm_engine.collective_rpc(inject_pooler)
tokenizer = llm.get_llm().get_tokenizer()
prompt1 = "\n\n".join([verses[0], verses[1]])
prompt2 = "\n\n".join([verses[0], verses[2]])
tokens1 = tokenizer(prompt1)["input_ids"]
tokens2 = tokenizer(prompt2)["input_ids"]
prompt1_len = len(tokens1)
prompt2_len = len(tokens2)
llm.embed([prompt1])
chunks = llm.get_llm().llm_engine.collective_rpc(retrieve_chunks)[0]
assert len(chunks) == 1
assert chunks[0] == prompt1_len
llm.embed([prompt2])
chunks = llm.get_llm().llm_engine.collective_rpc(retrieve_chunks)[0]
assert len(chunks) == 1
assert chunks[0] <= prompt1_len
assert chunks[0] < prompt2_len
cache_config = llm.get_llm().llm_engine.cache_config
print(f"{cache_config=}")
# Prefixes are cached in blocks
assert (prompt2_len - chunks[0]) % cache_config.block_size == 0

View File

@@ -0,0 +1,580 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
from typing import Any
import pytest
import torch
from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.platforms import current_platform
MTP_SIMILARITY_RATE = 0.8
def _skip_if_insufficient_gpus_for_tp(tp_size: int):
"""Skip test if available GPUs < tp_size on ROCm."""
if current_platform.is_rocm():
available_gpus = torch.cuda.device_count()
if available_gpus < tp_size:
pytest.skip(
f"Test requires {tp_size} GPUs, but only {available_gpus} available"
)
def get_test_prompts(mm_enabled: bool):
prompt_types = ["repeat", "sentence"]
if mm_enabled:
prompt_types.append("mm")
num_prompts = 100
prompts = []
random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
print(f"Prompt types: {random_prompt_type_choices}")
# Generate a mixed batch of prompts, some of which can be easily
# predicted by n-gram matching and some which likely cannot.
for kind in random_prompt_type_choices:
word_choices = ["test", "temp", "hello", "where"]
word = random.choice(word_choices)
prompt: str | list[dict[str, Any]] = ""
if kind == "repeat":
prompt = f"""
please repeat the word '{word}' 10 times.
give no other output than the word at least ten times in a row,
in lowercase with spaces between each word and without quotes.
"""
elif kind == "sentence":
prompt = f"""
please give a ten-word sentence that
uses the word {word} at least once.
give no other output than that simple sentence without quotes.
"""
elif kind == "mm":
placeholders = [
{
"type": "image_url",
"image_url": {
"url": f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
},
}
]
prompt = [
*placeholders,
{"type": "text", "text": "The meaning of the image is"},
]
else:
raise ValueError(f"Unknown prompt type: {kind}")
prompts.append([{"role": "user", "content": prompt}])
return prompts
@pytest.fixture
def sampling_config():
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
@pytest.fixture
def model_name():
return "meta-llama/Llama-3.1-8B-Instruct"
@pytest.fixture(autouse=True)
def reset_torch_dynamo():
"""Reset torch dynamo cache before each test"""
yield
# Cleanup after test
torch._dynamo.reset()
@pytest.mark.parametrize(
"speculative_config",
[
{
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
{
"method": "suffix",
"suffix_decoding_max_spec_factor": 2.0,
},
],
)
def test_ngram_and_suffix_correctness(
speculative_config: dict,
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_name: str,
):
"""
Compare the outputs of an original LLM and a speculative LLM
should be the same when using ngram speculative decoding.
"""
test_prompts = get_test_prompts(mm_enabled=False)
ref_llm = LLM(model=model_name, max_model_len=1024)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
spec_llm = LLM(
model=model_name,
speculative_config=speculative_config,
max_model_len=1024,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches >= int(0.66 * len(ref_outputs))
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
def test_suffix_decoding_acceptance(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_name: str,
):
"""
Check that suffix decoding caching takes effect and improves acceptance
lengths and acceptance rates over multiple runs of the same prompts.
"""
test_prompts = get_test_prompts(mm_enabled=False)
spec_llm = LLM(
model=model_name,
speculative_config={
"method": "suffix",
"suffix_decoding_max_spec_factor": 2.0,
"suffix_decoding_max_cached_requests": 1000,
},
max_model_len=1024,
disable_log_stats=False,
)
# Run several times and check that the accepted tokens increase.
num_draft = []
num_accept = []
for i in range(10): # Run multiple times to warm up the cache.
spec_llm.chat(test_prompts, sampling_config)
# Collect draft and acceptance stats.
metrics = spec_llm.get_metrics()
for metric in metrics:
if metric.name == "vllm:spec_decode_num_draft_tokens":
num_draft.append(metric.value)
if metric.name == "vllm:spec_decode_num_accepted_tokens":
num_accept.append(metric.value)
# Calculate the acceptance rates for the first and last runs.
first_accept_tokens = num_accept[0]
first_draft_tokens = num_draft[0]
first_accept_rate = first_accept_tokens / first_draft_tokens
# Take the diff since the stats are cumulative.
last_accept_tokens = num_accept[-1] - num_accept[-2]
last_draft_tokens = num_draft[-1] - num_draft[-2]
last_accept_rate = last_accept_tokens / last_draft_tokens
# Expect the acceptance length to improve.
assert first_accept_tokens < last_accept_tokens
# Expect the acceptance rate to improve.
assert first_accept_rate < last_accept_rate
# Heuristic: expect at least 80.0% acceptance rate at the end.
assert last_accept_rate > 0.80
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
@pytest.mark.parametrize(
"model_path",
[
"RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3",
"RedHatAI/Qwen3-8B-speculator.eagle3",
],
ids=["llama3_eagle3_speculator", "qwen3_eagle3_speculator"],
)
def test_speculators_model_integration(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_path: str,
):
"""
Test that speculators models work with the simplified integration.
This verifies the `vllm serve <speculator-model>` use case where
speculative config is automatically detected from the model config
without requiring explicit --speculative-config argument.
Tests:
1. Speculator model is correctly detected
2. Verifier model is extracted from speculator config
3. Speculative decoding is automatically enabled
4. Text generation works correctly
5. Output matches reference (non-speculative) generation
"""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
# Generate test prompts
test_prompts = get_test_prompts(mm_enabled=False)
# First run: Direct speculator model (simplified integration)
spec_llm = LLM(model=model_path, max_model_len=1024)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
# Verify speculative config was auto-detected
assert spec_llm.llm_engine.vllm_config.speculative_config is not None, (
f"Speculative config should be auto-detected for {model_path}"
)
spec_config = spec_llm.llm_engine.vllm_config.speculative_config
assert spec_config.num_speculative_tokens > 0, (
f"Expected positive speculative tokens, "
f"got {spec_config.num_speculative_tokens}"
)
# Verify draft model is set to the speculator model
assert spec_config.model == model_path, (
f"Draft model should be {model_path}, got {spec_config.model}"
)
# Extract verifier model for reference run
verifier_model = spec_llm.llm_engine.vllm_config.model_config.model
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
# Second run: Reference without speculative decoding
ref_llm = LLM(model=verifier_model, max_model_len=1024)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
# Compare outputs
matches = sum(
1
for ref, spec in zip(ref_outputs, spec_outputs)
if ref.outputs[0].text == spec.outputs[0].text
)
# Heuristic: expect at least 66% of prompts to match exactly
assert matches >= int(0.66 * len(ref_outputs)), (
f"Only {matches}/{len(ref_outputs)} outputs matched. "
f"Expected at least {int(0.66 * len(ref_outputs))} matches."
)
@pytest.mark.parametrize(
["model_setup", "mm_enabled", "enable_chunked_prefill", "model_impl"],
[
(
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
False,
False,
"auto",
),
(
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
False,
False,
"transformers",
),
pytest.param(
(
"eagle3",
"Qwen/Qwen3-VL-8B-Instruct",
"taobao-mnn/Qwen3-VL-8B-Instruct-Eagle3",
1,
),
False,
False,
"auto",
marks=pytest.mark.skip(
reason="architecture of its eagle3 is LlamaForCausalLMEagle3"
),
),
pytest.param(
(
"eagle3",
"Qwen/Qwen2.5-VL-7B-Instruct",
"Rayzl/qwen2.5-vl-7b-eagle3-sgl",
1,
),
False,
False,
"auto",
marks=pytest.mark.skip(
reason="Skipping due to its head_dim not being a a multiple of 32"
),
),
pytest.param(
(
"eagle",
"meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
1,
),
False,
True,
"auto",
marks=large_gpu_mark(min_gb=40),
), # works on 4x H100
(
(
"eagle3",
"meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
1,
),
False,
False,
"auto",
),
pytest.param(
(
"eagle",
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
4,
),
False,
False,
"auto",
marks=large_gpu_mark(min_gb=80),
), # works on 4x H100
pytest.param(
(
"eagle",
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
4,
),
True,
True,
"auto",
marks=large_gpu_mark(min_gb=80),
), # works on 4x H100
(
(
"eagle",
"eagle618/deepseek-v3-random",
"eagle618/eagle-deepseek-v3-random",
1,
),
False,
False,
"auto",
),
],
ids=[
"qwen3_eagle3",
"qwen3_eagle3-transformers",
"qwen3_vl_eagle3",
"qwen2_5_vl_eagle3",
"llama3_eagle",
"llama3_eagle3",
"llama4_eagle",
"llama4_eagle_mm",
"deepseek_eagle",
],
)
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
def test_eagle_correctness(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_setup: tuple[str, str, str, int],
mm_enabled: bool,
enable_chunked_prefill: bool,
model_impl: str,
attn_backend: str,
):
if attn_backend == "TREE_ATTN":
# TODO: Fix this flaky test
pytest.skip(
"TREE_ATTN is flaky in the test disable for now until it can be "
"resolved (see https://github.com/vllm-project/vllm/issues/22922)"
)
if model_impl == "transformers":
import transformers
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("5.0.0.dev")
if installed < required:
pytest.skip(
"Eagle3 with the Transformers modeling backend requires "
f"transformers>={required}, but got {installed}"
)
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
"""
Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding.
model_setup: (method, model_name, eagle_model_name, tp_size)
"""
with monkeypatch.context() as m:
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
# Scout requires default backend selection
# because vision encoder has head_dim 88 being incompatible
# with FLASH_ATTN and needs to fall back to Flex Attn
# pass if not ROCm
if current_platform.is_rocm():
# TODO: Enable Flex Attn for spec_decode on ROCm
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
else:
m.setenv("VLLM_MLA_DISABLE", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
if "deepseek" in model_setup[1].lower():
pytest.skip("ROCM_AITER_FA for deepseek not supported on ROCm platform")
else:
m.setenv("VLLM_ROCM_USE_AITER", "1")
method, model_name, spec_model_name, tp_size = model_setup
_skip_if_insufficient_gpus_for_tp(tp_size)
max_model_len = 2048
max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len
ref_llm = LLM(
model=model_name, max_model_len=max_model_len, tensor_parallel_size=tp_size
)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
spec_llm = LLM(
model=model_name,
trust_remote_code=True,
tensor_parallel_size=tp_size,
speculative_config={
"method": method,
"model": spec_model_name,
"num_speculative_tokens": 3,
"max_model_len": max_model_len,
},
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill,
model_impl=model_impl,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 60% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.6 * len(ref_outputs))
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
@pytest.mark.parametrize(
["model_setup", "mm_enabled"],
[
(("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False),
(("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False),
],
ids=["mimo", "deepseek"],
)
def test_mtp_correctness(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_setup: tuple[str, str, int],
mm_enabled: bool,
):
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
"""
Compare the outputs of a original LLM and a speculative LLM
should be the same when using MTP speculative decoding.
model_setup: (method, model_name, tp_size)
"""
with monkeypatch.context() as m:
m.setenv("VLLM_MLA_DISABLE", "1")
method, model_name, tp_size = model_setup
_skip_if_insufficient_gpus_for_tp(tp_size)
ref_llm = LLM(
model=model_name,
max_model_len=2048,
tensor_parallel_size=tp_size,
trust_remote_code=True,
)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
spec_llm = LLM(
model=model_name,
trust_remote_code=True,
tensor_parallel_size=tp_size,
speculative_config={
"method": method,
"num_speculative_tokens": 1,
"max_model_len": 2048,
},
max_model_len=2048,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 80% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(MTP_SIMILARITY_RATE * len(ref_outputs))
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()