Sync from v0.13
This commit is contained in:
0
tests/v1/e2e/__init__.py
Normal file
0
tests/v1/e2e/__init__.py
Normal file
388
tests/v1/e2e/test_async_scheduling.py
Normal file
388
tests/v1/e2e/test_async_scheduling.py
Normal file
@@ -0,0 +1,388 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from itertools import repeat
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch._dynamo.config as dynamo_config
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import StructuredOutputsParams
|
||||
from vllm.v1.metrics.reader import Metric
|
||||
|
||||
from ...conftest import VllmRunner
|
||||
from ...models.utils import check_outputs_equal
|
||||
|
||||
MODEL = "Qwen/Qwen3-0.6B"
|
||||
MTP_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
|
||||
first_prompt = (
|
||||
"The following numbers of the sequence "
|
||||
+ ", ".join(str(i) for i in range(10))
|
||||
+ " are:"
|
||||
)
|
||||
example_prompts = [first_prompt, "In one word, the capital of France is "] + [
|
||||
f"Tell me about the number {i}: " for i in range(32)
|
||||
]
|
||||
|
||||
default_params = dict(
|
||||
temperature=0.0, # greedy
|
||||
max_tokens=23,
|
||||
min_tokens=18,
|
||||
)
|
||||
|
||||
|
||||
def test_without_spec_decoding(
|
||||
sample_json_schema,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""Test consistency of combos of async scheduling, preemption,
|
||||
uni/multiproc executor, prefill chunking."""
|
||||
struct_outputs = StructuredOutputsParams(json=sample_json_schema)
|
||||
test_sampling_params: list[dict[str, Any]] = [
|
||||
dict(),
|
||||
# dict(min_tokens=20),
|
||||
dict(presence_penalty=-1.0),
|
||||
dict(bad_words=["the", " the"]),
|
||||
dict(logprobs=2),
|
||||
dict(logprobs=2, presence_penalty=-1.0),
|
||||
dict(structured_outputs=struct_outputs),
|
||||
dict(
|
||||
structured_outputs=struct_outputs,
|
||||
logprobs=2,
|
||||
presence_penalty=-1.0,
|
||||
),
|
||||
]
|
||||
|
||||
# test_preemption, executor, async_scheduling,
|
||||
# spec_config, test_prefill_chunking
|
||||
test_configs = [
|
||||
(False, "mp", False, None, False),
|
||||
(True, "mp", False, None, True),
|
||||
(False, "mp", True, None, False),
|
||||
(False, "uni", True, None, False),
|
||||
(True, "mp", True, None, False),
|
||||
(True, "uni", True, None, False),
|
||||
(False, "mp", True, None, True),
|
||||
(True, "mp", True, None, True),
|
||||
(True, "uni", True, None, True),
|
||||
]
|
||||
|
||||
if current_platform.is_rocm():
|
||||
# On ROCm, Only test with structured_outputs (deterministic)
|
||||
# and skip chunk_prefill (more variable).
|
||||
test_configs = [
|
||||
cfg
|
||||
for cfg in test_configs
|
||||
if not cfg[4] # skip chunk_prefill=True
|
||||
]
|
||||
test_sampling_params = [
|
||||
p for p in test_sampling_params if p.get("structured_outputs") is not None
|
||||
]
|
||||
|
||||
run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)
|
||||
|
||||
|
||||
def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test consistency and acceptance rates with some different combos of
|
||||
preemption, executor, async scheduling, prefill chunking,
|
||||
spec decoding model length.
|
||||
"""
|
||||
|
||||
spec_config = {
|
||||
"method": "eagle3",
|
||||
"num_speculative_tokens": 2,
|
||||
"model": "nm-testing/Llama3_2_1B_speculator.eagle3",
|
||||
}
|
||||
# Set small draft model len to force doesn't-fit-in-drafter case.
|
||||
spec_config_short = spec_config | {"max_model_len": 50}
|
||||
|
||||
test_sampling_params = [
|
||||
dict(),
|
||||
dict(logprobs=2),
|
||||
]
|
||||
|
||||
# test_preemption, executor, async_scheduling,
|
||||
# spec_config, test_prefill_chunking
|
||||
test_configs = [
|
||||
(False, "mp", False, None, False),
|
||||
(False, "mp", False, spec_config, False),
|
||||
(True, "mp", False, spec_config, True),
|
||||
(True, "uni", False, spec_config_short, True),
|
||||
(False, "mp", True, spec_config, False),
|
||||
(True, "mp", True, spec_config, False),
|
||||
(False, "mp", True, spec_config_short, True),
|
||||
(True, "uni", True, spec_config, False),
|
||||
(True, "uni", True, spec_config_short, False),
|
||||
(True, "mp", True, spec_config, True),
|
||||
(True, "uni", True, spec_config_short, True),
|
||||
]
|
||||
|
||||
# On ROCm, use TRITON_ATTN + float32 for better numerical consistency
|
||||
run_tests(
|
||||
monkeypatch,
|
||||
MTP_MODEL,
|
||||
test_configs,
|
||||
test_sampling_params,
|
||||
is_testing_with_spec_decoding=True,
|
||||
)
|
||||
|
||||
|
||||
@dynamo_config.patch(cache_size_limit=16)
|
||||
def run_tests(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
model: str,
|
||||
test_configs: list[tuple],
|
||||
test_sampling_params: list[dict[str, Any]],
|
||||
is_testing_with_spec_decoding: bool = False,
|
||||
):
|
||||
"""Test consistency of combos of async scheduling, preemption,
|
||||
uni/multiproc executor with spec decoding."""
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
# avoid precision errors
|
||||
if current_platform.is_rocm():
|
||||
if is_testing_with_spec_decoding:
|
||||
# Use TRITON_ATTN for spec decoding test for consistency
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
|
||||
else:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_FA")
|
||||
else:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
||||
# lock matmul precision to full FP32 (IEEE)
|
||||
m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "ieee")
|
||||
# m.setenv("VLLM_BATCH_INVARIANT", "1")
|
||||
outputs: list[tuple[str, list, list]] = []
|
||||
for n, (
|
||||
test_preemption,
|
||||
executor,
|
||||
async_scheduling,
|
||||
spec_config,
|
||||
test_prefill_chunking,
|
||||
) in enumerate(test_configs, 1):
|
||||
test_str = f"{n}/{len(test_configs)}"
|
||||
test_results = run_test(
|
||||
model,
|
||||
test_str,
|
||||
test_sampling_params,
|
||||
test_preemption,
|
||||
executor,
|
||||
async_scheduling,
|
||||
spec_config,
|
||||
test_prefill_chunking=test_prefill_chunking,
|
||||
is_testing_with_spec_decoding=is_testing_with_spec_decoding,
|
||||
)
|
||||
outputs.append(test_results)
|
||||
|
||||
baseline_config, baseline_tests, _ = outputs[0]
|
||||
_, _, baseline_acceptances = next(
|
||||
(o for o in outputs if o[2] is not None), (None, None, None)
|
||||
)
|
||||
|
||||
print(f"BASELINE: config=[{baseline_config}], accept_rates={baseline_acceptances}")
|
||||
|
||||
failure = None
|
||||
for test_config, test_outputs, test_acceptance_rates in outputs[1:]:
|
||||
for (base_outs, base_logprobs), base_acceptance_rate, (
|
||||
test_outs,
|
||||
test_logprobs,
|
||||
), test_acceptance_rate, params in zip(
|
||||
baseline_tests,
|
||||
baseline_acceptances or repeat(None),
|
||||
test_outputs,
|
||||
test_acceptance_rates or repeat(None),
|
||||
test_sampling_params,
|
||||
):
|
||||
try:
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=base_outs,
|
||||
outputs_1_lst=test_outs,
|
||||
name_0=f"baseline=[{baseline_config}], params={params}",
|
||||
name_1=f"config=[{test_config}], params={params}",
|
||||
)
|
||||
|
||||
# On ROCm with TRITON_ATTN (spec decoding test), skip strict
|
||||
# logprobs comparison when logprobs are requested
|
||||
skip_logprobs_check = (
|
||||
current_platform.is_rocm()
|
||||
and params.get("logprobs")
|
||||
and is_testing_with_spec_decoding
|
||||
)
|
||||
if not skip_logprobs_check:
|
||||
assert _all_logprobs_match(base_logprobs, test_logprobs)
|
||||
|
||||
if (
|
||||
base_acceptance_rate is not None
|
||||
and test_acceptance_rate is not None
|
||||
):
|
||||
if "spec_mml=None" in test_config:
|
||||
# Preemption causes more variance in acceptance rates
|
||||
if (
|
||||
current_platform.is_rocm()
|
||||
and "preemption=True" in test_config
|
||||
):
|
||||
tolerance = 0.10
|
||||
else:
|
||||
tolerance = 0.05
|
||||
assert (
|
||||
test_acceptance_rate > base_acceptance_rate
|
||||
or test_acceptance_rate
|
||||
== pytest.approx(base_acceptance_rate, rel=tolerance)
|
||||
)
|
||||
else:
|
||||
# Currently the reported acceptance rate is expected to be
|
||||
# lower when we sometimes skip drafting altogether.
|
||||
assert test_acceptance_rate > 0.1
|
||||
print(
|
||||
f"PASSED: config=[{test_config}], params={params}"
|
||||
f" accept_rate={test_acceptance_rate}"
|
||||
)
|
||||
except AssertionError as e:
|
||||
print(
|
||||
f"FAILED: config=[{test_config}], params={params}"
|
||||
f" accept_rate={test_acceptance_rate}"
|
||||
)
|
||||
if failure is None:
|
||||
failure = e
|
||||
|
||||
if failure is not None:
|
||||
raise failure
|
||||
|
||||
|
||||
def run_test(
|
||||
model: str,
|
||||
test_str: str,
|
||||
sampling_param_tests: list[dict[str, Any]],
|
||||
test_preemption: bool,
|
||||
executor: str,
|
||||
async_scheduling: bool,
|
||||
spec_config: dict[str, Any] | None,
|
||||
test_prefill_chunking: bool,
|
||||
is_testing_with_spec_decoding: bool = False,
|
||||
):
|
||||
spec_decoding = spec_config is not None
|
||||
cache_arg: dict[str, Any] = (
|
||||
# Force preemptions
|
||||
dict(num_gpu_blocks_override=32)
|
||||
if test_preemption
|
||||
else dict(gpu_memory_utilization=0.9)
|
||||
)
|
||||
spec_mml = (spec_config or {}).get("max_model_len")
|
||||
test_config = (
|
||||
f"executor={executor}, preemption={test_preemption}, "
|
||||
f"async_sched={async_scheduling}, "
|
||||
f"chunk_prefill={test_prefill_chunking}, "
|
||||
f"spec_decoding={spec_decoding}, spec_mml={spec_mml}"
|
||||
)
|
||||
print("-" * 80)
|
||||
print(f"---- TESTING {test_str}: {test_config}")
|
||||
print("-" * 80)
|
||||
|
||||
# On ROCm: use float16 for first test (ROCM_AITER_FA), but float32 for
|
||||
# spec decoding test (TRITON_ATTN) for better precision.
|
||||
# On others: always use float32.
|
||||
if current_platform.is_rocm() and not is_testing_with_spec_decoding:
|
||||
dtype = "float16"
|
||||
else:
|
||||
dtype = "float32"
|
||||
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=512,
|
||||
enable_chunked_prefill=test_prefill_chunking,
|
||||
# Force prefill chunking
|
||||
max_num_batched_tokens=48 if test_prefill_chunking else None,
|
||||
# enforce_eager=True,
|
||||
async_scheduling=async_scheduling,
|
||||
distributed_executor_backend=executor,
|
||||
dtype=dtype,
|
||||
speculative_config=spec_config,
|
||||
disable_log_stats=False,
|
||||
**cache_arg,
|
||||
) as vllm_model:
|
||||
results = []
|
||||
acceptance_rates: list[float] | None = [] if spec_decoding else None
|
||||
for override_params in sampling_param_tests:
|
||||
metrics_before = vllm_model.llm.get_metrics()
|
||||
print(f"----------- RUNNING PARAMS: {override_params}")
|
||||
results.append(
|
||||
vllm_model.generate(
|
||||
example_prompts,
|
||||
sampling_params=SamplingParams(**default_params, **override_params),
|
||||
return_logprobs=True,
|
||||
)
|
||||
)
|
||||
metrics_after = vllm_model.llm.get_metrics()
|
||||
if acceptance_rates is not None:
|
||||
acceptance_rate = _get_acceptance_rate(metrics_before, metrics_after)
|
||||
acceptance_rates.append(acceptance_rate)
|
||||
print(f"ACCEPTANCE RATE {acceptance_rate}")
|
||||
|
||||
if test_preemption:
|
||||
preemptions = _get_count(
|
||||
metrics_before, metrics_after, "vllm:num_preemptions"
|
||||
)
|
||||
assert preemptions > 0, "preemption test had no preemptions"
|
||||
|
||||
if len(results) > 1:
|
||||
# First check that the different parameter configs
|
||||
# actually result in different output.
|
||||
for (other_test_outs, other_test_logprobs), params in zip(
|
||||
results[1:], sampling_param_tests[1:]
|
||||
):
|
||||
with pytest.raises(AssertionError):
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=results[0][0],
|
||||
outputs_1_lst=other_test_outs,
|
||||
name_0=f"baseline params={params}",
|
||||
name_1=f"other params={params}",
|
||||
)
|
||||
assert _all_logprobs_match(results[0][1], other_test_logprobs)
|
||||
|
||||
return test_config, results, acceptance_rates
|
||||
|
||||
|
||||
def _all_logprobs_match(req_a, req_b) -> bool:
|
||||
return (
|
||||
req_a == req_b
|
||||
or len(req_a) == len(req_b)
|
||||
and all(
|
||||
len(seq_a) == len(seq_b)
|
||||
and all(_logprobs_match(a, b) for a, b in zip(seq_a, seq_b))
|
||||
for seq_a, seq_b in zip(req_a, req_b)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool:
|
||||
if current_platform.is_rocm():
|
||||
# ROCm has higher numerical variance
|
||||
# due to use of float16.
|
||||
rel_tol, abs_tol = 5e-2, 1e-5
|
||||
else:
|
||||
rel_tol, abs_tol = 1e-3, 1e-6
|
||||
return (
|
||||
len(lps_a) == len(lps_b)
|
||||
and lps_a.keys() == lps_b.keys()
|
||||
and all(
|
||||
a.decoded_token == b.decoded_token
|
||||
and a.rank == b.rank
|
||||
and a.logprob == pytest.approx(b.logprob, rel=rel_tol, abs=abs_tol)
|
||||
for a, b in ((lps_a[x], lps_b[x]) for x in lps_a)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _get_acceptance_rate(before: list[Metric], after: list[Metric]) -> float:
|
||||
draft = _get_count(before, after, "vllm:spec_decode_num_draft_tokens")
|
||||
accept = _get_count(before, after, "vllm:spec_decode_num_accepted_tokens")
|
||||
return accept / draft if draft > 0 else 0.0
|
||||
|
||||
|
||||
def _get_count(before: list[Metric], after: list[Metric], name: str) -> int:
|
||||
before_val = next(m.value for m in before if m.name == name)
|
||||
after_val = next(m.value for m in after if m.name == name)
|
||||
return after_val - before_val
|
||||
131
tests/v1/e2e/test_async_spec_decode.py
Normal file
131
tests/v1/e2e/test_async_spec_decode.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Test that verifies no implicit GPU-CPU synchronization occurs during
|
||||
speculative decoding generation under expected conditions.
|
||||
"""
|
||||
|
||||
import multiprocessing
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sync_tracker():
|
||||
"""
|
||||
Fixture that patches CommonAttentionMetadata.seq_lens_cpu to detect
|
||||
lazy init syncs. Prints stack traces immediately when syncs occur.
|
||||
"""
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
|
||||
# Shared counter for cross-process communication (inherited by fork)
|
||||
sync_count = multiprocessing.Value("i", 0)
|
||||
|
||||
# Save original property
|
||||
original_prop = CommonAttentionMetadata.seq_lens_cpu
|
||||
original_fget = original_prop.fget
|
||||
|
||||
# Create tracking wrapper
|
||||
def tracking_seq_lens_cpu(self):
|
||||
if self._seq_lens_cpu is None:
|
||||
# Increment counter
|
||||
with sync_count.get_lock():
|
||||
sync_count.value += 1
|
||||
count = sync_count.value
|
||||
# Print stack trace immediately (shows in subprocess output)
|
||||
print(f"\n{'=' * 60}", file=sys.stderr)
|
||||
print(f"SYNC #{count}: seq_lens_cpu lazy init triggered!", file=sys.stderr)
|
||||
print(f"{'=' * 60}", file=sys.stderr)
|
||||
traceback.print_stack(file=sys.stderr)
|
||||
print(f"{'=' * 60}\n", file=sys.stderr)
|
||||
sys.stderr.flush()
|
||||
return original_fget(self)
|
||||
|
||||
# Apply patch
|
||||
CommonAttentionMetadata.seq_lens_cpu = property(tracking_seq_lens_cpu)
|
||||
|
||||
class SyncTracker:
|
||||
@property
|
||||
def count(self) -> int:
|
||||
return sync_count.value
|
||||
|
||||
def assert_no_sync(self, msg: str = ""):
|
||||
count = sync_count.value
|
||||
assert count == 0, (
|
||||
f"Unexpected GPU-CPU sync: seq_lens_cpu lazy init triggered "
|
||||
f"{count} times. See stack traces above. {msg}"
|
||||
)
|
||||
|
||||
yield SyncTracker()
|
||||
|
||||
# Restore original property
|
||||
CommonAttentionMetadata.seq_lens_cpu = original_prop
|
||||
torch._dynamo.reset()
|
||||
|
||||
|
||||
# Test configurations: (model, spec_model, method, num_spec_tokens, backend_env)
|
||||
SPEC_DECODE_CONFIGS = [
|
||||
pytest.param(
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
"nm-testing/Llama3_2_1B_speculator.eagle3",
|
||||
"eagle3",
|
||||
2,
|
||||
id="eagle3-llama",
|
||||
),
|
||||
pytest.param(
|
||||
"eagle618/deepseek-v3-random",
|
||||
"eagle618/eagle-deepseek-v3-random",
|
||||
"eagle",
|
||||
2,
|
||||
id="eagle-mla-deepseek",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,spec_model,method,num_spec_tokens",
|
||||
SPEC_DECODE_CONFIGS,
|
||||
)
|
||||
def test_no_sync_with_spec_decode(
|
||||
sync_tracker,
|
||||
model: str,
|
||||
spec_model: str,
|
||||
method: str,
|
||||
num_spec_tokens: int,
|
||||
):
|
||||
"""
|
||||
Test that no implicit GPU-CPU sync occurs during speculative decoding
|
||||
generation.
|
||||
"""
|
||||
# Import vLLM AFTER sync_tracker fixture has applied the patch
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
llm = LLM(
|
||||
model=model,
|
||||
max_model_len=256,
|
||||
speculative_config={
|
||||
"method": method,
|
||||
"num_speculative_tokens": num_spec_tokens,
|
||||
"model": spec_model,
|
||||
},
|
||||
enforce_eager=True,
|
||||
async_scheduling=True,
|
||||
)
|
||||
|
||||
outputs = llm.generate(
|
||||
["Hello, my name is"],
|
||||
SamplingParams(temperature=0, max_tokens=10),
|
||||
)
|
||||
|
||||
assert len(outputs) == 1
|
||||
assert len(outputs[0].outputs[0].text) > 0
|
||||
|
||||
del llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
sync_tracker.assert_no_sync()
|
||||
37
tests/v1/e2e/test_cascade_attention.py
Normal file
37
tests/v1/e2e/test_cascade_attention.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
from ...utils import create_new_process_for_each_test
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "FLASHINFER"])
|
||||
def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
|
||||
prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:"
|
||||
|
||||
if attn_backend == "FLASHINFER":
|
||||
pytest.skip(
|
||||
"This test is failing with FlashInfer backend and "
|
||||
"needs investigation. See issue #25679."
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct")
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
|
||||
|
||||
# No cascade attention.
|
||||
single_prompt = [example_system_message + prompt]
|
||||
responses = llm.generate(single_prompt, sampling_params)
|
||||
ref_output = responses[0].outputs[0].text
|
||||
|
||||
# (Probably) Use cascade attention.
|
||||
prompts = [example_system_message + prompt] * 64
|
||||
responses = llm.generate(prompts, sampling_params)
|
||||
for response in responses:
|
||||
assert response.outputs[0].text == ref_output
|
||||
63
tests/v1/e2e/test_context_length.py
Normal file
63
tests/v1/e2e/test_context_length.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests for vLLM `vllm/v1/engine/processor.Processor._validate_model_input()`
|
||||
handling of maximum context length for decoder models.
|
||||
|
||||
This test ensures:
|
||||
- A prompt that is one token shorter than the model's maximum context length
|
||||
can be processed successfully when requesting one additional token.
|
||||
- A prompt that reaches the model's maximum context length throws a
|
||||
`ValueError` when requesting at least one additional token.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.conftest import VllmRunner
|
||||
from tests.utils import create_new_process_for_each_test
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("model, max_model_len", [("JackFram/llama-160m", 2048)])
|
||||
@pytest.mark.parametrize(
|
||||
"prompt_len, max_tokens",
|
||||
[
|
||||
(2047, 1), # prompt_len = max_model_len - 1 -> allowed
|
||||
(2048, 1), # prompt_len = max_model_len -> not allowed
|
||||
],
|
||||
)
|
||||
def test_decoder_max_context_length_validation(
|
||||
model: str,
|
||||
max_model_len: int,
|
||||
vllm_runner: type[VllmRunner],
|
||||
prompt_len: int,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
"""Check vLLM decoder model input validation for edge cases where
|
||||
the prompt length is (almost) equal to the max model length."""
|
||||
|
||||
prompt_ids = [[43] * prompt_len]
|
||||
|
||||
with vllm_runner(
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
max_model_len=max_model_len,
|
||||
max_num_seqs=1,
|
||||
tensor_parallel_size=1,
|
||||
) as vllm_model:
|
||||
if prompt_len + max_tokens <= max_model_len:
|
||||
# Should succeed as constraints are met
|
||||
vllm_model.generate_greedy(prompt_ids, max_tokens)
|
||||
else:
|
||||
# Should raise the ValueError defined in
|
||||
# vllm/v1/engine/processor.Processor_validate_model_input()
|
||||
expected_msg = (
|
||||
f"The decoder prompt (length {prompt_len}) plus the number of "
|
||||
f"requested output tokens (at least 1) is longer than "
|
||||
f"the maximum model length of {max_model_len}. "
|
||||
"Make sure that `max_model_len` is no smaller than the number of "
|
||||
"text tokens (prompt + requested output tokens)."
|
||||
)
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
vllm_model.generate_greedy(prompt_ids, max_tokens)
|
||||
assert expected_msg in str(excinfo.value)
|
||||
98
tests/v1/e2e/test_correctness_sliding_window.py
Normal file
98
tests/v1/e2e/test_correctness_sliding_window.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ...utils import check_answers, prep_prompts
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestConfig:
|
||||
sliding_window: int
|
||||
ln_range: tuple[int, int]
|
||||
|
||||
|
||||
model_config = {
|
||||
"bigcode/starcoder2-3b": TestConfig(4096, (800, 1100)),
|
||||
"google/gemma-3-1b-it": TestConfig(4096, (400, 800)),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"bigcode/starcoder2-3b", # sliding window only
|
||||
"google/gemma-3-1b-it", # sliding window + full attention
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("batch_size", [5])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("disable_hybrid_kv_cache_manager", [True, False])
|
||||
def test_sliding_window_retrieval(
|
||||
model, batch_size, seed, disable_hybrid_kv_cache_manager
|
||||
):
|
||||
"""
|
||||
The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then
|
||||
asks for value of one of them (which is outside the sliding window).
|
||||
If we tell it upfront which we are going to be looking for, then
|
||||
it answers correctly (mostly).
|
||||
"""
|
||||
# NOTE: For ROCm, we have to enforce eager mode to use custom kernel
|
||||
# implementation of GELU with tanh approximation, as PyTorch's native
|
||||
# implementation is currently unstable with torch.compile and produces garbage.
|
||||
enforce_eager = current_platform.is_rocm()
|
||||
|
||||
test_config = model_config[model]
|
||||
|
||||
llm = LLM(
|
||||
model=model,
|
||||
disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager,
|
||||
enforce_eager=enforce_eager,
|
||||
)
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
|
||||
|
||||
prompts, answer, indices = prep_prompts(batch_size, ln_range=test_config.ln_range)
|
||||
|
||||
check_length(prompts, llm, test_config.sliding_window)
|
||||
|
||||
# Fresh generation
|
||||
responses = llm.generate(prompts, sampling_params)
|
||||
check_answers(
|
||||
indices,
|
||||
answer,
|
||||
[response.outputs[0].text for response in responses],
|
||||
accept_rate=1.0,
|
||||
)
|
||||
|
||||
# Re-generate with the same prompts to test prefix caching
|
||||
responses = llm.generate(prompts, sampling_params)
|
||||
check_answers(
|
||||
indices,
|
||||
answer,
|
||||
[response.outputs[0].text for response in responses],
|
||||
accept_rate=1.0,
|
||||
)
|
||||
|
||||
|
||||
def check_length(prompts: list[str], llm: LLM, sliding_window: int):
|
||||
"""
|
||||
Check if the prompt length is valid, i.e., longer than the sliding window
|
||||
size and shorter than the model's max length.
|
||||
|
||||
Args:
|
||||
prompts: list of prompts
|
||||
llm: LLM object
|
||||
sliding_window: Sliding window size
|
||||
"""
|
||||
tokenizer = llm.get_tokenizer()
|
||||
max_model_len = llm.llm_engine.model_config.max_model_len
|
||||
assert any(len(tokenizer.encode(prompt)) > sliding_window for prompt in prompts), (
|
||||
"Prompt is too short for test"
|
||||
)
|
||||
assert all(len(tokenizer.encode(prompt)) <= max_model_len for prompt in prompts), (
|
||||
"Prompt is too long for test"
|
||||
)
|
||||
101
tests/v1/e2e/test_kv_sharing_fast_prefill.py
Normal file
101
tests/v1/e2e/test_kv_sharing_fast_prefill.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig, CompilationMode
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ...utils import check_answers, fork_new_process_for_each_test, prep_prompts
|
||||
|
||||
# global seed
|
||||
SEED = 42
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_prompts():
|
||||
"""
|
||||
Adapted from tests/v1/e2e/test_spec_decode.py
|
||||
"""
|
||||
prompt_types = ["repeat", "sentence"]
|
||||
# Setting higher num prompts increases the chance of numerics mismatch
|
||||
# due to matrix multiplication numerics depending on batch dimension
|
||||
num_prompts = 10
|
||||
prompts = []
|
||||
|
||||
random.seed(0)
|
||||
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
|
||||
|
||||
for kind in random_prompt_type_choices:
|
||||
word_choices = ["test", "temp", "hello", "where"]
|
||||
word = random.choice(word_choices)
|
||||
if kind == "repeat":
|
||||
prompt = f"""please repeat the word '{word}' 10 times."""
|
||||
elif kind == "sentence":
|
||||
prompt = f"""please give a ten-word sentence that
|
||||
uses the word {word} at least once."""
|
||||
else:
|
||||
raise ValueError(f"Unknown prompt type: {kind}")
|
||||
prompts.append(prompt)
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
use_fork_for_test = (
|
||||
fork_new_process_for_each_test if not current_platform.is_rocm() else lambda x: x
|
||||
)
|
||||
|
||||
|
||||
@use_fork_for_test
|
||||
@pytest.mark.parametrize("kv_sharing_fast_prefill", [False, True])
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
def test_kv_sharing_fast_prefill(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
kv_sharing_fast_prefill: bool,
|
||||
enforce_eager: bool,
|
||||
):
|
||||
if not enforce_eager and current_platform.is_rocm():
|
||||
# Relevant context: https://github.com/vllm-project/vllm/pull/29244
|
||||
pytest.skip(
|
||||
"ROCm: torch.compile produces incorrect output for gemma-3n's GELU "
|
||||
"with tanh approximation. Use enforce_eager=True instead."
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
|
||||
compilation_config = CompilationConfig(
|
||||
# This allows vLLM compilation backend to handle allocating and
|
||||
# managing buffers for cudagraph
|
||||
cudagraph_copy_inputs=True,
|
||||
mode=CompilationMode.VLLM_COMPILE
|
||||
if not enforce_eager
|
||||
else CompilationMode.NONE,
|
||||
)
|
||||
batch_size = 10
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
# Make scheduling deterministic for reproducibility
|
||||
if current_platform.is_rocm():
|
||||
# Use spawn to prevent cuda re-initialization error
|
||||
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
else:
|
||||
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
|
||||
prompts, answer, indices = prep_prompts(batch_size)
|
||||
|
||||
llm = LLM(
|
||||
model="google/gemma-3n-E2B-it",
|
||||
enforce_eager=enforce_eager,
|
||||
compilation_config=compilation_config,
|
||||
seed=SEED,
|
||||
kv_sharing_fast_prefill=kv_sharing_fast_prefill,
|
||||
)
|
||||
responses = llm.generate(prompts, sampling_params)
|
||||
check_answers(
|
||||
indices,
|
||||
answer,
|
||||
[response.outputs[0].text for response in responses],
|
||||
accept_rate=1.0,
|
||||
)
|
||||
139
tests/v1/e2e/test_lora_with_spec_decode.py
Normal file
139
tests/v1/e2e/test_lora_with_spec_decode.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This script contains:
|
||||
1. test lora with speculative decoding for batch inference
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
LORA_TEST_PROMPT_MAP: dict[str, str] = {}
|
||||
|
||||
LORA_TEST_PROMPT_MAP["premjatin/qwen-linear-algebra-coder"] = """
|
||||
### INSTRUCTION:
|
||||
You are an AI assistant that generates Python code to solve linear
|
||||
algebra problems.
|
||||
|
||||
### PROBLEM:
|
||||
Find the eigenvalues and eigenvectors of the following 3x3 matrix:
|
||||
[[3, 2, 0],
|
||||
[2, 3, 0],
|
||||
[0, 0, 2]]
|
||||
|
||||
### OUTPUT FORMAT (STRICT):
|
||||
Numbers should be represented as integers only.
|
||||
|
||||
### PYTHON SOLUTION:
|
||||
"""
|
||||
|
||||
SEED = 42
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
|
||||
@pytest.mark.parametrize(
|
||||
"model_setup",
|
||||
[
|
||||
(
|
||||
"eagle3",
|
||||
"Qwen/Qwen3-1.7B",
|
||||
"AngelSlim/Qwen3-1.7B_eagle3",
|
||||
"premjatin/qwen-linear-algebra-coder",
|
||||
1,
|
||||
)
|
||||
],
|
||||
)
|
||||
def test_batch_inference_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
model_setup: tuple[str, str, str, str, int],
|
||||
):
|
||||
"""
|
||||
Compare the outputs of a LLM with only Lora and a LLM with both SD and Lora.
|
||||
Should be the same and no failure when doing batch inference.
|
||||
model_setup: (method, model_name, spec_model_name, lora_path, tp_size)
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
# Disable randomness
|
||||
m.setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
|
||||
torch.manual_seed(SEED)
|
||||
np.random.seed(SEED)
|
||||
random.seed(SEED)
|
||||
torch.cuda.manual_seed_all(SEED)
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
method, model_name, spec_model_name, lora_path, tp_size = model_setup
|
||||
|
||||
# without speculative decoding
|
||||
ref_llm = LLM(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
tensor_parallel_size=tp_size,
|
||||
max_model_len=2048,
|
||||
max_num_seqs=4,
|
||||
enable_lora=True,
|
||||
max_loras=1,
|
||||
max_cpu_loras=1,
|
||||
max_lora_rank=16,
|
||||
)
|
||||
|
||||
prompts = [LORA_TEST_PROMPT_MAP[lora_path]] * 100
|
||||
lora_request = LoRARequest("adapter", 1, lora_path)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0, top_p=1.0, top_k=-1, seed=SEED, max_tokens=128
|
||||
)
|
||||
|
||||
ref_outputs = ref_llm.generate(
|
||||
prompts, sampling_params, lora_request=lora_request
|
||||
)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
lora_spec_llm = LLM(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
tensor_parallel_size=tp_size,
|
||||
speculative_config={
|
||||
"method": method,
|
||||
"model": spec_model_name,
|
||||
"num_speculative_tokens": 3,
|
||||
"max_model_len": 2048,
|
||||
},
|
||||
max_model_len=2048,
|
||||
max_num_seqs=4,
|
||||
enable_lora=True,
|
||||
max_loras=1,
|
||||
max_cpu_loras=1,
|
||||
max_lora_rank=16,
|
||||
)
|
||||
|
||||
lora_spec_outputs = lora_spec_llm.generate(
|
||||
prompts, sampling_params, lora_request=lora_request
|
||||
)
|
||||
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, lora_spec_outputs):
|
||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||
|
||||
# Heuristic: expect at least 90% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
print(f"match ratio: {matches}/{len(ref_outputs)}")
|
||||
assert matches > int(0.90 * len(ref_outputs))
|
||||
del lora_spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
502
tests/v1/e2e/test_min_tokens.py
Normal file
502
tests/v1/e2e/test_min_tokens.py
Normal file
@@ -0,0 +1,502 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Comprehensive end-to-end tests for `min_tokens` in the V1 engine.
|
||||
|
||||
Addresses #21950: verify and add CI coverage.
|
||||
|
||||
Covers:
|
||||
1) Basic functionality
|
||||
2) Stop strings with `min_tokens` (bug #21987; fix in PR #22014)
|
||||
3) EOS behavior with `min_tokens` (potential logits-processor bug)
|
||||
4) Edge cases (min_tokens == max_tokens, min_tokens == 0)
|
||||
5) Multiple stop conditions
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.outputs import RequestOutput
|
||||
|
||||
# Test configuration
|
||||
TEST_MODEL = "facebook/opt-125m" # Small model for fast CI execution
|
||||
GREEDY = 0.0 # Deterministic generation for consistent testing
|
||||
|
||||
|
||||
class MinTokensTestCase:
|
||||
"""Data class for min_tokens test scenarios"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
min_tokens: int,
|
||||
max_tokens: int,
|
||||
stop: str | list[str] | None = None,
|
||||
expected_min_len: int | None = None,
|
||||
expected_exact_len: int | None = None,
|
||||
):
|
||||
self.name = name
|
||||
self.min_tokens = min_tokens
|
||||
self.max_tokens = max_tokens
|
||||
self.stop = stop
|
||||
self.expected_min_len = expected_min_len or min_tokens
|
||||
self.expected_exact_len = expected_exact_len
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"{self.name}: min={self.min_tokens}, "
|
||||
f"max={self.max_tokens}, stop={self.stop}"
|
||||
)
|
||||
|
||||
|
||||
# Test scenarios covering all critical cases
|
||||
MIN_TOKENS_TEST_CASES = [
|
||||
# === BASIC FUNCTIONALITY (should work) ===
|
||||
MinTokensTestCase(
|
||||
name="basic_min_tokens_no_stop",
|
||||
min_tokens=8,
|
||||
max_tokens=20,
|
||||
stop=None,
|
||||
expected_min_len=8,
|
||||
),
|
||||
MinTokensTestCase(
|
||||
name="min_tokens_zero",
|
||||
min_tokens=0,
|
||||
max_tokens=10,
|
||||
stop=None,
|
||||
expected_min_len=0,
|
||||
),
|
||||
MinTokensTestCase(
|
||||
name="min_equals_max_no_stop",
|
||||
min_tokens=15,
|
||||
max_tokens=15,
|
||||
stop=None,
|
||||
expected_exact_len=15,
|
||||
),
|
||||
# === STOP STRINGS WITH MIN_TOKENS ===
|
||||
# These tests expose the detokenizer bug where stop strings
|
||||
# bypass min_tokens
|
||||
# Using mathematically guaranteed approach with wide stop nets
|
||||
pytest.param(
|
||||
MinTokensTestCase(
|
||||
name="min_tokens_with_comprehensive_stops",
|
||||
min_tokens=5,
|
||||
max_tokens=20,
|
||||
stop=[
|
||||
"a",
|
||||
"e",
|
||||
"i",
|
||||
"o",
|
||||
"u",
|
||||
"t",
|
||||
"n",
|
||||
"s",
|
||||
"r",
|
||||
"l",
|
||||
" ",
|
||||
],
|
||||
expected_min_len=5,
|
||||
),
|
||||
marks=pytest.mark.xfail(
|
||||
reason=(
|
||||
"Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)"
|
||||
),
|
||||
strict=False,
|
||||
),
|
||||
id="min_tokens_with_comprehensive_stops",
|
||||
),
|
||||
pytest.param(
|
||||
MinTokensTestCase(
|
||||
name="min_tokens_with_simple_char_stop",
|
||||
min_tokens=3,
|
||||
max_tokens=15,
|
||||
stop=["e", "a", " "],
|
||||
expected_min_len=3,
|
||||
),
|
||||
marks=pytest.mark.xfail(
|
||||
reason=(
|
||||
"Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)"
|
||||
),
|
||||
strict=False,
|
||||
),
|
||||
id="min_tokens_with_simple_char_stop",
|
||||
),
|
||||
# === EOS TOKEN WITH MIN_TOKENS (potential LogitsProcessor bug) ===
|
||||
# These test the MinTokensLogitsProcessor handling of EOS tokens
|
||||
pytest.param(
|
||||
MinTokensTestCase(
|
||||
name="min_equals_max_eos_only",
|
||||
min_tokens=20,
|
||||
max_tokens=20,
|
||||
stop=None, # Relies on default EOS token behavior
|
||||
expected_exact_len=20,
|
||||
),
|
||||
marks=pytest.mark.xfail(
|
||||
reason=("Potential logits-processor bug: EOS tokens may bypass min_tokens"),
|
||||
strict=False,
|
||||
),
|
||||
id="min_equals_max_eos_only",
|
||||
),
|
||||
# === EDGE CASES ===
|
||||
MinTokensTestCase(
|
||||
name="large_min_tokens",
|
||||
min_tokens=50,
|
||||
max_tokens=60,
|
||||
stop=None,
|
||||
expected_min_len=50,
|
||||
),
|
||||
MinTokensTestCase(
|
||||
name="min_tokens_with_empty_stop_list",
|
||||
min_tokens=5,
|
||||
max_tokens=15,
|
||||
stop=[], # Empty stop list
|
||||
expected_min_len=5,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm_v1():
|
||||
"""Create V1 LLM instance for testing"""
|
||||
llm = LLM(
|
||||
model=TEST_MODEL,
|
||||
tensor_parallel_size=1,
|
||||
max_model_len=1024, # Small context for fast testing
|
||||
enforce_eager=True, # Avoid graph compilation overhead
|
||||
)
|
||||
return llm
|
||||
|
||||
|
||||
def get_token_count(output: RequestOutput) -> int:
|
||||
"""Extract token count from LLM output"""
|
||||
if not output.outputs:
|
||||
return 0
|
||||
return len(output.outputs[0].token_ids)
|
||||
|
||||
|
||||
def assert_min_tokens_satisfied(
|
||||
output: RequestOutput, test_case: MinTokensTestCase
|
||||
) -> None:
|
||||
"""Assert that min_tokens requirement is satisfied"""
|
||||
token_count = get_token_count(output)
|
||||
stop_reason = output.outputs[0].stop_reason if output.outputs else "no output"
|
||||
|
||||
if test_case.expected_exact_len is not None:
|
||||
# Exact length requirement
|
||||
assert token_count == test_case.expected_exact_len, (
|
||||
f"Expected exactly {test_case.expected_exact_len} tokens, "
|
||||
f"got {token_count} tokens. "
|
||||
f"Stop reason: {stop_reason}"
|
||||
)
|
||||
else:
|
||||
# Minimum length requirement
|
||||
assert token_count >= (test_case.expected_min_len or 0), (
|
||||
f"Expected at least {test_case.expected_min_len} tokens, "
|
||||
f"got {token_count} tokens. "
|
||||
f"Stop reason: {stop_reason}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
MIN_TOKENS_TEST_CASES,
|
||||
ids=lambda tc: tc.name,
|
||||
)
|
||||
def test_min_tokens_comprehensive(llm_v1: LLM, test_case: MinTokensTestCase):
|
||||
"""
|
||||
Comprehensive test for min_tokens functionality in V1 engine.
|
||||
|
||||
This test covers all critical scenarios for min_tokens:
|
||||
- Basic functionality (should work)
|
||||
- Stop strings with min_tokens (known bug)
|
||||
- EOS tokens with min_tokens (potential bug)
|
||||
- Edge cases
|
||||
|
||||
Args:
|
||||
llm_v1: V1 LLM instance
|
||||
test_case: Test scenario parameters
|
||||
"""
|
||||
# Known failing cases are handled via param-level xfail marks above.
|
||||
|
||||
# Create sampling parameters
|
||||
sampling_params = SamplingParams(
|
||||
min_tokens=test_case.min_tokens,
|
||||
max_tokens=test_case.max_tokens,
|
||||
stop=test_case.stop,
|
||||
temperature=GREEDY,
|
||||
include_stop_str_in_output=True, # Include stop strings for debugging
|
||||
)
|
||||
|
||||
# Use simple prompt. Comprehensive stop lists should catch any generation
|
||||
prompt = "Hello"
|
||||
|
||||
# Generate output
|
||||
outputs = llm_v1.generate([prompt], sampling_params)
|
||||
|
||||
assert len(outputs) == 1, "Expected exactly one output"
|
||||
output = outputs[0]
|
||||
|
||||
# Debug information
|
||||
token_count = get_token_count(output)
|
||||
generated_text = output.outputs[0].text if output.outputs else ""
|
||||
stop_reason = output.outputs[0].stop_reason if output.outputs else "unknown"
|
||||
|
||||
print(f"\nTest: {test_case.name}")
|
||||
print(f"Generated {token_count} tokens")
|
||||
print(f"Stop reason: {stop_reason}")
|
||||
print(f"Generated text: {repr(generated_text)}")
|
||||
print(f"Expected min: {test_case.expected_min_len}")
|
||||
if test_case.expected_exact_len:
|
||||
print(f"Expected exact: {test_case.expected_exact_len}")
|
||||
|
||||
# Validate min_tokens requirement
|
||||
assert_min_tokens_satisfied(output, test_case)
|
||||
|
||||
|
||||
def test_min_tokens_basic_functionality(llm_v1: LLM):
|
||||
"""
|
||||
Test basic min_tokens functionality without stop conditions.
|
||||
|
||||
This is a baseline test that should always pass and validates
|
||||
that min_tokens works correctly in the simple case.
|
||||
"""
|
||||
sampling_params = SamplingParams(min_tokens=10, max_tokens=20, temperature=GREEDY)
|
||||
|
||||
prompt = "Once upon a time"
|
||||
outputs = llm_v1.generate([prompt], sampling_params)
|
||||
|
||||
assert len(outputs) == 1
|
||||
token_count = get_token_count(outputs[0])
|
||||
|
||||
assert token_count >= 10, f"Expected at least 10 tokens, got {token_count}"
|
||||
assert token_count <= 20, f"Expected at most 20 tokens, got {token_count}"
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=("Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)"),
|
||||
strict=False,
|
||||
)
|
||||
def test_min_tokens_stop_strings_bug(llm_v1: LLM):
|
||||
"""
|
||||
Test the specific bug where stop strings bypass min_tokens.
|
||||
|
||||
This test specifically reproduces the bug Calvin is fixing in PR #22014.
|
||||
It should fail until that fix is merged.
|
||||
|
||||
Strategy: Use guaranteed stop characters that will appear
|
||||
in any generated text.
|
||||
"""
|
||||
# If the bug is fixed upstream, this test will XPASS
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
min_tokens=15,
|
||||
max_tokens=50,
|
||||
# Common letter; likely appears early
|
||||
stop=["e"],
|
||||
temperature=GREEDY,
|
||||
include_stop_str_in_output=True,
|
||||
)
|
||||
|
||||
# Simple prompt that will generate text containing "e"
|
||||
prompt = "The quick brown fox"
|
||||
outputs = llm_v1.generate([prompt], sampling_params)
|
||||
|
||||
assert len(outputs) == 1
|
||||
token_count = get_token_count(outputs[0])
|
||||
generated_text = outputs[0].outputs[0].text if outputs[0].outputs else ""
|
||||
|
||||
# Debug info to understand what happened
|
||||
print(f"Generated text: {repr(generated_text)}")
|
||||
print(f"Token count: {token_count}")
|
||||
print(f"Contains 'e': {'e' in generated_text}")
|
||||
|
||||
# This assertion should fail due to the bug - if stop string is found early,
|
||||
# the model should still continue generating until min_tokens is reached
|
||||
stop_reason = (
|
||||
outputs[0].outputs[0].stop_reason if outputs[0].outputs else "no output"
|
||||
)
|
||||
assert token_count >= 15, (
|
||||
"Bug confirmed: "
|
||||
f"{token_count} tokens < min_tokens=15. "
|
||||
f"Reason: {stop_reason}. "
|
||||
f"Text: {repr(generated_text)}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=("Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)"),
|
||||
strict=False,
|
||||
)
|
||||
def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM):
|
||||
"""
|
||||
Guaranteed test for stop strings bypassing min_tokens bug.
|
||||
|
||||
Strategy: Use very low temperature and multiple common stop strings
|
||||
to virtually guarantee early detection, combined with long min_tokens
|
||||
to ensure the bug is exposed regardless of model behavior.
|
||||
"""
|
||||
# If the bug is fixed upstream, this test will XPASS
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
min_tokens=50, # Set high min_tokens to ensure bug detection
|
||||
max_tokens=200,
|
||||
# Use multiple very common patterns - at least one will appear
|
||||
stop=["e", "a", "i", "o", "u", " ", "t", "n", "s", "r"],
|
||||
temperature=GREEDY,
|
||||
include_stop_str_in_output=True,
|
||||
)
|
||||
|
||||
# Simple prompt that will generate some text
|
||||
prompt = "The cat"
|
||||
outputs = llm_v1.generate([prompt], sampling_params)
|
||||
|
||||
assert len(outputs) == 1
|
||||
token_count = get_token_count(outputs[0])
|
||||
generated_text = outputs[0].outputs[0].text if outputs[0].outputs else ""
|
||||
stop_reason = outputs[0].outputs[0].stop_reason if outputs[0].outputs else "unknown"
|
||||
|
||||
print(f"Generated text: {repr(generated_text)}")
|
||||
print(f"Token count: {token_count}")
|
||||
print(f"Stop reason: {stop_reason}")
|
||||
|
||||
# With the bug, this will fail because ANY of the common characters
|
||||
# will trigger early termination before min_tokens=50 is reached
|
||||
# It's virtually impossible to generate 50 tokens without hitting
|
||||
# at least one of: e, a, i, o, u, space, t, n, s, r
|
||||
finish_reason = (
|
||||
outputs[0].outputs[0].finish_reason if outputs[0].outputs else "unknown"
|
||||
)
|
||||
|
||||
print(f"Finish reason: {finish_reason}")
|
||||
|
||||
if finish_reason == "stop":
|
||||
assert token_count >= 50, (
|
||||
"Bug confirmed: "
|
||||
f"{token_count} tokens < min_tokens=50. "
|
||||
f"Reason: {finish_reason}. "
|
||||
f"Text: {repr(generated_text)}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=("Potential logits-processor bug: EOS tokens may bypass min_tokens"),
|
||||
strict=False,
|
||||
)
|
||||
def test_min_tokens_eos_behavior(llm_v1: LLM):
|
||||
"""
|
||||
Verify EOS handling with and without min_tokens.
|
||||
|
||||
- Without min_tokens: expect early EOS -> finish_reason == "stop",
|
||||
stop_reason is None, and generated tokens < max_tokens (25).
|
||||
- With min_tokens: EOS should be blocked until min_tokens is reached
|
||||
(finish_reason == "length"); verify that eos_token_id does not appear
|
||||
in generated token_ids.
|
||||
"""
|
||||
# tokenizer + eos id
|
||||
tokenizer = llm_v1.get_tokenizer()
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
prompt = "Give a file extension."
|
||||
max_toks = 32
|
||||
|
||||
# Case 1: WITHOUT min_tokens
|
||||
sp_no_min = SamplingParams(
|
||||
max_tokens=max_toks,
|
||||
temperature=GREEDY,
|
||||
)
|
||||
out_no_min = llm_v1.generate([prompt], sp_no_min)
|
||||
assert len(out_no_min) == 1
|
||||
choice_no_min = out_no_min[0].outputs[0]
|
||||
|
||||
ids_no_min = choice_no_min.token_ids or []
|
||||
finish_no_min = choice_no_min.finish_reason
|
||||
stop_no_min = choice_no_min.stop_reason
|
||||
|
||||
print(
|
||||
"[no-min] tokens=",
|
||||
len(ids_no_min),
|
||||
" finish=",
|
||||
finish_no_min,
|
||||
" stop_reason=",
|
||||
stop_no_min,
|
||||
)
|
||||
|
||||
assert finish_no_min == "stop", (
|
||||
f"Expected finish_reason 'stop' without min_tokens, got {finish_no_min}"
|
||||
)
|
||||
assert stop_no_min is None, (
|
||||
"For EOS-based stop (no user stop strings), stop_reason should be None."
|
||||
)
|
||||
assert len(ids_no_min) < max_toks, (
|
||||
f"Expected early EOS with < {max_toks} tokens, got {len(ids_no_min)}"
|
||||
)
|
||||
|
||||
# Case 2: WITH min_tokens
|
||||
sp_with_min = SamplingParams(
|
||||
min_tokens=max_toks,
|
||||
max_tokens=max_toks,
|
||||
temperature=GREEDY,
|
||||
)
|
||||
out_with_min = llm_v1.generate([prompt], sp_with_min)
|
||||
assert len(out_with_min) == 1
|
||||
choice_with_min = out_with_min[0].outputs[0]
|
||||
|
||||
ids_with_min = choice_with_min.token_ids or []
|
||||
finish_with_min = choice_with_min.finish_reason
|
||||
stop_with_min = choice_with_min.stop_reason
|
||||
|
||||
print(
|
||||
"[with-min] tokens=",
|
||||
len(ids_with_min),
|
||||
" finish=",
|
||||
finish_with_min,
|
||||
" stop_reason=",
|
||||
stop_with_min,
|
||||
)
|
||||
|
||||
# Exact length reached; EOS should have been blocked
|
||||
assert len(ids_with_min) == max_toks, (
|
||||
f"Expected exactly {max_toks} tokens with min_tokens; got {len(ids_with_min)}"
|
||||
)
|
||||
assert finish_with_min == "length", (
|
||||
f"Expected finish_reason 'length'; got {finish_with_min}"
|
||||
)
|
||||
assert eos_token_id not in ids_with_min, (
|
||||
"EOS token id should not appear when min_tokens prevents early EOS."
|
||||
)
|
||||
|
||||
|
||||
def test_min_tokens_validation():
|
||||
"""
|
||||
Test that SamplingParams correctly validates min_tokens parameters.
|
||||
|
||||
This tests the parameter validation logic in SamplingParams.
|
||||
"""
|
||||
# Valid cases
|
||||
SamplingParams(min_tokens=0, max_tokens=10)
|
||||
SamplingParams(min_tokens=5, max_tokens=10)
|
||||
SamplingParams(min_tokens=10, max_tokens=10)
|
||||
|
||||
# Invalid cases
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="min_tokens must be greater than or equal to 0",
|
||||
):
|
||||
SamplingParams(min_tokens=-1, max_tokens=10)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="min_tokens must be less than or equal to max_tokens",
|
||||
):
|
||||
SamplingParams(min_tokens=15, max_tokens=10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Run tests locally for development.
|
||||
|
||||
Usage:
|
||||
cd vllm/
|
||||
python -m pytest tests/v1/e2e/test_min_tokens.py -v
|
||||
"""
|
||||
pytest.main([__file__, "-v"])
|
||||
167
tests/v1/e2e/test_pooling_chunked_prefill.py
Normal file
167
tests/v1/e2e/test_pooling_chunked_prefill.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
prompt = """
|
||||
Generals gathered in their masses
|
||||
Just like witches at black masses
|
||||
Evil minds that plot destruction
|
||||
Sorcerer of death's construction
|
||||
In the fields, the bodies burning
|
||||
As the war machine keeps turning
|
||||
Death and hatred to mankind
|
||||
Poisoning their brainwashed minds
|
||||
Oh, Lord, yeah
|
||||
|
||||
Politicians hide themselves away
|
||||
They only started the war
|
||||
Why should they go out to fight?
|
||||
They leave that all to the poor, yeah
|
||||
Time will tell on their power minds
|
||||
Making war just for fun
|
||||
Treating people just like pawns in chess
|
||||
Wait till their judgment day comes, yeah
|
||||
|
||||
Now, in darkness, world stops turning
|
||||
Ashes where their bodies burning
|
||||
No more war pigs have the power
|
||||
Hand of God has struck the hour
|
||||
Day of Judgment, God is calling
|
||||
On their knees, the war pigs crawling
|
||||
Begging mercies for their sins
|
||||
Satan, laughing, spreads his wings
|
||||
Oh, Lord, yeah
|
||||
"""
|
||||
|
||||
|
||||
class WrapperPooler(nn.Module):
|
||||
def __init__(self, pooler):
|
||||
super().__init__()
|
||||
self.pooler = pooler
|
||||
self.chunks = []
|
||||
|
||||
def get_pooling_updates(self, task):
|
||||
return self.pooler.get_pooling_updates(task)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
pooling_metadata,
|
||||
):
|
||||
self.chunks.append(hidden_states.shape[0])
|
||||
return self.pooler(hidden_states, pooling_metadata)
|
||||
|
||||
|
||||
def inject_pooler(self):
|
||||
model = self.get_model()
|
||||
wrapper = WrapperPooler(model.pooler)
|
||||
model.pooler = wrapper
|
||||
|
||||
|
||||
def retrieve_chunks(self):
|
||||
model = self.get_model()
|
||||
chunks = model.pooler.chunks
|
||||
model.pooler.chunks = []
|
||||
return chunks
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
|
||||
def test_pooling_chunked_prefill(vllm_runner, monkeypatch):
|
||||
"""Test chunked prefill for pooling models with LastPool."""
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
model_id = "Qwen/Qwen3-Embedding-0.6B"
|
||||
|
||||
chunk_size = 10
|
||||
|
||||
# Set chunking parameters to force chunked prefill
|
||||
# Note: Chunked prefill is automatically handled by vLLM
|
||||
# internally based on the model size and prompt
|
||||
with vllm_runner(
|
||||
model_id,
|
||||
runner="pooling",
|
||||
long_prefill_token_threshold=chunk_size,
|
||||
tensor_parallel_size=1,
|
||||
enforce_eager=True,
|
||||
enable_chunked_prefill=True,
|
||||
) as llm:
|
||||
llm.get_llm().llm_engine.collective_rpc(inject_pooler)
|
||||
|
||||
tokenizer = llm.get_llm().get_tokenizer()
|
||||
tokens = tokenizer(prompt)["input_ids"]
|
||||
prompt_len = len(tokens)
|
||||
full_chunks, last_chunk = divmod(prompt_len, chunk_size)
|
||||
expected_chunks = [chunk_size] * full_chunks
|
||||
if last_chunk:
|
||||
expected_chunks.append(last_chunk)
|
||||
llm.embed([prompt])
|
||||
chunks = llm.get_llm().llm_engine.collective_rpc(retrieve_chunks)[0]
|
||||
|
||||
# Check that PoolerWrapper was called and chunks were received
|
||||
assert len(chunks) > 1
|
||||
assert chunks == expected_chunks
|
||||
|
||||
# Disable chunked prefill
|
||||
with vllm_runner(
|
||||
model_id,
|
||||
runner="pooling",
|
||||
tensor_parallel_size=1,
|
||||
enforce_eager=True,
|
||||
) as llm:
|
||||
llm.get_llm().llm_engine.collective_rpc(inject_pooler)
|
||||
llm.embed([prompt])
|
||||
chunks = llm.get_llm().llm_engine.collective_rpc(retrieve_chunks)[0]
|
||||
|
||||
# Check that PoolerWrapper was called and no chunks were received
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0] == prompt_len
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available")
|
||||
def test_pooling_prefix_cache(vllm_runner, monkeypatch):
|
||||
"""Test chunked prefill for pooling models with LastPool."""
|
||||
|
||||
verses = prompt.split("\n\n")
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
model_id = "Qwen/Qwen3-Embedding-0.6B"
|
||||
|
||||
with vllm_runner(
|
||||
model_id,
|
||||
runner="pooling",
|
||||
enable_prefix_caching=True,
|
||||
tensor_parallel_size=1,
|
||||
enforce_eager=True,
|
||||
) as llm:
|
||||
llm.get_llm().llm_engine.collective_rpc(inject_pooler)
|
||||
tokenizer = llm.get_llm().get_tokenizer()
|
||||
|
||||
prompt1 = "\n\n".join([verses[0], verses[1]])
|
||||
prompt2 = "\n\n".join([verses[0], verses[2]])
|
||||
tokens1 = tokenizer(prompt1)["input_ids"]
|
||||
tokens2 = tokenizer(prompt2)["input_ids"]
|
||||
prompt1_len = len(tokens1)
|
||||
prompt2_len = len(tokens2)
|
||||
|
||||
llm.embed([prompt1])
|
||||
chunks = llm.get_llm().llm_engine.collective_rpc(retrieve_chunks)[0]
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0] == prompt1_len
|
||||
|
||||
llm.embed([prompt2])
|
||||
chunks = llm.get_llm().llm_engine.collective_rpc(retrieve_chunks)[0]
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0] <= prompt1_len
|
||||
assert chunks[0] < prompt2_len
|
||||
|
||||
cache_config = llm.get_llm().llm_engine.cache_config
|
||||
print(f"{cache_config=}")
|
||||
# Prefixes are cached in blocks
|
||||
assert (prompt2_len - chunks[0]) % cache_config.block_size == 0
|
||||
580
tests/v1/e2e/test_spec_decode.py
Normal file
580
tests/v1/e2e/test_spec_decode.py
Normal file
@@ -0,0 +1,580 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.base import VLLM_S3_BUCKET_URL
|
||||
from vllm.assets.image import VLM_IMAGES_DIR
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MTP_SIMILARITY_RATE = 0.8
|
||||
|
||||
|
||||
def _skip_if_insufficient_gpus_for_tp(tp_size: int):
|
||||
"""Skip test if available GPUs < tp_size on ROCm."""
|
||||
if current_platform.is_rocm():
|
||||
available_gpus = torch.cuda.device_count()
|
||||
if available_gpus < tp_size:
|
||||
pytest.skip(
|
||||
f"Test requires {tp_size} GPUs, but only {available_gpus} available"
|
||||
)
|
||||
|
||||
|
||||
def get_test_prompts(mm_enabled: bool):
|
||||
prompt_types = ["repeat", "sentence"]
|
||||
if mm_enabled:
|
||||
prompt_types.append("mm")
|
||||
num_prompts = 100
|
||||
prompts = []
|
||||
|
||||
random.seed(0)
|
||||
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
|
||||
print(f"Prompt types: {random_prompt_type_choices}")
|
||||
|
||||
# Generate a mixed batch of prompts, some of which can be easily
|
||||
# predicted by n-gram matching and some which likely cannot.
|
||||
for kind in random_prompt_type_choices:
|
||||
word_choices = ["test", "temp", "hello", "where"]
|
||||
word = random.choice(word_choices)
|
||||
prompt: str | list[dict[str, Any]] = ""
|
||||
if kind == "repeat":
|
||||
prompt = f"""
|
||||
please repeat the word '{word}' 10 times.
|
||||
give no other output than the word at least ten times in a row,
|
||||
in lowercase with spaces between each word and without quotes.
|
||||
"""
|
||||
elif kind == "sentence":
|
||||
prompt = f"""
|
||||
please give a ten-word sentence that
|
||||
uses the word {word} at least once.
|
||||
give no other output than that simple sentence without quotes.
|
||||
"""
|
||||
elif kind == "mm":
|
||||
placeholders = [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
|
||||
},
|
||||
}
|
||||
]
|
||||
prompt = [
|
||||
*placeholders,
|
||||
{"type": "text", "text": "The meaning of the image is"},
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unknown prompt type: {kind}")
|
||||
prompts.append([{"role": "user", "content": prompt}])
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sampling_config():
|
||||
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_name():
|
||||
return "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_torch_dynamo():
|
||||
"""Reset torch dynamo cache before each test"""
|
||||
yield
|
||||
# Cleanup after test
|
||||
torch._dynamo.reset()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"speculative_config",
|
||||
[
|
||||
{
|
||||
"method": "ngram",
|
||||
"prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 3,
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
{
|
||||
"method": "suffix",
|
||||
"suffix_decoding_max_spec_factor": 2.0,
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_ngram_and_suffix_correctness(
|
||||
speculative_config: dict,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
):
|
||||
"""
|
||||
Compare the outputs of an original LLM and a speculative LLM
|
||||
should be the same when using ngram speculative decoding.
|
||||
"""
|
||||
test_prompts = get_test_prompts(mm_enabled=False)
|
||||
|
||||
ref_llm = LLM(model=model_name, max_model_len=1024)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
speculative_config=speculative_config,
|
||||
max_model_len=1024,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||
|
||||
# Heuristic: expect at least 66% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches >= int(0.66 * len(ref_outputs))
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
def test_suffix_decoding_acceptance(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
):
|
||||
"""
|
||||
Check that suffix decoding caching takes effect and improves acceptance
|
||||
lengths and acceptance rates over multiple runs of the same prompts.
|
||||
"""
|
||||
test_prompts = get_test_prompts(mm_enabled=False)
|
||||
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
speculative_config={
|
||||
"method": "suffix",
|
||||
"suffix_decoding_max_spec_factor": 2.0,
|
||||
"suffix_decoding_max_cached_requests": 1000,
|
||||
},
|
||||
max_model_len=1024,
|
||||
disable_log_stats=False,
|
||||
)
|
||||
|
||||
# Run several times and check that the accepted tokens increase.
|
||||
num_draft = []
|
||||
num_accept = []
|
||||
for i in range(10): # Run multiple times to warm up the cache.
|
||||
spec_llm.chat(test_prompts, sampling_config)
|
||||
# Collect draft and acceptance stats.
|
||||
metrics = spec_llm.get_metrics()
|
||||
for metric in metrics:
|
||||
if metric.name == "vllm:spec_decode_num_draft_tokens":
|
||||
num_draft.append(metric.value)
|
||||
if metric.name == "vllm:spec_decode_num_accepted_tokens":
|
||||
num_accept.append(metric.value)
|
||||
|
||||
# Calculate the acceptance rates for the first and last runs.
|
||||
first_accept_tokens = num_accept[0]
|
||||
first_draft_tokens = num_draft[0]
|
||||
first_accept_rate = first_accept_tokens / first_draft_tokens
|
||||
|
||||
# Take the diff since the stats are cumulative.
|
||||
last_accept_tokens = num_accept[-1] - num_accept[-2]
|
||||
last_draft_tokens = num_draft[-1] - num_draft[-2]
|
||||
last_accept_rate = last_accept_tokens / last_draft_tokens
|
||||
|
||||
# Expect the acceptance length to improve.
|
||||
assert first_accept_tokens < last_accept_tokens
|
||||
|
||||
# Expect the acceptance rate to improve.
|
||||
assert first_accept_rate < last_accept_rate
|
||||
|
||||
# Heuristic: expect at least 80.0% acceptance rate at the end.
|
||||
assert last_accept_rate > 0.80
|
||||
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_path",
|
||||
[
|
||||
"RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3",
|
||||
"RedHatAI/Qwen3-8B-speculator.eagle3",
|
||||
],
|
||||
ids=["llama3_eagle3_speculator", "qwen3_eagle3_speculator"],
|
||||
)
|
||||
def test_speculators_model_integration(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
model_path: str,
|
||||
):
|
||||
"""
|
||||
Test that speculators models work with the simplified integration.
|
||||
|
||||
This verifies the `vllm serve <speculator-model>` use case where
|
||||
speculative config is automatically detected from the model config
|
||||
without requiring explicit --speculative-config argument.
|
||||
|
||||
Tests:
|
||||
1. Speculator model is correctly detected
|
||||
2. Verifier model is extracted from speculator config
|
||||
3. Speculative decoding is automatically enabled
|
||||
4. Text generation works correctly
|
||||
5. Output matches reference (non-speculative) generation
|
||||
"""
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
# Generate test prompts
|
||||
test_prompts = get_test_prompts(mm_enabled=False)
|
||||
|
||||
# First run: Direct speculator model (simplified integration)
|
||||
spec_llm = LLM(model=model_path, max_model_len=1024)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
|
||||
# Verify speculative config was auto-detected
|
||||
assert spec_llm.llm_engine.vllm_config.speculative_config is not None, (
|
||||
f"Speculative config should be auto-detected for {model_path}"
|
||||
)
|
||||
|
||||
spec_config = spec_llm.llm_engine.vllm_config.speculative_config
|
||||
assert spec_config.num_speculative_tokens > 0, (
|
||||
f"Expected positive speculative tokens, "
|
||||
f"got {spec_config.num_speculative_tokens}"
|
||||
)
|
||||
|
||||
# Verify draft model is set to the speculator model
|
||||
assert spec_config.model == model_path, (
|
||||
f"Draft model should be {model_path}, got {spec_config.model}"
|
||||
)
|
||||
|
||||
# Extract verifier model for reference run
|
||||
verifier_model = spec_llm.llm_engine.vllm_config.model_config.model
|
||||
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
# Second run: Reference without speculative decoding
|
||||
ref_llm = LLM(model=verifier_model, max_model_len=1024)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
# Compare outputs
|
||||
matches = sum(
|
||||
1
|
||||
for ref, spec in zip(ref_outputs, spec_outputs)
|
||||
if ref.outputs[0].text == spec.outputs[0].text
|
||||
)
|
||||
|
||||
# Heuristic: expect at least 66% of prompts to match exactly
|
||||
assert matches >= int(0.66 * len(ref_outputs)), (
|
||||
f"Only {matches}/{len(ref_outputs)} outputs matched. "
|
||||
f"Expected at least {int(0.66 * len(ref_outputs))} matches."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["model_setup", "mm_enabled", "enable_chunked_prefill", "model_impl"],
|
||||
[
|
||||
(
|
||||
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
),
|
||||
(
|
||||
("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1),
|
||||
False,
|
||||
False,
|
||||
"transformers",
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle3",
|
||||
"Qwen/Qwen3-VL-8B-Instruct",
|
||||
"taobao-mnn/Qwen3-VL-8B-Instruct-Eagle3",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
marks=pytest.mark.skip(
|
||||
reason="architecture of its eagle3 is LlamaForCausalLMEagle3"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle3",
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
"Rayzl/qwen2.5-vl-7b-eagle3-sgl",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
marks=pytest.mark.skip(
|
||||
reason="Skipping due to its head_dim not being a a multiple of 32"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
True,
|
||||
"auto",
|
||||
marks=large_gpu_mark(min_gb=40),
|
||||
), # works on 4x H100
|
||||
(
|
||||
(
|
||||
"eagle3",
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
),
|
||||
pytest.param(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
|
||||
4,
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
marks=large_gpu_mark(min_gb=80),
|
||||
), # works on 4x H100
|
||||
pytest.param(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
|
||||
4,
|
||||
),
|
||||
True,
|
||||
True,
|
||||
"auto",
|
||||
marks=large_gpu_mark(min_gb=80),
|
||||
), # works on 4x H100
|
||||
(
|
||||
(
|
||||
"eagle",
|
||||
"eagle618/deepseek-v3-random",
|
||||
"eagle618/eagle-deepseek-v3-random",
|
||||
1,
|
||||
),
|
||||
False,
|
||||
False,
|
||||
"auto",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"qwen3_eagle3",
|
||||
"qwen3_eagle3-transformers",
|
||||
"qwen3_vl_eagle3",
|
||||
"qwen2_5_vl_eagle3",
|
||||
"llama3_eagle",
|
||||
"llama3_eagle3",
|
||||
"llama4_eagle",
|
||||
"llama4_eagle_mm",
|
||||
"deepseek_eagle",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
|
||||
def test_eagle_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
model_setup: tuple[str, str, str, int],
|
||||
mm_enabled: bool,
|
||||
enable_chunked_prefill: bool,
|
||||
model_impl: str,
|
||||
attn_backend: str,
|
||||
):
|
||||
if attn_backend == "TREE_ATTN":
|
||||
# TODO: Fix this flaky test
|
||||
pytest.skip(
|
||||
"TREE_ATTN is flaky in the test disable for now until it can be "
|
||||
"resolved (see https://github.com/vllm-project/vllm/issues/22922)"
|
||||
)
|
||||
if model_impl == "transformers":
|
||||
import transformers
|
||||
from packaging.version import Version
|
||||
|
||||
installed = Version(transformers.__version__)
|
||||
required = Version("5.0.0.dev")
|
||||
if installed < required:
|
||||
pytest.skip(
|
||||
"Eagle3 with the Transformers modeling backend requires "
|
||||
f"transformers>={required}, but got {installed}"
|
||||
)
|
||||
|
||||
# Generate test prompts inside the function instead of using fixture
|
||||
test_prompts = get_test_prompts(mm_enabled)
|
||||
"""
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using eagle speculative decoding.
|
||||
model_setup: (method, model_name, eagle_model_name, tp_size)
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
|
||||
# Scout requires default backend selection
|
||||
# because vision encoder has head_dim 88 being incompatible
|
||||
# with FLASH_ATTN and needs to fall back to Flex Attn
|
||||
|
||||
# pass if not ROCm
|
||||
if current_platform.is_rocm():
|
||||
# TODO: Enable Flex Attn for spec_decode on ROCm
|
||||
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
|
||||
else:
|
||||
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"TRITON_ATTN does not support "
|
||||
"multi-token eagle spec decode on current platform"
|
||||
)
|
||||
|
||||
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
|
||||
if "deepseek" in model_setup[1].lower():
|
||||
pytest.skip("ROCM_AITER_FA for deepseek not supported on ROCm platform")
|
||||
else:
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
method, model_name, spec_model_name, tp_size = model_setup
|
||||
_skip_if_insufficient_gpus_for_tp(tp_size)
|
||||
|
||||
max_model_len = 2048
|
||||
max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len
|
||||
|
||||
ref_llm = LLM(
|
||||
model=model_name, max_model_len=max_model_len, tensor_parallel_size=tp_size
|
||||
)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
tensor_parallel_size=tp_size,
|
||||
speculative_config={
|
||||
"method": method,
|
||||
"model": spec_model_name,
|
||||
"num_speculative_tokens": 3,
|
||||
"max_model_len": max_model_len,
|
||||
},
|
||||
max_model_len=max_model_len,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
model_impl=model_impl,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||
|
||||
# Heuristic: expect at least 60% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(0.6 * len(ref_outputs))
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["model_setup", "mm_enabled"],
|
||||
[
|
||||
(("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False),
|
||||
(("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False),
|
||||
],
|
||||
ids=["mimo", "deepseek"],
|
||||
)
|
||||
def test_mtp_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
model_setup: tuple[str, str, int],
|
||||
mm_enabled: bool,
|
||||
):
|
||||
# Generate test prompts inside the function instead of using fixture
|
||||
test_prompts = get_test_prompts(mm_enabled)
|
||||
"""
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using MTP speculative decoding.
|
||||
model_setup: (method, model_name, tp_size)
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||
|
||||
method, model_name, tp_size = model_setup
|
||||
_skip_if_insufficient_gpus_for_tp(tp_size)
|
||||
|
||||
ref_llm = LLM(
|
||||
model=model_name,
|
||||
max_model_len=2048,
|
||||
tensor_parallel_size=tp_size,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
tensor_parallel_size=tp_size,
|
||||
speculative_config={
|
||||
"method": method,
|
||||
"num_speculative_tokens": 1,
|
||||
"max_model_len": 2048,
|
||||
},
|
||||
max_model_len=2048,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||
|
||||
# Heuristic: expect at least 80% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(MTP_SIMILARITY_RATE * len(ref_outputs))
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
Reference in New Issue
Block a user