[SpecDecode] Add spec decode support (#500)

### What this PR does / why we need it?
Backport: https://github.com/vllm-project/vllm-ascend/pull/252
This support speculative decoding in Ascend, including speculating with
a draft model、by matching n-grams in the prompt、using MLP speculators
and using EAGLE based draft models.

Backport: https://github.com/vllm-project/vllm-ascend/pull/423
spec decode MultiStepWorker support TP1DraftModelRunner fully, support
run the draft_model_runner with multi-step prepare on the NPU directly
and support draft_model_runner use MLA.

1. before this pr, `MultiStepWorker` would not step into the branch
using NPU prepare, but only into the branch using CPU prepare (`line 52`
of `vllm_ascend/patch/patch_multi_step_worker.py`). Although this has
`no effect` on the `correct operation` of speculative decoding and the
performance of the two branches is basically the same as of the current
version, I support entering this branch in this PR. In general, there
are two main changes in `patch_multi_step_worker.py`: first, the
`is_cuda_like()` check is removed and the `TP1DraftModelRunner`
rewritten in vllm_ascend is used; second, the
`supports_gpu_multi_step()` function is made to return true on NPU
devices when outer Multi_step_worker could work correct.

3. before this pr, `TP1DraftModelRunner` only supports Attention on NPU,
but not MLA. The relevant adaptation is in
`vllm_ascend/worker/draft_model_runner.py`. Although I don’t know why
the `input_positions` of `model_input.attn_metadata` in vllm-ascend
needs to be added in `execute_model`, it is done in `model_runner.py`,
so I also made corresponding changes. Otherwise, when atten_backend is
MLA, it will prompt that input_positions cannot be found.

4. I commented out two lines in `draft_model_runner.py` in `line118` to
support the scenario of K>1.
  ```
  # lora_mapping=model_input.lora_mapping,
  # lora_requests=model_input.lora_requests,
  ```
I added comments. In the future, when vllm-ascend supports lora feature,
the changes here can be restored.

TODO:
- [ ] revert the patch when the related issues are addressed in vllm

### How was this patch tested?
CI passed with new added test.
- e2e test for medusa proposer:
tests/singlecard/spec_decode/e2e/test_medusa_correctness.py
- e2e test for mlp proposer:
tests/singlecard/spec_decode/e2e/test_mlp_correctness.py
- e2e test for n-gram proposer:
tests/singlecard/spec_decode/e2e/test_ngram_correctness.py

Tests for patched files:
- tests/singlecard/spec_decode/test_dynamic_spec_decode.py
- tests/singlecard/spec_decode/test_multi_step_worker.py
- tests/singlecard/spec_decode/test_ngram_worker.py
- tests/singlecard/spec_decode/test_spec_decode_worker.py

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
Co-authored-by: mengwei805 <mengwei25@huawei.com>
This commit is contained in:
Mengqing Cao
2025-04-17 20:16:32 +08:00
committed by GitHub
parent b71f193cb0
commit 6ee7f5cf71
27 changed files with 5813 additions and 11 deletions

View File

@@ -122,10 +122,10 @@ jobs:
VLLM_USE_V1: 0
run: |
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
pytest -sv tests/singlecard
pytest -sv tests/singlecard/test_offline_inference.py
pytest -sv tests/ops
else
pytest -sv tests/multicard
pytest -sv tests/multicard/test_offline_inference_distributed.py
pytest -sv tests/ops
fi
@@ -135,13 +135,35 @@ jobs:
VLLM_WORKER_MULTIPROC_METHOD: spawn
run: |
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
pytest -sv tests/singlecard
pytest -sv tests/singlecard/test_offline_inference.py
pytest -sv tests/ops
else
pytest -sv tests/multicard
pytest -sv tests/multicard/test_offline_inference_distributed.py
pytest -sv tests/ops
fi
- name: Check for changes in Speculative Decode
id: filter_spec_decode
uses: dorny/paths-filter@v2
with:
filters: |
speculative_tests_changed:
- "tests/singlecard/spec_decode/**"
- "tests/multicard/spec_decode_e2e/**"
- "vllm_ascend/worker/multi_step_runner.py"
- "vllm_ascend/worker/multi_step_worker.py"
- "vllm_ascend/patch/patch_rejection_sampler.py"
- "vllm_ascend/patch/patch_spec_decode_worker.py"
- "vllm_ascend/patch/patch_multi_step_worker.py"
- name: Run vllm-project/vllm-ascend Speculative Decode test
env:
HF_ENDPOINT: https://hf-mirror.com
if: steps.filter_spec_decode.outputs.speculative_tests_changed
run: |
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
pytest -sv tests/singlecard/spec_decode
fi
- name: Run vllm-project/vllm test for V0 Engine
env:
VLLM_USE_V1: 0

View File

@@ -3,4 +3,4 @@ modelscope
pytest >= 6.0
pytest-asyncio
lm-eval
ray
ray

0
tests/__init__.py Normal file
View File

View File

@@ -24,7 +24,8 @@ import os
import pytest
import vllm # noqa: F401
from conftest import VllmRunner
from tests.conftest import VllmRunner
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"

0
tests/ops/__init__.py Normal file
View File

View File

View File

View File

@@ -0,0 +1,28 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/spec_decode/conftest.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
Since this module is V0 only, set VLLM_USE_V1=0 for
all tests in the module.
"""
monkeypatch.setenv('VLLM_USE_V1', '0')

View File

@@ -0,0 +1,256 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/conftest.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from itertools import cycle
from typing import List, Optional, Sequence, Tuple, Union
import pytest
import torch
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import PromptLogprobs, SampleLogprobs
from ....model_utils import (TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
check_logprobs_close, check_outputs_equal)
PROMPTS = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"San Francisco is know for its",
"Facebook was created in 2004 by",
"Curious George is a",
"Python 3.11 brings improvements to its",
]
@pytest.fixture
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, seed):
def generate():
kwargs = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**test_llm_kwargs,
}
llm = LLM(**kwargs)
if seed is not None:
set_random_seed(seed)
yield llm
del llm
cleanup_dist_env_and_memory()
return generate
def maybe_assert_ngram_worker(llm):
# Verify the proposer worker is ngram if ngram is specified.
if (llm.llm_engine.speculative_config is not None
and llm.llm_engine.speculative_config.method == "ngram"):
from vllm.spec_decode.ngram_worker import NGramWorker
assert isinstance(
llm.llm_engine.model_executor.driver_worker.proposer_worker,
NGramWorker)
def get_output_from_llm_generator(
llm_generator, prompts,
sampling_params) -> Tuple[List[str], List[List[int]], float]:
tokens: List[str] = []
token_ids: List[List[int]] = []
acceptance_rate: float = -1.0
for llm in llm_generator():
maybe_assert_ngram_worker(llm)
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
token_ids = [output.outputs[0].token_ids for output in outputs]
tokens = [output.outputs[0].text for output in outputs]
# Fetch acceptance rate if logging is enabled.
if stat_loggers := getattr(llm.llm_engine, "stat_loggers", None):
stat_logger = stat_loggers["prometheus"]
acceptance_rate = (stat_logger.metrics.
gauge_spec_decode_draft_acceptance_rate.labels(
**stat_logger.labels)._value.get())
del llm
return tokens, token_ids, acceptance_rate
def check_logprobs_correctness(
spec_outputs: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs]],
baseline_outputs: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs]],
disable_logprobs: bool = False,
):
"""Compare sampled and prompt logprobs between baseline and spec decoding
"""
if not disable_logprobs:
return check_logprobs_close(
outputs_0_lst=baseline_outputs,
outputs_1_lst=spec_outputs,
name_0="org",
name_1="sd",
)
# Check correctness when disable_logprobs == True
for spec_output, baseline_output in zip(spec_outputs, baseline_outputs):
# Check generated token logprobs.
spec_logprobs = spec_output[2]
baseline_logprobs = baseline_output[2]
_check_logprobs_when_output_disabled(spec_logprobs,
baseline_logprobs,
is_prompt_logprobs=False)
# Check prompt logprobs too, if they exist
if len(baseline_output) == 4:
assert len(spec_output) == 4
spec_prompt_logprobs = spec_output[3]
baseline_prompt_logprobs = baseline_output[3]
_check_logprobs_when_output_disabled(spec_prompt_logprobs,
baseline_prompt_logprobs,
is_prompt_logprobs=True)
def _check_logprobs_when_output_disabled(
spec_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
baseline_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
is_prompt_logprobs: bool = False,
):
# Prompt logprobs are optional
if is_prompt_logprobs and baseline_logprobs is None:
assert spec_logprobs is None
return
assert spec_logprobs is not None
assert baseline_logprobs is not None
assert len(spec_logprobs) == len(baseline_logprobs)
# For each generated position of the sequence.
for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
zip(spec_logprobs, baseline_logprobs)):
# First prompt logprob is expected to be None
if is_prompt_logprobs and baseline_pos_logprobs is None:
assert spec_pos_logprobs is None
assert pos == 0
continue
assert spec_pos_logprobs is not None
assert baseline_pos_logprobs is not None
# When disabled, the 1 logprob is returned with dummy values for the
# score and rank, but the token id should match the baseline model
assert len(spec_pos_logprobs) == 1
(spec_pos_logprob_token_id,
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
assert spec_pos_logprob.rank == -1
assert spec_pos_logprob.logprob == 0.0
if isinstance(spec_pos_logprob_token_id, torch.Tensor):
spec_pos_logprob_token_id = spec_pos_logprob_token_id.item()
assert spec_pos_logprob_token_id in baseline_pos_logprobs
def run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size: int,
max_output_len: int,
seed: Optional[int] = 0,
temperature: float = 0.0,
disable_seed: bool = False,
ignore_eos: bool = True,
ensure_all_accepted: bool = False,
expected_acceptance_rate: Optional[float] = None,
logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None,
disable_logprobs: bool = False):
org_args = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**baseline_llm_kwargs,
}
sd_args = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**test_llm_kwargs,
}
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
if disable_seed:
seed = None
sampling_params = SamplingParams(temperature=temperature,
max_tokens=max_output_len,
seed=seed,
ignore_eos=ignore_eos,
logprobs=logprobs,
prompt_logprobs=prompt_logprobs)
with vllm_runner(**org_args) as vllm_model:
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
with vllm_runner(**sd_args) as vllm_model:
if ensure_all_accepted or expected_acceptance_rate is not None:
# Force log interval to be 0 to catch all metrics.
stat_logger = vllm_model.model.llm_engine.stat_loggers[
'prometheus']
stat_logger.local_interval = -100
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
if ensure_all_accepted or expected_acceptance_rate is not None:
acceptance_rate = (stat_logger.metrics.
gauge_spec_decode_draft_acceptance_rate.labels(
**stat_logger.labels)._value.get())
if ensure_all_accepted:
assert True
# FIXME: ci fails to log acceptance rate.
# It works locally.
# assert acceptance_rate == 1.0
if expected_acceptance_rate is not None:
assert acceptance_rate >= expected_acceptance_rate - 1e-2
# Only pass token entries, not the logprobs
check_outputs_equal(outputs_0_lst=[out[0:2] for out in org_outputs],
outputs_1_lst=[out[0:2] for out in sd_outputs],
name_0="org",
name_1="sd")
# Check logprobs if requested
if logprobs is not None or prompt_logprobs is not None:
check_logprobs_correctness(sd_outputs, org_outputs, disable_logprobs)

View File

@@ -0,0 +1,451 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_medusa_correctness.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.
With those tests, we can say at least, Medusa would not break the
correctess for the target model outputs.
"""
import os
import pytest
from tests.singlecard.spec_decode.e2e.conftest import \
run_equality_correctness_test
from tests.singlecard.spec_decode.utils import maybe_enable_chunked_prefill
# main model
# lmsys/vicuna-7b-v1.3 was to be used but it's causing
# OOM in CI pipeline, so using a smaller model.
MAIN_MODEL = "JackFram/llama-68m"
# speculative model
SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random"
# max number of speculative tokens: this corresponds to
# num_heads in the config.json of the speculator model.
MAX_SPEC_TOKENS = 5
# precision
# TODO: The vLLM here uses float32, but some op on the vllm-ascend
# do not support float32, such as ROPE, When it is fixed, it is
# recommended to change this to float32 to keep it consistent
# with vLLM.
PRECISION = "float16"
PREFILL_CHUNK_SIZE = [
-1,
# TODO:enable chunked prefill when it is supported
# 32
]
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int, prefill_chunk_size: int):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": False,
},
},
{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": True,
},
},
])
@pytest.mark.parametrize("output_len", [
8,
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int, logprobs: int,
prefill_chunk_size: int):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
# TODO: Open it when vllm-ascend support graph mode and
# @pytest.mark.parametrize(
# "common_llm_kwargs",
# [{
# "enforce_eager": False,
# # Print spec metrics.
# "disable_log_stats": False,
# # Precision
# "dtype": PRECISION,
# # Main model
# "model_name": MAIN_MODEL,
# }])
# @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
# @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
# @pytest.mark.parametrize("test_llm_kwargs", [
# {
# "speculative_config": {
# "model": SPEC_MODEL,
# "num_speculative_tokens": MAX_SPEC_TOKENS,
# },
# },
# ])
# @pytest.mark.parametrize("output_len", [
# 128,
# ])
# @pytest.mark.parametrize("batch_size", [1, 32])
# @pytest.mark.parametrize("seed", [1])
# @pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
# def test_medusa_e2e_greedy_correctness_cuda_graph(
# vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
# baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
# seed: int, prefill_chunk_size: int):
# """Verify greedy equality with cuda graph enabled and different
# batch sizes."""
# maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
# run_equality_correctness_test(vllm_runner,
# common_llm_kwargs,
# per_test_common_llm_kwargs,
# baseline_llm_kwargs,
# test_llm_kwargs,
# batch_size,
# max_output_len=output_len,
# seed=seed,
# temperature=0.0)
# TODO: There is a problem with the preemptive scheduling in the current
# version, which makes this case fail. Please release this case after the
# preemptive scheduling preblem is solved.
# @pytest.mark.parametrize(
# "common_llm_kwargs",
# [{
# "block_size": 8,
# # 2 for small prompt, 256//8 for generated.
# "num_gpu_blocks_override": 2 + 256 // 8,
# "max_model_len": (2 + 256 // 8) * 8,
# # Skip cuda graph recording for fast test.
# "enforce_eager": True,
# # Precision
# "dtype": PRECISION,
# # Main model
# "model_name": MAIN_MODEL,
# }])
# @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
# @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
# @pytest.mark.parametrize("test_llm_kwargs", [
# {
# "speculative_config": {
# "model": SPEC_MODEL,
# "num_speculative_tokens": MAX_SPEC_TOKENS,
# },
# },
# ])
# @pytest.mark.parametrize(
# "output_len",
# [
# # Use small output len for fast test.
# 128,
# ])
# @pytest.mark.parametrize("batch_size", [4])
# @pytest.mark.parametrize("seed", [1])
# @pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
# def test_medusa_e2e_greedy_correctness_with_preemption(
# vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
# baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
# seed: int, prefill_chunk_size: int):
# """Verify greedy equality, even when some sequences are preempted mid-
# generation.
# """
# maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
# run_equality_correctness_test(vllm_runner,
# common_llm_kwargs,
# per_test_common_llm_kwargs,
# baseline_llm_kwargs,
# test_llm_kwargs,
# batch_size,
# max_output_len=output_len,
# seed=seed,
# temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": k,
},
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
def test_medusa_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int, prefill_chunk_size: int):
"""Verify that medusa speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_by_batch_size": 4,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
output_len: int, seed: int,
prefill_chunk_size: int):
"""Verify that medusa speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_by_batch_size": 4,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int, prefill_chunk_size: int):
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
if __name__ == "__main__":
import pytest
pytest.main([__file__])

View File

@@ -0,0 +1,564 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_mlp_correctness.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.
With those tests, we can say at least, MLPSpeculator would not break the
correctness for the target model outputs.
"""
import pytest
from vllm.model_executor.layers.vocab_parallel_embedding import \
pad_vocab_size # noqa: F401
from tests.singlecard.spec_decode.e2e.conftest import \
run_equality_correctness_test
from tests.singlecard.spec_decode.utils import maybe_enable_chunked_prefill
# main model
MAIN_MODEL = "JackFram/llama-160m"
# speculative model
SPEC_MODEL = "ibm-ai-platform/llama-160m-accelerator"
# max. number of speculative tokens: this corresponds to
# n_predict in the config.json of the speculator model.
MAX_SPEC_TOKENS = 3
PREFILL_CHUNK_SIZE_1 = [
-1,
# TODO:enable chunked prefill when it is supported
# 4
]
PREFILL_CHUNK_SIZE_2 = [
-1,
# TODO:enable chunked prefill when it is supported
# 32
]
# precision
# TODO: The vLLM here uses float32, but some op on the vllm-ascend
# do not support float32, such as ROPE, When it is fixed, it is
# recommended to change this to float32 to keep it consistent
# with vLLM.
PRECISION = "float16"
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [4, 32])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_2)
def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int, prefill_chunk_size: int):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
"disable_logprobs": False,
},
},
{
"speculative_config": {
"model": SPEC_MODEL,
"disable_logprobs": True,
},
},
])
@pytest.mark.parametrize("output_len", [8])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int, prefill_chunk_size: int):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
# NOTE Test is sensitive enough st if we don't enable chunked prefill
# scheduling on baseline too, we get slightly different logprobs, ending
# up sampling different tokens at the tail (ie top tokens don't change).
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
},
},
])
@pytest.mark.parametrize("output_len", [2048])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
prefill_chunk_size: int, seed: int):
"""Verify acceptance rate with different batch size and large output
length."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
temperature=0.0,
seed=seed,
expected_acceptance_rate=0.48)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# Speculative config
"speculative_config": {
"model": SPEC_MODEL,
},
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
@pytest.mark.parametrize("output_len", [64])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("temperature", [1.0])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
temperature: float,
prefill_chunk_size: int, seed: int):
"""Verify seeded runs produce the same output."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
temperature=temperature,
seed=seed)
# Ensure this same test does fail if we _don't_ include per-request seeds
with pytest.raises(AssertionError):
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
temperature=temperature,
seed=seed,
disable_seed=True)
# TODO: There is a problem with the preemptive scheduling in the current
# version, which makes this case fail. Please release this case after the
# preemptive scheduling preblem is solved.
# @pytest.mark.parametrize(
# "common_llm_kwargs",
# [{
# "block_size": 8,
# # 2 for small prompt, 256//8 for generated.
# "num_gpu_blocks_override": 2 + 256 // 8,
# "max_model_len": (2 + 256 // 8) * 8,
# # Skip cuda graph recording for fast test.
# "enforce_eager": True,
# # Precision
# "dtype": PRECISION,
# # Main model
# "model_name": MAIN_MODEL,
# }])
# @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
# @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
# @pytest.mark.parametrize("test_llm_kwargs", [
# {
# "speculative_config": {
# "model": SPEC_MODEL,
# },
# },
# ])
# @pytest.mark.parametrize(
# "output_len",
# [
# # Use small output len for fast test.
# 128,
# ])
# @pytest.mark.parametrize("batch_size", [4])
# @pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
# @pytest.mark.parametrize("seed", [1])
# def test_mlp_e2e_greedy_correctness_with_preemption(
# vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
# baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
# prefill_chunk_size: int, seed: int):
# """Verify greedy equality, even when some sequences are preempted mid-
# generation.
# """
# maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
# run_equality_correctness_test(vllm_runner,
# common_llm_kwargs,
# per_test_common_llm_kwargs,
# baseline_llm_kwargs,
# test_llm_kwargs,
# batch_size,
# max_output_len=output_len,
# seed=seed,
# temperature=0.0)
# TODO: There is a problem with the preemptive scheduling in the current
# version, which makes this case fail. Please release this case after the
# preemptive scheduling preblem is solved.
# @pytest.mark.parametrize(
# "common_llm_kwargs",
# [{
# "block_size": 8,
# # 2 for small prompt, 256//8 for generated.
# "num_gpu_blocks_override": 2 + 256 // 8,
# "max_model_len": (2 + 256 // 8) * 8,
# # Skip cuda graph recording for fast test.
# "enforce_eager": True,
# # Precision
# "dtype": PRECISION,
# # Main model
# "model_name": MAIN_MODEL,
# }])
# @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
# @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
# @pytest.mark.parametrize("test_llm_kwargs", [
# {
# "speculative_config": {
# "model": SPEC_MODEL,
# },
# },
# ])
# @pytest.mark.parametrize(
# "output_len",
# [
# # Use small output len for fast test.
# 128,
# ])
# @pytest.mark.parametrize("batch_size", [4])
# @pytest.mark.parametrize("seed", [1])
# @pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
# def test_mlp_e2e_greedy_correctness_with_padding(
# vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
# baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
# prefill_chunk_size: int, seed: int):
# """Verify greedy equality when the vocab dimension is padded
# """
# maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
# # Default pad_to is 64, test model has vocab_size of 32000
# def patched_pad_vocab_size(vocab_size, pad_to=None):
# return pad_vocab_size(vocab_size, pad_to=32064)
# # NOTE: Compared with vLLM, the patch method has been modified
# from vllm.model_executor.layers.vocab_parallel_embedding import pad_vocab_size
# pad_vocab_size = patched_pad_vocab_size
# run_equality_correctness_test(vllm_runner,
# common_llm_kwargs,
# per_test_common_llm_kwargs,
# baseline_llm_kwargs,
# test_llm_kwargs,
# batch_size,
# max_output_len=output_len,
# seed=seed,
# temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": k,
},
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
@pytest.mark.parametrize("seed", [1])
def test_mlp_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
prefill_chunk_size: int, seed: int, output_len: int):
"""Verify that mlp speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"disable_by_batch_size": 4,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
# Speculative decoding is disabled when sequences reach decoding and the batch
# consists of single-token requests. Hence we set `max_num_seqs`
# >= `speculative_disable_by_batch_size` to test feature interaction.
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
@pytest.mark.parametrize("seed", [1])
def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
prefill_chunk_size: int, seed: int,
output_len: int):
"""Verify that mlp speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": MAIN_MODEL,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
@pytest.mark.parametrize("seed", [1])
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, prefill_chunk_size: int, seed: int):
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)

View File

@@ -0,0 +1,406 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_ngram_correctness.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding,
and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775.
Since there is no model is needed for generate the proposal, we could make
the testcase much simpler than drafter multi-step one.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various ngram sizes / speculative sizes
With those tests, we can say at least, ngram spec would not break the correctess
for the target model outputs.
"""
import pytest
from tests.singlecard.spec_decode.e2e.conftest import \
run_equality_correctness_test
from tests.singlecard.spec_decode.utils import maybe_enable_chunked_prefill
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"model_name": "JackFram/llama-68m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": False,
},
},
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": True,
},
},
])
@pytest.mark.parametrize("output_len", [
256,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize(
"prefill_chunk_size",
[
-1,
# TODO:enable chunked prefill when it is supported
# 4
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
prefill_chunk_size: int, seed: int):
"""Verify greedy equality on a tiny model with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"model_name": "JackFram/llama-68m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_logprobs": False,
},
},
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_logprobs": True,
},
},
])
@pytest.mark.parametrize("output_len", [
8,
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int):
"""Verify greedy equality on a tiny model with different batch size."""
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
# TODO: There is a problem with the preemptive scheduling in the current
# version, which makes this case fail. Please release this case after the
# preemptive scheduling preblem is solved.
# @pytest.mark.parametrize(
# "common_llm_kwargs",
# [{
# "block_size": 8,
# # 2 for small prompt, 256//8 for generated.
# "num_gpu_blocks_override": 2 + 256 // 8,
# "max_model_len": (2 + 256 // 8) * 8,
# # Skip cuda graph recording for fast test.
# "enforce_eager": True,
# }])
# @pytest.mark.parametrize("per_test_common_llm_kwargs", [
# {
# "model_name": "JackFram/llama-160m",
# },
# ])
# @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
# @pytest.mark.parametrize("test_llm_kwargs", [
# {
# "speculative_config": {
# "method": "ngram",
# "num_speculative_tokens": 5,
# "prompt_lookup_max": 3,
# },
# "enable_chunked_prefill": False,
# },
# {
# "speculative_config": {
# "method": "ngram",
# "num_speculative_tokens": 5,
# "prompt_lookup_max": 3,
# "disable_mqa_scorer": True,
# },
# "enable_chunked_prefill": True,
# "max_num_batched_tokens": 4,
# "max_num_seqs": 4
# },
# ])
# @pytest.mark.parametrize(
# "output_len",
# [
# # Use small output len for fast test.
# 256,
# ])
# @pytest.mark.parametrize("batch_size", [4])
# @pytest.mark.parametrize("seed", [1])
# def test_ngram_e2e_greedy_correctness_with_preemption(
# vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
# baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
# seed: int):
# """Verify greedy equality, even when some sequences are preempted mid-
# generation.
# """
# run_equality_correctness_test(vllm_runner,
# common_llm_kwargs,
# per_test_common_llm_kwargs,
# baseline_llm_kwargs,
# test_llm_kwargs,
# batch_size,
# max_output_len=output_len,
# temperature=0,
# seed=seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": k,
"prompt_lookup_max": 3,
},
}
# Try a range of common k, as well as large speculation.
for k in [1, 3, 5]
] + [
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": k,
"prompt_lookup_max": 1,
},
}
# Try a range of common k, as well as large speculation.
for k in [1, 3, 5]
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_by_batch_size": 4
},
},
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_by_batch_size": 4,
"disable_mqa_scorer": True,
},
"enable_chunked_prefill": False,
# FIXME: enable me when chunked prefill is available
# "max_num_batched_tokens": 4,
"max_num_seqs": 4
}
])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_scorer(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that ngram speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)

View File

@@ -0,0 +1,106 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/spec_decode/test_dynamic_spec_decode.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from unittest.mock import MagicMock, patch
import pytest
import torch
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from tests.singlecard.spec_decode.test_utils import mock_spec_decode_sampler
from tests.singlecard.spec_decode.utils import create_batch, mock_worker
from vllm_ascend.patch.worker import patch_common # noqa: F401
@pytest.mark.parametrize('queue_size', [4])
@pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('k', [1])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
acceptance_sampler_method: str):
"""Verify that speculative tokens are disabled when the batch size
exceeds the threshold.
"""
disable_by_batch_size = 3
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector,
disable_by_batch_size=disable_by_batch_size)
exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
running_queue_size=queue_size)
if queue_size > disable_by_batch_size:
with patch.object(worker,
'_run_no_spec',
side_effect=ValueError(exception_secret)), \
pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
# When the batch size is larger than the threshold,
# we expect no speculative tokens (0).
expected_num_spec_tokens = None if queue_size < disable_by_batch_size else 0
assert seq_group_metadata_list[
0].num_speculative_tokens == expected_num_spec_tokens
draft_worker.sampler_output.side_effect = ValueError(exception_secret)
proposer = Top1Proposer(
worker=draft_worker,
device='cpu', # not used
vocab_size=100, # not used
# Must be long enough to avoid being skipped due to length.
max_proposal_len=1024,
)
if queue_size < disable_by_batch_size:
# Should raise exception when executing the mocked draft model.
with pytest.raises(ValueError, match=exception_secret):
proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
else:
# Should not execute the draft model because spec decode is disabled
# for all requests. Accordingly, the proposal length should be 0.
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert proposals.proposal_lens.tolist() == [0] * batch_size

View File

@@ -0,0 +1,847 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/spec_decode/test_multi_step_worker.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import random
from unittest.mock import MagicMock
import pytest
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob,
get_all_seq_ids)
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from tests.singlecard.spec_decode.utils import (
assert_logprobs_dict_allclose, create_batch,
create_seq_group_metadata_from_prompts, create_worker,
patch_execute_model_with_seeds, zero_kv_cache)
from vllm_ascend.patch.worker import patch_common # noqa: F401
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
from vllm_ascend.worker.worker import NPUWorker
@pytest.mark.parametrize('num_steps', list(range(1, 17)))
def test_assert_enough_kv_space(num_steps: int):
"""Test that the multi step worker checks for sufficient space in the KV
cache. It should throw if it cannot run all the steps.
"""
block_size = 16
num_gpu_blocks = 2048 // block_size
prompts = [
list(range(block_size * 3)),
list(range(block_size * 2)),
]
prev_output_tokens = [
list(range(block_size * 1)),
list(range(block_size * 2)),
]
final_prompt_lens = [
len(prompt + output) + num_steps
for prompt, output in zip(prompts, prev_output_tokens)
]
inputs = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens,
continuations=prev_output_tokens)
assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access
worker = MagicMock()
worker.model_runner.block_size = block_size
for seq_group_metadata in inputs:
original_block_tables = seq_group_metadata.block_tables
# No exception.
assert_enough_kv_space(worker, inputs, num_steps)
seq_group_metadata.block_tables = {
seq_id: []
for seq_id, physical_blocks in original_block_tables.items()
}
# Expect exception.
with pytest.raises(ValueError,
match='times but found insufficient KV space for'):
assert_enough_kv_space(worker, inputs, num_steps)
seq_group_metadata.block_tables = original_block_tables
@torch.inference_mode()
def test_same_output_for_single_step():
"""Verify the multi step worker produces the same output as the normal
worker for num_steps=1.
"""
seed = 100
model_name = 'JackFram/llama-68m'
block_size = 32
num_gpu_blocks = 2048 // block_size
multi_step_worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
worker = create_worker(
NPUWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
# multi_step_worker.model_runner = worker.model_runner
# multi_step_worker.cache_engine = worker.cache_engine
num_steps = 1
prompts = [
[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10],
]
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
multi_step_seq_group = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
actual_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=multi_step_seq_group),
sample_len=num_steps,
seq_ids_with_bonus_token_in_last_step=set())
assert len(actual_output) == num_steps
actual_output = actual_output[0]
single_step_seq_group = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)
zero_kv_cache(worker.cache_engine)
set_random_seed(seed)
expected_output = worker.execute_model(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=single_step_seq_group))[0]
actual_token_ids = [
output.samples[0].output_token for output in actual_output
]
actual_logprobs = [output.samples[0].logprobs for output in actual_output]
expected_token_ids = [
output.samples[0].output_token for output in expected_output
]
expected_logprobs = [
output.samples[0].logprobs for output in expected_output
]
assert actual_token_ids == expected_token_ids
print(f'{actual_logprobs=}')
print(f'{expected_logprobs=}')
assert_logprobs_dict_allclose(actual_logprobs, expected_logprobs)
@torch.inference_mode()
def test_same_output_for_multi_step():
"""Verify the multi-step worker produces the same output as the normal
worker when num_steps > 1. This test runs the multi-step worker once, and
then runs the worker num_steps times, and compares the output.
"""
seed = 100
model_name = 'JackFram/llama-68m'
block_size = 16
num_gpu_blocks = 2048 // block_size
multi_step_worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
worker = create_worker(
NPUWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
# Make sure we go over the block boundary.
num_steps = block_size + 1
random.seed(seed)
prompts = [[
random.randint(0, 1000) for _ in range(random.randint(10, 20))
] for _ in range(10)]
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
multi_step_worker.execute_model = patch_execute_model_with_seeds(
multi_step_worker, rand_seeds)
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
continuations = [[1] for _ in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
# Run multi-step.
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
multi_step_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=num_steps,
seq_ids_with_bonus_token_in_last_step=set())
# Run single-step repeatedly.
zero_kv_cache(worker.cache_engine)
single_step_output: list[SamplerOutput] = []
continuations = [[1] for _ in prompts]
set_random_seed(seed)
for _ in multi_step_output:
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
single_step_output.extend(
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list)))
# Append output tokens to new sequence data.
for i, seq_group_output in enumerate(single_step_output[-1]):
continuations[i].append(seq_group_output.samples[0].output_token)
# Get token ids and logprobs for comparison.
multi_step_output_logprobs: list[list[dict[int,
Logprob]]] = [[]
for _ in prompts]
single_step_output_logprobs: list[list[dict[int,
Logprob]]] = [[]
for _ in prompts]
multi_step_output_token_ids: list[list[int]] = [[] for _ in prompts]
single_step_output_token_ids: list[list[int]] = [[] for _ in prompts]
for i, _ in enumerate(prompts):
for multi_step, single_step in zip(multi_step_output,
single_step_output):
multi_step_output_token_ids[i].append(
multi_step[i].samples[0].output_token)
single_step_output_token_ids[i].append(
single_step[i].samples[0].output_token)
multi_step_output_logprobs[i].append(
multi_step[i].samples[0].logprobs)
single_step_output_logprobs[i].append(
single_step[i].samples[0].logprobs)
# Print per-sequence token ids
for i, (multi_step_tokens, single_step_tokens) in enumerate(
zip(multi_step_output_token_ids, single_step_output_token_ids)):
print(f'{i=} {multi_step_tokens=}')
print(f'{i=} {single_step_tokens=}')
print(f'{i=} equal {multi_step_tokens == single_step_tokens}')
# Assert token ids are equal.
for multi_step_tokens, single_step_tokens in zip(
multi_step_output_token_ids, single_step_output_token_ids):
assert multi_step_tokens == single_step_tokens
# Assert logprobs are equal.
for multi_step_logprobs, single_step_logprobs in zip(
multi_step_output_logprobs, single_step_output_logprobs):
assert_logprobs_dict_allclose(multi_step_logprobs,
single_step_logprobs)
@torch.inference_mode()
def test_multi_step_with_batch_expansion_correct_output():
"""
In this test we verify that the MultiStepWorker is able to handle bonus
tokens correctly. The test verifies that if a sequence has a
bonus token then the MultiStepWorker is able to expand the batch by adding
new sequences corresponding to the sequences with bonus tokens. The
expanded batch is then used for predicting the next tokens.
"""
seed = 100
model_name = 'JackFram/llama-68m'
block_size = 16
num_gpu_blocks = 2048 // block_size
batch_size = 128
multi_step_worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
multi_step_worker.set_include_gpu_probs_tensor()
worker = create_worker(
NPUWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
random.seed(seed)
prompts = [[0] for _ in range(batch_size)]
num_steps = 2
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
multi_step_worker.execute_model = patch_execute_model_with_seeds(
multi_step_worker, rand_seeds)
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
# Create the test continuations
continuations = [[random.randint(0, 1000)] for _ in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
# Run single-step twice to generate 2 tokens. This
# will simulate the bonus token case with the second token
# being the bonus token.
zero_kv_cache(worker.cache_engine)
single_step_output: list[SamplerOutput] = []
set_random_seed(seed)
for _ in range(num_steps):
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
single_step_output.extend(
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list)))
# Append output tokens to new sequence data.
for i, seq_group_output in enumerate(single_step_output[-1]):
continuations[i].append(seq_group_output.samples[0].output_token)
# Create continuations for the MultiStepWorker. The continuations have
# 2 tokens in order to simulate the bonus token case.
multi_step_continuations = []
for continuation in continuations:
multi_step_continuations.append(continuation[:2])
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)
# Run multi-step and verify that the third token prediction is accurate
# for all sequences.
zero_kv_cache(multi_step_worker.cache_engine)
all_seq_ids = {i for i in range(batch_size)}
multi_step_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=1,
seq_ids_with_bonus_token_in_last_step=all_seq_ids)
for index, output in enumerate(multi_step_output[-1].outputs):
assert (continuations[index][-1] == output.samples[0].output_token)
@torch.inference_mode()
def test_multi_step_with_batch_expansion_incorrect_output():
"""
Tests the MultiStepWorker's ability to handle batch expansion with bonus
tokens in a negative case scenario. This test provides the MultiStepWorker
with a batch containing sequences with bonus tokens but specifies the
sequence IDs with bonus tokens incorrectly. The test verifies that the
MultiStepWorker generates correct tokens for the sequences where the
sequence ID is specified correctly and incorrect tokens for those where
the sequence ID is specified incorrectly.
"""
seed = 100
model_name = 'JackFram/llama-68m'
block_size = 16
num_gpu_blocks = 2048 // block_size
batch_size = 128
multi_step_worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
multi_step_worker.set_include_gpu_probs_tensor()
worker = create_worker(
NPUWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
random.seed(seed)
prompts = [[0] for _ in range(batch_size)]
num_steps = 2
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
multi_step_worker.execute_model = patch_execute_model_with_seeds(
multi_step_worker, rand_seeds)
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
# Create the test continuations
continuations = [[random.randint(0, 1000)] for _ in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
# Run single-step twice to generate 2 tokens. This
# will simulate the bonus token case with the second token
# being the bonus token.
zero_kv_cache(worker.cache_engine)
single_step_output: list[SamplerOutput] = []
set_random_seed(seed)
for _ in range(num_steps):
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
single_step_output.extend(
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list)))
# Append output tokens to new sequence data.
for i, seq_group_output in enumerate(single_step_output[-1]):
continuations[i].append(seq_group_output.samples[0].output_token)
# Create continuations for the MultiStepWorker. The continuations have
# 2 tokens in order to simulate the bonus token case.
multi_step_continuations = []
for continuation in continuations:
multi_step_continuations.append(continuation[:2])
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)
# Run multi-step. In this run INCORRECTLY specify that only the odd number
# sequences have bonus tokens. Verify that with this setting the third token
# prediction is accurate only for the odd numbered sequences. Also verify
# that the prediction might be wrong for some of the even numbered
# sequences.
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
odd_seq_ids = {i for i in range(batch_size) if i % 2 != 0}
multi_step_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=1,
seq_ids_with_bonus_token_in_last_step=odd_seq_ids)
num_mismatch = 0
for index, output in enumerate(multi_step_output[-1].outputs):
if (index % 2) != 0:
assert (continuations[index][-1] == output.samples[0].output_token)
elif (continuations[index][-1] != output.samples[0].output_token):
num_mismatch += 1
# The prediction is accurate for some of the sequences even without proper
# handling of the bonus tokens. Hence verify that the number of sequences
# for which there is a mismatch is > 0.
assert (num_mismatch > 0)
@torch.inference_mode()
@pytest.mark.parametrize('num_steps', [1, 2, 3, 4])
def test_multi_step_correct_kvcache(num_steps):
"""Verify that the KV cache of the draft model
is correctly updated for sequences with bonus token.
"""
seed = 100
model_name = "JackFram/llama-68m"
block_size = 16
num_gpu_blocks = 2048 // block_size
batch_size = 1
dtype = 'float16'
multi_step_worker = create_worker(MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
dtype=dtype)
multi_step_worker.set_include_gpu_probs_tensor()
worker = create_worker(NPUWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
dtype=dtype)
prompts = [[0] for _ in range(batch_size)]
# Already generate two tokens for the sequence
# so that we can simulate the bonus token case
multi_step_continuations = [[
random.randint(0, 1000),
random.randint(0, 1000)
] for _ in prompts]
final_prompt_lens = [len(prompt) + 2 + num_steps for prompt in prompts]
seq_ids_with_bonus_token_in_last_step = set(range(batch_size))
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)
# Run multi-step.
zero_kv_cache(multi_step_worker.cache_engine)
multi_step_worker.sampler_output(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=num_steps,
seq_ids_with_bonus_token_in_last_step=
seq_ids_with_bonus_token_in_last_step)
# Run single-step repeatedly.
zero_kv_cache(worker.cache_engine)
# Generate the kv cache for the bonus token first
single_step_continuations = [c[:1] for c in multi_step_continuations]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=single_step_continuations,
final_prompt_lens=final_prompt_lens)
single_step_output = worker.execute_model(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list))
for _ in range(num_steps):
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)
single_step_output = worker.execute_model(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list))
for i, seq_group_output in enumerate(single_step_output[-1]):
multi_step_continuations[i].append(
seq_group_output.samples[0].output_token)
# Verify that the KV cache of the single-step and
# multi-step workers are the same.
single_step_gpu_cache = worker.cache_engine[0].gpu_cache
multi_step_gpu_cache = multi_step_worker.cache_engine[0].gpu_cache
num_layers = len(single_step_gpu_cache)
allclose = lambda a, b: torch.allclose( # noqa: E731
a.npu(), b.npu(), rtol=1e-2, atol=1e-2)
for i in range(num_layers):
assert allclose(single_step_gpu_cache[i][0],
multi_step_gpu_cache[i][0])
assert allclose(single_step_gpu_cache[i][1],
multi_step_gpu_cache[i][1])
@torch.inference_mode()
def test_draft_proposals_full_speculation_len():
"""Verify Top1Proposer correctly handles case where all sequences
can speculate.
"""
k = 10
batch_size = 32
vocab_size = 32_000
device = 'npu:0'
draft_worker = MagicMock()
proposer = Top1Proposer(
worker=draft_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=2048,
)
draft_worker.sampler_output.return_value = [
SamplerOutput(
outputs=[],
sampled_token_probs=torch.rand(batch_size,
vocab_size,
device=device,
dtype=torch.float32),
logprobs=torch.rand(batch_size,
vocab_size,
device=device,
dtype=torch.float32),
sampled_token_ids=torch.randint(low=0,
high=vocab_size,
size=(batch_size, ),
device=device,
dtype=torch.long),
) for _ in range(k)
], True
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
assert proposals.proposal_lens.shape == torch.Size([batch_size])
assert proposals.proposal_lens.tolist() == [k for _ in range(batch_size)]
@torch.inference_mode()
def test_draft_proposals_no_speculations():
"""Verify Top1Proposer correctly handles case where no sequences
can speculate.
"""
k = 10
batch_size = 32
vocab_size = 32_000
device = 'npu:0'
prompt_len = 10
draft_worker = MagicMock()
proposer = Top1Proposer(
worker=draft_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=prompt_len + k - 1,
)
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
prompt_len=prompt_len)
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
assert proposals.proposal_lens.shape == torch.Size([batch_size])
assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)]
@torch.inference_mode()
def test_draft_proposals_mixed_k():
"""Verify Top1Proposer correctly handles case some sequences can
speculate and some can't.
"""
k = 10
batch_size = 32
vocab_size = 32_000
device = 'npu:0'
small_prompt_len = 5
long_prompt_len = 10
prev_output_token_len = 20
expected_num_proposal_seqs = 6
expected_num_no_proposal_seqs = batch_size - expected_num_proposal_seqs
prompt_len = [
small_prompt_len for _ in range(expected_num_proposal_seqs - 1)
] + [long_prompt_len
for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len]
draft_worker = MagicMock()
proposer = Top1Proposer(
worker=draft_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=long_prompt_len + prev_output_token_len + k - 1,
)
draft_worker.sampler_output.return_value = [
SamplerOutput(
outputs=[],
sampled_token_probs=torch.rand(expected_num_proposal_seqs,
vocab_size,
device=device,
dtype=torch.float32),
logprobs=torch.rand(expected_num_proposal_seqs,
vocab_size,
device=device,
dtype=torch.float32),
sampled_token_ids=torch.randint(
low=0,
high=vocab_size,
size=(expected_num_proposal_seqs, ),
device=device,
dtype=torch.long),
) for _ in range(k)
], True
seq_group_metadata_list, _, _ = create_batch(
batch_size,
k,
prompt_len=prompt_len,
prev_output_token_len=prev_output_token_len,
)
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
assert proposals.proposal_lens.shape == torch.Size([batch_size])
assert proposals.proposal_lens.tolist() == [
k for _ in range(expected_num_proposal_seqs - 1)
] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k]
@torch.inference_mode()
def test_use_draft_model_runner_advance_step():
"""Verify that draft model runner triggers advance step
when applicable.
"""
seed = 100
model_name = 'JackFram/llama-68m'
k = 5
batch_size = 32
block_size = 32
num_gpu_blocks = 2048 // block_size
worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
# Mock "_gpu_advance_step" to raise an exception when called.
exception_secret = "artificial stop"
worker.model_runner._gpu_advance_step = MagicMock()
worker.model_runner._gpu_advance_step.side_effect = ValueError(
exception_secret)
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
block_size=block_size,
num_gpu_blocks=num_gpu_blocks)
# Fallback (should not call) when num_steps=1.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
num_steps=1)
worker.execute_model(execute_model_req=execute_model_req)
# Expect exception if _gpu_advance_step is called.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
num_steps=k)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
call_args_list = worker.model_runner._gpu_advance_step.call_args_list
assert len(call_args_list) == 1
@torch.inference_mode()
def test_expand_execute_model_request_sync_with_expand_hidden_states():
"""
In this test we verify that the logic for expanding the
seq_group_metadata_list remains in sync with the expansion logic of
the HiddenStates in _expand_execute_model_request.
"""
k = 5
batch_size = 16
seq_with_bonus_token_in_last_step = [1, 3, 8, 10, 13, 15]
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
execute_model_request = ExecuteModelRequest(
seq_group_metadata_list,
previous_hidden_states=HiddenStates(
torch.arange(batch_size), seq_group_metadata_list,
torch.arange(batch_size, 2 * batch_size)))
expanded_execute_model_request, orig_seq_group_ids = MultiStepWorker.\
_expand_execute_model_request(execute_model_request,
seq_with_bonus_token_in_last_step)
all_seq_ids = torch.tensor(
get_all_seq_ids(
expanded_execute_model_request.seq_group_metadata_list))
ref_expanded_hidden_states = all_seq_ids + batch_size
ref_expanded_hidden_states[orig_seq_group_ids] -= batch_size
assert (ref_expanded_hidden_states == expanded_execute_model_request.
previous_hidden_states.hidden_states).all().item()

View File

@@ -0,0 +1,238 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/spec_decode/test_ngram_worker.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from tests.singlecard.spec_decode.utils import (
create_seq_group_metadata_from_prompts, create_worker)
from vllm_ascend.patch.worker import patch_common # noqa: F401
def test_ngram_algo_correctness_for_single_no_match():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario cannot find any candidate in one single batch
"""
block_size = 32
num_gpu_blocks = 2048 // block_size
seed = 100
model_name = 'JackFram/llama-68m'
vocab_size = 32_000
device = 'npu:0'
ngram_worker = create_worker(
NGramWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
proposer = Top1Proposer(
worker=ngram_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=20,
)
# set ngram window [1, 3], which is window=1/2/3
ngram_worker.set_ngram_window_size(1, 3)
prompts = [
# shall find no candidate
[1, 2, 3, 4, 5, 6, 7],
]
proposal_len = 5
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len),
seq_ids_with_bonus_token_in_last_step=None)
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([1, proposal_len])
assert proposals.proposal_probs.shape[:-1] == torch.Size([1, proposal_len])
assert proposals.proposal_lens.shape == torch.Size([1])
assert proposals.proposal_lens.tolist() == [0]
def test_ngram_algo_correctness_for_batches_not_match_all():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario find some candidate not full in batchs
"""
block_size = 32
num_gpu_blocks = 2048 // block_size
seed = 100
model_name = 'JackFram/llama-68m'
vocab_size = 32_000
device = 'npu:0'
ngram_worker = create_worker(
NGramWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
proposer = Top1Proposer(
worker=ngram_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=20,
)
# set ngram window [1, 3], which is window=1/2/3
ngram_worker.set_ngram_window_size(1, 3)
prompts = [
# shall find no candidate
[1, 2, 3, 4, 5, 6, 7],
# shall find candidate 12,13,14,15,16
[11, 12, 13, 14, 15, 16, 11],
# shall find candidate 23,24,25,26,21
[21, 21, 22, 23, 24, 25, 26, 21, 22],
# shall find candidate 34,35,36,37,38
[31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
# shall find no candidate as exceed max_proposal_len
[
31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 33, 34, 35, 36, 37,
38, 31, 32, 33
],
]
proposal_len = 5
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)
for sg in seq_group_metadata_list:
sg.is_prompt = False
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len),
seq_ids_with_bonus_token_in_last_step=None)
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([5, proposal_len])
assert proposals.proposal_probs.shape[:-1] == torch.Size([5, proposal_len])
assert proposals.proposal_lens.shape == torch.Size([5])
# the first sequence has no match so proposal_len should be overwritten to 0
assert proposals.proposal_lens.tolist(
) == [0] + [proposal_len for _ in range(3)] + [0]
for i in range(proposal_len):
assert proposals.proposal_token_ids[0][i] == -1
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 1]
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 3]
assert proposals.proposal_token_ids[3][i] == prompts[3][i + 5]
assert proposals.proposal_token_ids[4][i] == -1
def test_ngram_algo_correctness_for_batches_match_all():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario find candidate in all batches
"""
block_size = 32
num_gpu_blocks = 2048 // block_size
seed = 100
model_name = 'JackFram/llama-68m'
vocab_size = 32_000
device = 'npu:0'
ngram_worker = create_worker(
NGramWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
proposer = Top1Proposer(
worker=ngram_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=20,
)
# set ngram window [0, 3], which is window=1/2/3
ngram_worker.set_ngram_window_size(1, 3)
prompts = [
# shall find candidate 12,13,14,15,16
[11, 12, 13, 14, 15, 16, 11],
# shall find candidate 23,24,25,26,21
[21, 21, 22, 23, 24, 25, 26, 21, 22],
# shall find candidate 34,35,36,37,38
[31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
]
proposal_len = 5
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)
# Normally drafter is run on decode requests only; here we check the output
# of the ngram worker as it is the sole proposer that has no forward.
for sg in seq_group_metadata_list:
sg.is_prompt = False
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len),
seq_ids_with_bonus_token_in_last_step=None)
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([3, proposal_len])
assert proposals.proposal_probs.shape[:-1] == torch.Size([3, proposal_len])
assert proposals.proposal_lens.shape == torch.Size([3])
assert proposals.proposal_lens.tolist() == [proposal_len for _ in range(3)]
for i in range(proposal_len):
assert proposals.proposal_token_ids[0][i] == prompts[0][i + 1]
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 3]
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 5]

View File

@@ -0,0 +1,959 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/spec_decode/test_spec_decode_worker.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import random
from collections import defaultdict
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SequenceOutput
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics)
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
split_num_cache_blocks_evenly)
from tests.singlecard.spec_decode.test_utils import mock_spec_decode_sampler
from tests.singlecard.spec_decode.utils import (create_batch,
create_sampler_output_list,
create_worker, mock_worker)
# patch SpecDecodeWorker, AsyncMetricsCollector
from vllm_ascend.patch.worker import patch_common # noqa: F401
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
from vllm_ascend.worker.worker import NPUWorker
@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_correctly_calls_draft_model(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker calls the draft worker with correct
inputs. Everything else is mocked out.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(
draft_worker,
target_worker,
mock_spec_decode_sampler(acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector)
exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
call_args_list = draft_worker.get_spec_proposals.call_args_list
assert len(call_args_list) == 1
for args, _ in call_args_list:
actual_execute_model_data = args[0]
assert actual_execute_model_data == execute_model_req
@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_batch_expansion_correctly_calls_target_model(
k: int, batch_size: int, acceptance_sampler_method: str):
"""Verify SpecDecodeWorker calls the target model with correct
inputs with batch expansion. Everything else is mocked out.
"""
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
target_worker = mock_worker(use_spec=False)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'npu'
target_worker.device = 'npu'
set_random_seed(1)
worker = SpecDecodeWorker(
draft_worker,
target_worker,
mock_spec_decode_sampler(acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector,
disable_mqa_scorer=True)
worker.init_device()
vocab_size = 32_000
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device='npu')
proposal_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device='npu')
proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='npu') * k
seq_group_metadata_list, prompts, prev_output_tokens = create_batch(
batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens)
exception_secret = 'artificial stop'
target_worker.execute_model.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
seen_contexts: list[list[int]] = []
call_args_list = target_worker.execute_model.call_args_list
assert len(call_args_list) == 1
for _, kwargs in call_args_list:
seq_group_metadata_list = kwargs[
"execute_model_req"].seq_group_metadata_list
assert len(seq_group_metadata_list) == (k + 1) * batch_size
for seq_group_metadata in seq_group_metadata_list:
for seq_data in seq_group_metadata.seq_data.values():
seen_contexts.append(seq_data.get_token_ids())
expected_seen_contexts: list[list[int]] = []
for prompt, prev_generated, draft_tokens in zip(
prompts, prev_output_tokens, proposal_token_ids.tolist()):
for i in range(len(draft_tokens) + 1):
expected_seen_contexts.append(prompt + prev_generated +
draft_tokens[:i])
seen_contexts.sort()
expected_seen_contexts.sort()
assert expected_seen_contexts == seen_contexts
@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker calls the rejection sampler with
correct inputs. Everything else is mocked out.
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker,
vocab_size=vocab_size,
use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'npu'
target_worker.device = 'npu'
set_random_seed(1)
worker = SpecDecodeWorker(draft_worker,
target_worker,
spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device()
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device='npu')
proposal_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device='npu')
proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='npu') * k
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens)
target_token_ids = torch.randint(low=0,
high=vocab_size,
size=(1, batch_size * (k + 1)),
dtype=torch.int64,
device='npu')
target_token_probs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='npu')
target_token_logprobs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='npu')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs,
target_token_logprobs)
target_worker.execute_model.return_value = [target_output[0]]
exception_secret = 'artificial stop'
spec_decode_sampler.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
assert len(spec_decode_sampler.call_args_list) == 1
_, kwargs = spec_decode_sampler.call_args_list[0]
actual = SimpleNamespace(**kwargs)
assert torch.equal(actual.bonus_token_ids,
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
assert torch.equal(actual.target_with_bonus_probs,
target_token_probs.reshape(batch_size, k + 1, -1))
assert torch.equal(actual.draft_token_ids, proposal_token_ids)
assert torch.equal(actual.draft_probs, proposal_probs)
@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_correctly_formats_output(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker formats sampler output correctly.
Everything else is mocked out.
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker,
vocab_size=vocab_size,
use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'npu'
target_worker.device = 'npu'
set_random_seed(1)
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
worker = SpecDecodeWorker(draft_worker,
target_worker,
spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device()
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device='npu')
proposal_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device='npu')
proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='npu') * k
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens)
target_token_ids = torch.randint(low=0,
high=vocab_size,
size=(1, batch_size * (k + 1)),
dtype=torch.int64,
device='npu')
target_token_probs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='npu')
target_token_logprobs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='npu')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs,
target_token_logprobs)
target_worker.execute_model.return_value = [target_output[0]]
spec_decode_sampler_output = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k + 1),
dtype=torch.int64,
device='npu')
for i in range(batch_size):
minimum_accepted_tokens = 1
spec_decode_sampler_output[i][
-random.randint(minimum_accepted_tokens, k + 1):] = -1
spec_decode_sampler.return_value = spec_decode_sampler_output
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
expected_output = create_sampler_output_list(
token_ids=spec_decode_sampler_output.transpose(0, 1),
probs=[None for _ in range(k + 1)],
logprobs=[None for _ in range(k + 1)])
seq_ids = [
next(iter(seq_group_metadata.seq_data.keys()))
for seq_group_metadata in seq_group_metadata_list
]
actual_output_by_seq: dict[int, list[SequenceOutput]] = {
seq_id: []
for seq_id in seq_ids
}
expected_output_by_seq: dict[int, list[SequenceOutput]] = {
seq_id: []
for seq_id in seq_ids
}
for step in output:
for seq_group in step:
for sample in seq_group.samples:
seq_id = sample.parent_seq_id
actual_output_by_seq[seq_id].append(sample)
for step in expected_output:
for seq_group in step:
for sample in seq_group.samples:
seq_id = sample.parent_seq_id
expected_output_by_seq[seq_id].append(sample)
all_seen_seq_ids = set(
list(actual_output_by_seq.keys()) +
list(expected_output_by_seq.keys()))
for seq_id in all_seen_seq_ids:
actual_by_step = actual_output_by_seq[seq_id]
expected_by_step = expected_output_by_seq[seq_id]
for i in range(k + 1):
if i >= len(actual_by_step):
assert expected_by_step[i].output_token == -1
continue
assert actual_by_step[i].output_token == expected_by_step[
i].output_token
@pytest.mark.parametrize('k', [1, 2])
@pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('returns_metrics', [True, False])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker collects metrics.
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker,
vocab_size=vocab_size,
use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'npu'
target_worker.device = 'npu'
set_random_seed(1)
worker = SpecDecodeWorker(draft_worker,
target_worker,
spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device()
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device='npu')
proposal_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device='npu')
proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='npu') * k
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens)
target_token_ids = torch.randint(low=0,
high=vocab_size,
size=(1, batch_size * (k + 1)),
dtype=torch.int64,
device='npu')
target_token_probs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='npu')
target_token_logprobs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='npu')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs,
target_token_logprobs)
target_worker.execute_model.return_value = [target_output[0]]
spec_decode_sampler_output = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k + 1),
dtype=torch.int64,
device='npu')
for i in range(batch_size):
minimum_accepted_tokens = 1
spec_decode_sampler_output[i][
-random.randint(minimum_accepted_tokens, k + 1):] = -1
spec_decode_sampler.return_value = spec_decode_sampler_output
mock_rejsample_metrics = MagicMock(
spec=SpecDecodeWorkerMetrics) if returns_metrics else None
metrics_collector.maybe_collect_rejsample_metrics.return_value = (
mock_rejsample_metrics)
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
call_args_list = (
metrics_collector.maybe_collect_rejsample_metrics.call_args_list)
assert len(call_args_list) == 1
args, kwargs = call_args_list[0]
assert args[0] == k or kwargs.get('k', -1) == k
@pytest.mark.parametrize('k', [0])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_k_equals_zero(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify that the SpecDecodeWorker calls the draft and target workers
when k is zero. This happens during prefill.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
sampler_output = MagicMock(spec=SamplerOutput)
sampler_output.hidden_states = None
target_worker.execute_model.return_value = [sampler_output]
draft_worker.device = 'npu'
target_worker.device = 'npu'
set_random_seed(1)
worker = SpecDecodeWorker(
proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector,
)
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
prev_output_token_len=0)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
out = worker.execute_model(execute_model_req=execute_model_req)
assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].sampled_token_probs is None, (
"expect gpu tensor references to be None")
assert out[
0].sampled_token_ids is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with(execute_model_req)
target_worker.execute_model.assert_called_once_with(execute_model_req)
@pytest.mark.parametrize('k', [0, 5])
@pytest.mark.parametrize('batch_size', [0])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_empty_input_batch(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify that the SpecDecodeWorker calls the draft and target workers
when the input batch is empty. This can happen if the engine communicates
to the workers information without scheduling a batch.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
sampler_output = MagicMock(spec=SamplerOutput)
sampler_output.hidden_states = None
target_worker.execute_model.return_value = [sampler_output]
draft_worker.device = 'npu'
target_worker.device = 'npu'
set_random_seed(1)
worker = SpecDecodeWorker(
proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector,
)
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
prev_output_token_len=0)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
out = worker.execute_model(execute_model_req=execute_model_req)
assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].sampled_token_probs is None, (
"expect gpu tensor references to be None")
assert out[
0].sampled_token_ids is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with(execute_model_req)
target_worker.execute_model.assert_called_once_with(execute_model_req)
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@pytest.mark.skip_global_cleanup
def test_init_device(acceptance_sampler_method: str):
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
well as other GPU initialization.
"""
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
target_worker = mock_worker(use_spec=False)
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(
proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector,
)
worker.init_device()
draft_worker.init_device.assert_called_once()
target_worker.init_device.assert_called_once()
metrics_collector.init_tensors.assert_called_once()
spec_decode_sampler.init_tensors.assert_called_once()
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_initialize_cache(acceptance_sampler_method):
"""Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer
workers.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
metrics_collector=metrics_collector)
kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
worker.initialize_cache(**kwargs)
draft_worker.initialize_cache.assert_called_once_with(**kwargs)
target_worker.initialize_cache.assert_called_once_with(**kwargs)
@pytest.mark.parametrize('available_gpu_blocks', [1, 1024])
@pytest.mark.parametrize('available_cpu_blocks', [500])
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@pytest.mark.skip_global_cleanup
def test_determine_num_available_blocks(available_gpu_blocks: int,
available_cpu_blocks: int,
target_cache_block_size_bytes: int,
draft_kv_size_bytes: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks.
Specifically, it should run profiling in the scorer worker, and then evenly
split the blocks between proposer and scorer worker.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.determine_num_available_blocks.return_value = (
available_gpu_blocks, available_cpu_blocks)
target_worker.get_cache_block_size_bytes.return_value = (
target_cache_block_size_bytes)
draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes
worker = SpecDecodeWorker(
draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks()
target_worker.determine_num_available_blocks.assert_called_once()
assert num_cpu_blocks == available_cpu_blocks
assert num_gpu_blocks == split_num_cache_blocks_evenly(
target_cache_block_size_bytes, draft_kv_size_bytes,
available_gpu_blocks)
@pytest.mark.parametrize('available_gpu_blocks',
list(range(20)) + [1024, 1024**2])
@pytest.mark.parametrize('target_cache_block_size_bytes',
[2 * 2 * 4096, 2 * 2 * 8192])
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
@pytest.mark.skip_global_cleanup
def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
target_cache_block_size_bytes: int,
draft_kv_size_bytes: int):
"""Verify split_num_cache_blocks_evenly does not exceed original memory
allocation in bytes.
"""
num_blocks = split_num_cache_blocks_evenly(target_cache_block_size_bytes,
draft_kv_size_bytes,
available_gpu_blocks)
assert (num_blocks * target_cache_block_size_bytes) + (
num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks *
target_cache_block_size_bytes)
@torch.inference_mode()
def test_populate_seq_ids_with_bonus_tokens():
"""
Verify that a call to _create_output_sampler_list correctly updates
seq_with_bonus_token_in_last_step.
seq_with_bonus_token_in_last_step is an internal data structure in
SpecDecodeWorker that tracks the sequence IDs which are assigned bonus
tokens by the target model in their last forward pass. This state is
maintained only for models relying on the KV cache, such as those using
the MultiStepWorker.
"""
batch_size = 10
k = 5
vocab_size = 10000
num_sequences_with_bonus_tokens = 5
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
target_worker.device = 'npu'
set_random_seed(1)
draft_worker = mock_worker(cls=MultiStepWorker)
draft_worker.device = 'npu'
# The sequence_ids attached to each sequence in the batch.
# The sequence at index i has seq_id assigned_seq_ids[i]
assigned_seq_ids = list(range(batch_size))
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
seq_ids=assigned_seq_ids,
prev_output_token_len=10)
target_token_logprobs = torch.rand(batch_size, (k + 1),
vocab_size,
dtype=torch.float32,
device='npu')
accepted_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, (k + 1)),
dtype=torch.int64,
device='npu')
expected_request_id_seq_ids_mapping: dict[str, set[int]] = defaultdict(set)
for seq_group_metadata in seq_group_metadata_list:
for seq_id in seq_group_metadata.seq_data:
expected_request_id_seq_ids_mapping[
seq_group_metadata.request_id].add(seq_id)
# Generate a random sample of sequence indexes with bonus tokens
seq_indexes_with_bonus_tokens = random.sample(
range(batch_size), num_sequences_with_bonus_tokens)
# Create a mask that is True for indices in seq_indexes_with_bonus_tokens
mask = torch.ones(batch_size, dtype=torch.bool, device='npu')
mask[seq_indexes_with_bonus_tokens] = False
# Set the last token ID to -1 for all indices not in
# seq_indexes_with_bonus_tokens to indicate the lack of bonus token in
# those indices.
accepted_token_ids[mask, -1:] = -1
worker = SpecDecodeWorker(draft_worker,
target_worker,
mock_spec_decode_sampler("rejection_sampler"),
disable_logprobs=False,
metrics_collector=metrics_collector)
# Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs.
# This set includes all sequence IDs in the batch as well as an additional
# `num_extra_sequence_ids` sequence IDs. Note that the sequence IDs are in
# the range [0, batch_size + num_extra_sequence_ids).
num_extra_sequence_ids = 10
worker._seq_with_bonus_token_in_last_step = set(
range(batch_size + num_extra_sequence_ids))
worker._create_output_sampler_list(
seq_group_metadata_list=seq_group_metadata_list,
accepted_token_ids=accepted_token_ids,
target_logprobs=target_token_logprobs,
prompt_logprobs=None,
k=k,
stage_times=(0, 0, 0))
# Verify that _seq_with_bonus_token_in_last_step contains the following:
# 1. Sequence IDs that were already present in
# _seq_with_bonus_token_in_last_step but were not part of the current
# batch are retained.
# 2. Of the sequence IDs present in the current batch, only those with a
# bonus token are retained in _seq_with_bonus_token_in_last_step.
# Sequence IDs that are present in the current batch but do not have
# bonus tokens are removed from _seq_with_bonus_token_in_last_step.
expected_seq_ids_with_bonus_tokens = \
set([assigned_seq_ids[i] for i in seq_indexes_with_bonus_tokens])
additional_sequence_ids = \
set(range(batch_size, batch_size + num_extra_sequence_ids))
assert worker._seq_with_bonus_token_in_last_step == \
expected_seq_ids_with_bonus_tokens.union(additional_sequence_ids)
assert worker._request_id_seq_id_mapping == \
expected_request_id_seq_ids_mapping
@torch.inference_mode()
def test_handle_finished_requests():
"""
Test to verify that finished request IDs are appropriately processed to
update the internal state of the SpecDecodeWorker.
This test initializes the SpecDecodeWorker with mock data, marks certain
requests as finished, and ensures that the corresponding sequence IDs are
correctly removed from the internal mappings.
"""
batch_size = 32
k = 3
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker,
mock_spec_decode_sampler("rejection_sampler"),
metrics_collector)
# Initialize the request_id_seq_id_mapping mapping dict with a few fake
# request ids and corresponding sequence ids.
worker._request_id_seq_id_mapping = \
{'request-1': {1,2,3}, 'request-2': {4,5,6,7},
'request-3': {8,9}, 'request-4': {10,11}}
# Initialize seq_with_bonus_token_in_last_step with a few fake
# sequence ids.
worker._seq_with_bonus_token_in_last_step = {1, 4, 5, 8, 9, 10}
exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
# Mark requests with ids request-1 and request-3 as finished.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
finished_requests_ids=['request-1', 'request-3'])
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
# Verify that request-1 and request-3 are removed from
# request_id_seq_id_mapping
assert worker._request_id_seq_id_mapping == \
{'request-2': {4,5,6,7}, 'request-4': {10,11}}
# Verify that all sequence ids corresponding to 'request-1'
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
assert worker._seq_with_bonus_token_in_last_step == \
{4,5,10}
@pytest.mark.parametrize('k', [3])
@pytest.mark.parametrize('batch_size', [2, 32])
@pytest.mark.parametrize("batch_composition",
["prefill_only", "decode_only", "mixed"])
@torch.inference_mode()
def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
"""
Verify SpecDecodeWorker calls match the expected flow.
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker,
target_worker,
mock_spec_decode_sampler("rejection_sampler"),
disable_logprobs=False,
metrics_collector=metrics_collector)
exception_secret = 'artificial stop'
worker.scorer = mock_worker(BatchExpansionTop1Scorer)
worker.scorer.score_proposals.side_effect = ValueError(exception_secret)
# Create batch with combination of terminal/non-terminal prefill chunks
# and decodes (different seq_ids).
decodes, _, _ = create_batch(batch_size, k)
# Pre-chunking here, get 'batch_size' chunks.
prefill, _, _ = create_batch(batch_size,
k,
prefill_chunk_size=4,
seq_ids=list(range(batch_size,
batch_size * 2)))
if batch_composition == "prefill_only":
n_prefills = batch_size
elif batch_composition == "decode_only":
n_prefills = 0
else:
n_prefills = random.randint(1, batch_size - 1)
n_decodes = batch_size - n_prefills
prefill = random.sample(prefill, n_prefills)
decodes = random.sample(decodes, n_decodes)
target_group_metadata_list = prefill + decodes
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=target_group_metadata_list,
# For prefill only batches we expect num_lookahead_slots = 0.
num_lookahead_slots=k if n_decodes > 0 else 0)
target_token_ids = torch.randint(low=0,
high=vocab_size,
size=(1, batch_size * (k + 1)),
dtype=torch.int64,
device='npu')
target_token_probs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='npu')
target_token_logprobs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='npu')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs,
target_token_logprobs)
target_worker.execute_model.return_value = [target_output[0]]
if not len(decodes):
worker.execute_model(execute_model_req=execute_model_req)
# no spec run (prefill only)
draft_worker.execute_model.assert_called_once_with(execute_model_req)
target_worker.execute_model.assert_called_once_with(execute_model_req)
else:
# Decode-only run OR mixed batch, scorer call fails (it's mocked)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
# but first draft still counted
assert draft_worker.get_spec_proposals.call_count == 1
def test_correctly_load_weight_for_eagle():
"""
Verify SpecDecodeWorker loads lm_head weight for eagle correctly.
"""
seed = 100
block_size = 32
num_gpu_blocks = 8096 // block_size
target_worker = create_worker(
NPUWorker,
"JackFram/llama-68m",
block_size,
num_gpu_blocks,
seed,
)
draft_worker = create_worker(
MultiStepWorker,
"abhigoyal/vllm-eagle-llama-68m-random",
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
spec_decode_sampler = mock_spec_decode_sampler("rejection_sampler")
worker = SpecDecodeWorker(draft_worker,
target_worker,
spec_decode_sampler,
disable_logprobs=False)
worker.proposer_worker.maybe_load_lm_head_weight(
target_worker.model_runner.model.lm_head.weight.data)
assert torch.allclose(
worker.proposer_worker.worker.model_runner.model.lm_head.weight.data,
worker.scorer_worker.model_runner.model.lm_head.weight.data)

View File

@@ -0,0 +1,165 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/spec_decode/test_utils.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from unittest.mock import MagicMock
import pytest
import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.sampler import _get_ranks
from vllm.model_executor.layers.typical_acceptance_sampler import \
TypicalAcceptanceSampler
from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
from vllm.spec_decode.util import (get_sampled_token_logprobs,
split_batch_by_proposal_len)
def test_get_all_seq_ids():
"""Verify get_all_seq_ids extracts all seq ids.
"""
expected_seq_ids = list(range(10)) + list(range(100, 110))
seq_group_metadata_list = [
SequenceGroupMetadata(
request_id=str(seq_id),
is_prompt=True,
seq_data={
seq_id: MagicMock(),
},
sampling_params=MagicMock(),
block_tables={
seq_id: MagicMock(),
},
lora_request=None,
) for seq_id in expected_seq_ids
]
actual_seq_ids = get_all_seq_ids(seq_group_metadata_list)
assert actual_seq_ids == expected_seq_ids
@pytest.fixture
def fake_sequence_group_metadata():
seq_ids = list(range(3))
return [
SequenceGroupMetadata(
request_id=str(i),
is_prompt=True,
seq_data={
i: MagicMock(),
},
sampling_params=MagicMock(),
block_tables={
i: MagicMock(),
},
lora_request=None,
) for i in seq_ids
]
def test_filter_zero_length_proposals(fake_sequence_group_metadata):
proposal_lens = [0, 1, 0]
_, (filtered_groups,
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens)
expected_groups = [
fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
]
expected_indices = [0, 2]
assert filtered_groups == expected_groups
assert indices == expected_indices
def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
proposal_lens = [0, 1, 2]
(filtered_groups,
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens)
expected_groups = [
fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
]
expected_indices = [1, 2]
assert filtered_groups == expected_groups
assert indices == expected_indices
def test_empty_inputs():
_, (filtered_groups, indices) = split_batch_by_proposal_len([], [])
assert filtered_groups == []
assert indices == []
def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
proposal_lens = [0, 0, 0]
(filtered_groups,
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens)
assert filtered_groups == []
assert indices == []
def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
proposal_lens = [1, 1, 1]
_, (filtered_groups,
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens)
assert filtered_groups == []
assert indices == []
def mock_spec_decode_sampler(acceptance_sampler_method):
"""
Returns either a RejectionSampler or TypicalAcceptanceSampler
object depending on whether acceptance_sampler_method is
'rejection_sampler' or 'typical_acceptance_sampler' respectively.
"""
if acceptance_sampler_method == "rejection_sampler":
sampler = MagicMock(spec=RejectionSampler)
sampler.token_id_dtype = torch.int64
return sampler
elif acceptance_sampler_method == "typical_acceptance_sampler":
sampler = MagicMock(spec=TypicalAcceptanceSampler)
sampler.token_id_dtype = torch.int64
return sampler
else:
raise ValueError(f"Invalid sampler name {acceptance_sampler_method}")
def test_get_sampled_token_logprobs():
"""Verify get_sampled_token_logprobs returns consistent rankings
with regular get_ranks when probabilities match exactly.
"""
logprob_tensor = torch.tensor(
[[[-.1, -.1]] * 2]) # shape (num_steps, batch_size, vocab_size)
sampled_token_tensor = torch.tensor([[1,
0]]) # shape (num_steps, batch_size)
ranks_spec_dec, _ = get_sampled_token_logprobs(logprob_tensor,
sampled_token_tensor)
ranks_regular = _get_ranks(logprob_tensor.reshape((2, -1)),
sampled_token_tensor.reshape(-1))
assert torch.equal(ranks_spec_dec.reshape(-1), ranks_regular)

View File

@@ -0,0 +1,317 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/spec_decode/utils.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections.abc import Sequence as GenericSequence
from itertools import count
from typing import Callable, Optional, TypeVar, Union
from unittest.mock import MagicMock
import torch
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sampling_params import SamplingParams
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceData, SequenceGroupMetadata, SequenceOutput)
from vllm.spec_decode.ngram_worker import NGramWorker # noqa: F401
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine
from vllm_ascend.worker.model_runner import NPUModelRunner
from vllm_ascend.worker.worker import NPUWorker
T = TypeVar("T", bound=NPUWorker)
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
return (seq_len + block_size - 1) // block_size
def mock_worker(cls=None,
vocab_size: int = 30_000,
max_model_len: int = 2048,
rank: int = 0,
use_spec: bool = True) -> MagicMock:
if cls is None:
cls = NPUWorker
spec = cls if use_spec else None
worker = MagicMock(spec=spec)
worker.vocab_size = vocab_size
worker.max_model_len = max_model_len
worker.rank = rank
worker.device = 'npu:0'
return worker
def patch_execute_model_with_seeds(worker: NPUWorker, rand_seeds: list[int]):
seed_iter = iter(rand_seeds)
original_execute_model = worker.execute_model
def new_execute_model(*args, **kwargs):
result = original_execute_model(*args, **kwargs)
set_random_seed(next(seed_iter))
return result
return new_execute_model
def zero_kv_cache(cache_engine: list[CacheEngine]):
assert cache_engine[0].gpu_cache
for key_blocks, value_blocks in cache_engine[0].gpu_cache:
key_blocks.zero_()
value_blocks.zero_()
def create_worker(cls: Callable[..., T],
model_name: str,
block_size: int,
num_gpu_blocks: int,
seed: int,
is_driver_worker: bool = True,
enforce_eager: bool = True,
model_runner_cls: Optional[NPUModelRunner] = None,
dtype: Optional[str] = "auto") -> T:
engine_args = EngineArgs(
model=model_name,
seed=seed,
block_size=block_size,
enforce_eager=enforce_eager,
dtype=dtype,
)
engine_config = engine_args.create_engine_config()
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
if cls.__name__ == "NGramWorker":
# we need to pass by device type to enable this on npu
worker = cls(vllm_config=engine_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
model_runner_cls=model_runner_cls,
device_type="npu")
else:
worker = cls(
vllm_config=engine_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
model_runner_cls=model_runner_cls,
)
worker.init_device()
worker.load_model()
engine_config.cache_config.num_gpu_blocks = num_gpu_blocks
engine_config.cache_config.num_cpu_blocks = 0
worker.initialize_cache(
num_gpu_blocks=engine_config.cache_config.num_gpu_blocks,
num_cpu_blocks=engine_config.cache_config.num_cpu_blocks)
return worker
def create_seq_group_metadata_from_prompts(
prompts: list[list[int]],
num_gpu_blocks: int,
block_size: int,
final_prompt_lens: list[int],
continuations: Optional[list[list[int]]] = None,
seq_ids: Optional[list[int]] = None,
) -> list[SequenceGroupMetadata]:
if continuations is None:
continuations = [[] for _ in prompts]
if seq_ids is None:
seq_ids = list(i for i, _ in enumerate(prompts))
free_gpu_blocks = list(range(num_gpu_blocks))
block_allocations = {
i: [
free_gpu_blocks.pop()
for _ in range(round_up_to_next_block(final_len, block_size))
]
for i, final_len in enumerate(final_prompt_lens)
}
seq_grou_metadata_list = []
for i, (prompt_token_ids,
cont_token_ids) in enumerate(zip(prompts, continuations)):
data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids)
data.update_num_computed_tokens(
len(prompt_token_ids) + len(cont_token_ids) - 1)
seq_data = {i: data}
seq_grou_metadata_list.append(
SequenceGroupMetadata(
request_id=str(i),
is_prompt=len(cont_token_ids) == 0,
seq_data=seq_data,
sampling_params=SamplingParams(temperature=0.0),
block_tables={i: block_allocations[i][:]},
))
return seq_grou_metadata_list
def create_chunked_seq_group_metadata_from_prompt(
prompt: list[int],
num_gpu_blocks: int,
chunk_size: int,
block_size: int,
seq_id: Optional[int] = None) -> list[SequenceGroupMetadata]:
if seq_id is None:
seq_id = 0
free_gpu_blocks = list(range(num_gpu_blocks))
block_allocations = [
free_gpu_blocks.pop()
for _ in range(round_up_to_next_block(len(prompt), block_size))
]
seq_group_metadata_list = []
for i, idx in enumerate(range(0, len(prompt), chunk_size)):
chunk_ids = prompt[idx:idx + chunk_size]
data = SequenceData.from_seqs(prompt)
data.update_num_computed_tokens(idx)
seq_data = {i: data}
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=str(seq_id),
is_prompt=True,
do_sample=idx + chunk_size >= len(prompt), # terminal chunk
seq_data=seq_data,
sampling_params=SamplingParams(temperature=0.0),
block_tables={i: block_allocations},
token_chunk_size=len(chunk_ids)))
return seq_group_metadata_list
def assert_logprobs_dict_allclose(
actual_logprobs: list[dict[int, Logprob]],
expected_logprobs: list[dict[int, Logprob]]) -> None:
for single_step_actual_logprobs, single_step_expected_logprobs in zip(
actual_logprobs, expected_logprobs):
assert set(single_step_actual_logprobs.keys()) == set(
single_step_expected_logprobs.keys())
for token_id in single_step_actual_logprobs:
actual = torch.tensor(
single_step_actual_logprobs[token_id].logprob)
expected = torch.tensor(
single_step_expected_logprobs[token_id].logprob)
torch.testing.assert_close(actual, expected)
def create_sampler_output_list(
token_ids: torch.Tensor,
probs: GenericSequence[Optional[torch.Tensor]],
logprobs: GenericSequence[Optional[torch.Tensor]],
seq_ids: Optional[list[int]] = None) -> list[SamplerOutput]:
num_steps, batch_size = token_ids.shape
token_ids_by_step = token_ids.tolist()
if seq_ids is None:
seq_ids = list(range(batch_size))
return [
SamplerOutput(outputs=[
CompletionSequenceGroupOutput(
samples=[
SequenceOutput(
output_token=token_id,
parent_seq_id=seq_ids[seq_index],
logprobs={token_id: Logprob(0)},
)
],
prompt_logprobs=None,
) for seq_index, token_id in enumerate(token_ids_by_step[step])
],
sampled_token_probs=probs[step],
logprobs=logprobs[step],
sampled_token_ids=token_ids[step])
for step in range(num_steps)
]
def create_batch(batch_size,
k,
prompt_len: Union[int, list[int]] = 10,
prev_output_token_len: int = 10,
seq_ids: Optional[list[int]] = None,
num_gpu_blocks: Optional[int] = None,
block_size: Optional[int] = None,
prefill_chunk_size: Optional[int] = None):
if block_size is None:
block_size = 8
if num_gpu_blocks is None:
num_gpu_blocks = 2048 // block_size
iterator = count()
if isinstance(prompt_len, int):
prompt_lens = [prompt_len for _ in range(batch_size)]
else:
prompt_lens = prompt_len
prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens]
if prefill_chunk_size:
# Create a batch of chunked prompts.
if not seq_ids:
seq_ids = list(range(len(prompts)))
seq_group_metadata_list = []
for p, sid in zip(prompts, seq_ids):
seq_group_metadata_list += \
create_chunked_seq_group_metadata_from_prompt(
p, num_gpu_blocks, prefill_chunk_size, block_size, sid)
seq_group_metadata_list = seq_group_metadata_list[:batch_size]
prev_output_tokens = []
else:
prev_output_tokens = [[
next(iterator) for _ in range(prev_output_token_len)
] for _ in range(batch_size)]
final_prompt_lens = [
len(prompt) + len(prev_output_token) + k + 1
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size, final_prompt_lens,
prev_output_tokens, seq_ids)
return seq_group_metadata_list, prompts, prev_output_tokens
def maybe_enable_chunked_prefill(prefill_chunk_size, llm_kwargs):
if prefill_chunk_size > 0:
llm_kwargs.update(
**{
"enable_chunked_prefill": True,
"max_num_batched_tokens": prefill_chunk_size,
"max_num_seqs": prefill_chunk_size
})
else:
llm_kwargs["enable_chunked_prefill"] = False

View File

@@ -24,9 +24,9 @@ import os
import pytest
import vllm # noqa: F401
from conftest import VllmRunner
import vllm_ascend # noqa: F401
from tests.conftest import VllmRunner
MODELS = [
"Qwen/Qwen2.5-0.5B-Instruct",

735
tests/utils.py Normal file
View File

@@ -0,0 +1,735 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/utils.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import copy
import functools
import os
import signal
import subprocess
import sys
import time
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Type, Union
import openai
import pytest
import requests
import torch
import torch.nn.functional as F
import vllm.envs as envs
from openai.types.completion import Completion
from typing_extensions import ParamSpec
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.platforms import current_platform
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import FlexibleArgumentParser, GB_bytes, get_open_port
from .model_utils import TextTextLogprobs
VLLM_PATH = Path(__file__).parent.parent
"""Path to root of the vLLM repository."""
class RemoteOpenAIServer:
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
def __init__(self,
model: str,
vllm_serve_args: List[str],
*,
env_dict: Optional[Dict[str, str]] = None,
auto_port: bool = True,
max_wait_seconds: Optional[float] = None) -> None:
if auto_port:
if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
raise ValueError("You have manually specified the port "
"when `auto_port=True`.")
# Don't mutate the input args
vllm_serve_args = vllm_serve_args + [
"--port", str(get_open_port())
]
parser = FlexibleArgumentParser(
description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
args = parser.parse_args(["--model", model, *vllm_serve_args])
self.host = str(args.host or 'localhost')
self.port = int(args.port)
# download the model before starting the server to avoid timeout
is_local = os.path.isdir(model)
if not is_local:
engine_args = AsyncEngineArgs.from_cli_args(args)
model_config = engine_args.create_model_config()
load_config = engine_args.create_load_config()
model_loader = get_model_loader(load_config)
model_loader.download_model(model_config)
env = os.environ.copy()
# the current process might initialize cuda,
# to be safe, we should use spawn method
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
if env_dict is not None:
env.update(env_dict)
self.proc = subprocess.Popen(
["vllm", "serve", model, *vllm_serve_args],
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
)
max_wait_seconds = max_wait_seconds or 240
self._wait_for_server(url=self.url_for("health"),
timeout=max_wait_seconds)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.proc.terminate()
try:
self.proc.wait(8)
except subprocess.TimeoutExpired:
# force kill if needed
self.proc.kill()
def _wait_for_server(self, *, url: str, timeout: float):
# run health check
start = time.time()
while True:
try:
if requests.get(url).status_code == 200:
break
except Exception:
# this exception can only be raised by requests.get,
# which means the server is not ready yet.
# the stack trace is not useful, so we suppress it
# by using `raise from None`.
result = self.proc.poll()
if result is not None and result != 0:
raise RuntimeError("Server exited unexpectedly.") from None
time.sleep(0.5)
if time.time() - start > timeout:
raise RuntimeError(
"Server failed to start in time.") from None
@property
def url_root(self) -> str:
return f"http://{self.host}:{self.port}"
def url_for(self, *parts: str) -> str:
return self.url_root + "/" + "/".join(parts)
def get_client(self, **kwargs):
if "timeout" not in kwargs:
kwargs["timeout"] = 600
return openai.OpenAI(
base_url=self.url_for("v1"),
api_key=self.DUMMY_API_KEY,
max_retries=0,
**kwargs,
)
def get_async_client(self, **kwargs):
if "timeout" not in kwargs:
kwargs["timeout"] = 600
return openai.AsyncOpenAI(base_url=self.url_for("v1"),
api_key=self.DUMMY_API_KEY,
max_retries=0,
**kwargs)
def _test_completion(
client: openai.OpenAI,
model: str,
prompt: str,
token_ids: List[int],
):
results = []
# test with text prompt
completion = client.completions.create(model=model,
prompt=prompt,
max_tokens=5,
temperature=0.0)
results.append({
"test": "single_completion",
"text": completion.choices[0].text,
"finish_reason": completion.choices[0].finish_reason,
"usage": completion.usage,
})
# test using token IDs
completion = client.completions.create(
model=model,
prompt=token_ids,
max_tokens=5,
temperature=0.0,
)
results.append({
"test": "token_ids",
"text": completion.choices[0].text,
"finish_reason": completion.choices[0].finish_reason,
"usage": completion.usage,
})
# test seeded random sampling
completion = client.completions.create(model=model,
prompt=prompt,
max_tokens=5,
seed=33,
temperature=1.0)
results.append({
"test": "seeded_sampling",
"text": completion.choices[0].text,
"finish_reason": completion.choices[0].finish_reason,
"usage": completion.usage,
})
# test seeded random sampling with multiple prompts
completion = client.completions.create(model=model,
prompt=[prompt, prompt],
max_tokens=5,
seed=33,
temperature=1.0)
results.append({
"test":
"seeded_sampling",
"text": [choice.text for choice in completion.choices],
"finish_reason":
[choice.finish_reason for choice in completion.choices],
"usage":
completion.usage,
})
# test simple list
batch = client.completions.create(
model=model,
prompt=[prompt, prompt],
max_tokens=5,
temperature=0.0,
)
results.append({
"test": "simple_list",
"text0": batch.choices[0].text,
"text1": batch.choices[1].text,
})
# test streaming
batch = client.completions.create(
model=model,
prompt=[prompt, prompt],
max_tokens=5,
temperature=0.0,
stream=True,
)
texts = [""] * 2
for chunk in batch:
assert len(chunk.choices) == 1
choice = chunk.choices[0]
texts[choice.index] += choice.text
results.append({
"test": "streaming",
"texts": texts,
})
return results
def _test_completion_close(
client: openai.OpenAI,
model: str,
prompt: str,
):
results = []
# test with text prompt
completion = client.completions.create(model=model,
prompt=prompt,
max_tokens=1,
logprobs=5,
temperature=0.0)
logporbs = completion.choices[0].logprobs.top_logprobs[0]
logporbs = {k: round(v, 2) for k, v in logporbs.items()}
results.append({
"test": "completion_close",
"logprobs": logporbs,
})
return results
def _test_embeddings(
client: openai.OpenAI,
model: str,
text: str,
):
results = []
# test with text input
embeddings = client.embeddings.create(
model=model,
input=text,
encoding_format="float",
)
results.append({
"test": "single_embedding",
"embedding": embeddings.data[0].embedding,
"usage": embeddings.usage,
})
return results
def _test_image_text(
client: openai.OpenAI,
model_name: str,
image_url: str,
):
results = []
# test pure text input
messages = [{
"role":
"user",
"content": [
{
"type": "text",
"text": "How do you feel today?"
},
],
}]
chat_completion = client.chat.completions.create(model=model_name,
messages=messages,
temperature=0.0,
max_tokens=1,
logprobs=True,
top_logprobs=5)
top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs
for x in top_logprobs:
x.logprob = round(x.logprob, 2)
results.append({
"test": "pure_text",
"logprobs": top_logprobs,
})
messages = [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "text",
"text": "What's in this image?"
},
],
}]
chat_completion = client.chat.completions.create(model=model_name,
messages=messages,
temperature=0.0,
max_tokens=1,
logprobs=True,
top_logprobs=5)
top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs
results.append({
"test": "text_image",
"logprobs": top_logprobs,
})
return results
def compare_two_settings(model: str,
arg1: List[str],
arg2: List[str],
env1: Optional[Dict[str, str]] = None,
env2: Optional[Dict[str, str]] = None,
*,
method: str = "generate",
max_wait_seconds: Optional[float] = None) -> None:
"""
Launch API server with two different sets of arguments/environments
and compare the results of the API calls.
Args:
model: The model to test.
arg1: The first set of arguments to pass to the API server.
arg2: The second set of arguments to pass to the API server.
env1: The first set of environment variables to pass to the API server.
env2: The second set of environment variables to pass to the API server.
"""
compare_all_settings(
model,
[arg1, arg2],
[env1, env2],
method=method,
max_wait_seconds=max_wait_seconds,
)
def compare_all_settings(model: str,
all_args: List[List[str]],
all_envs: List[Optional[Dict[str, str]]],
*,
method: str = "generate",
max_wait_seconds: Optional[float] = None) -> None:
"""
Launch API server with several different sets of arguments/environments
and compare the results of the API calls with the first set of arguments.
Args:
model: The model to test.
all_args: A list of argument lists to pass to the API server.
all_envs: A list of environment dictionaries to pass to the API server.
"""
trust_remote_code = False
for args in all_args:
if "--trust-remote-code" in args:
trust_remote_code = True
break
tokenizer_mode = "auto"
for args in all_args:
if "--tokenizer-mode" in args:
tokenizer_mode = args[args.index("--tokenizer-mode") + 1]
break
tokenizer = get_tokenizer(
model,
trust_remote_code=trust_remote_code,
tokenizer_mode=tokenizer_mode,
)
can_force_load_format = True
for args in all_args:
if "--load-format" in args:
can_force_load_format = False
break
prompt = "Hello, my name is"
token_ids = tokenizer(prompt).input_ids
ref_results: List = []
for i, (args, env) in enumerate(zip(all_args, all_envs)):
if can_force_load_format:
# we are comparing the results and
# usually we don't need real weights.
# we force to use dummy weights by default,
# and it should work for most of the cases.
# if not, we can use VLLM_TEST_FORCE_LOAD_FORMAT
# environment variable to force the load format,
# e.g. in quantization tests.
args = args + ["--load-format", envs.VLLM_TEST_FORCE_LOAD_FORMAT]
compare_results: List = []
results = ref_results if i == 0 else compare_results
with RemoteOpenAIServer(model,
args,
env_dict=env,
max_wait_seconds=max_wait_seconds) as server:
client = server.get_client()
# test models list
models = client.models.list()
models = models.data
served_model = models[0]
results.append({
"test": "models_list",
"id": served_model.id,
"root": served_model.root,
})
if method == "generate":
results += _test_completion(client, model, prompt, token_ids)
elif method == "generate_close":
results += _test_completion_close(client, model, prompt)
elif method == "generate_with_image":
results += _test_image_text(
client, model,
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png"
)
elif method == "encode":
results += _test_embeddings(client, model, prompt)
else:
raise ValueError(f"Unknown method: {method}")
if i > 0:
# if any setting fails, raise an error early
ref_args = all_args[0]
ref_envs = all_envs[0]
compare_args = all_args[i]
compare_envs = all_envs[i]
for ref_result, compare_result in zip(ref_results,
compare_results):
ref_result = copy.deepcopy(ref_result)
compare_result = copy.deepcopy(compare_result)
if "embedding" in ref_result and method == "encode":
sim = F.cosine_similarity(
torch.tensor(ref_result["embedding"]),
torch.tensor(compare_result["embedding"]),
dim=0,
)
assert sim >= 0.999, (
f"Embedding for {model=} are not the same.\n"
f"cosine_similarity={sim}\n")
del ref_result["embedding"]
del compare_result["embedding"]
assert ref_result == compare_result, (
f"Results for {model=} are not the same.\n"
f"{ref_args=} {ref_envs=}\n"
f"{compare_args=} {compare_envs=}\n"
f"{ref_result=}\n"
f"{compare_result=}\n")
def init_test_distributed_environment(
tp_size: int,
pp_size: int,
rank: int,
distributed_init_port: str,
local_rank: int = -1,
) -> None:
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
init_distributed_environment(
world_size=pp_size * tp_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=local_rank)
ensure_model_parallel_initialized(tp_size, pp_size)
def multi_process_parallel(
tp_size: int,
pp_size: int,
test_target: Any,
) -> None:
import ray
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
# NOTE: We need to set working_dir for distributed tests,
# otherwise we may get import errors on ray workers
ray.init(runtime_env={"working_dir": VLLM_PATH})
distributed_init_port = get_open_port()
refs = []
for rank in range(tp_size * pp_size):
refs.append(
test_target.remote(tp_size, pp_size, rank, distributed_init_port))
ray.get(refs)
ray.shutdown()
@contextmanager
def error_on_warning(category: Type[Warning] = Warning):
"""
Within the scope of this context manager, tests will fail if any warning
of the given category is emitted.
"""
with warnings.catch_warnings():
warnings.filterwarnings("error", category=category)
yield
_P = ParamSpec("_P")
def fork_new_process_for_each_test(
f: Callable[_P, None]) -> Callable[_P, None]:
"""Decorator to fork a new process for each test function.
See https://github.com/vllm-project/vllm/issues/7053 for more details.
"""
@functools.wraps(f)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
# Make the process the leader of its own process group
# to avoid sending SIGTERM to the parent process
os.setpgrp()
from _pytest.outcomes import Skipped
pid = os.fork()
print(f"Fork a new process to run a test {pid}")
if pid == 0:
try:
f(*args, **kwargs)
except Skipped as e:
# convert Skipped to exit code 0
print(str(e))
os._exit(0)
except Exception:
import traceback
traceback.print_exc()
os._exit(1)
else:
os._exit(0)
else:
pgid = os.getpgid(pid)
_pid, _exitcode = os.waitpid(pid, 0)
# ignore SIGTERM signal itself
old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN)
# kill all child processes
os.killpg(pgid, signal.SIGTERM)
# restore the signal handler
signal.signal(signal.SIGTERM, old_signal_handler)
assert _exitcode == 0, (f"function {f} failed when called with"
f" args {args} and kwargs {kwargs}")
return wrapper
def large_gpu_mark(min_gb: int) -> pytest.MarkDecorator:
"""
Get a pytest mark, which skips the test if the GPU doesn't meet
a minimum memory requirement in GB.
This can be leveraged via `@large_gpu_test` to skip tests in environments
without enough resources, or called when filtering tests to run directly.
"""
try:
if current_platform.is_cpu():
memory_gb = 0
else:
memory_gb = current_platform.get_device_total_memory() / GB_bytes
except Exception as e:
warnings.warn(
f"An error occurred when finding the available memory: {e}",
stacklevel=2,
)
memory_gb = 0
return pytest.mark.skipif(
memory_gb < min_gb,
reason=f"Need at least {min_gb}GB GPU memory to run the test.",
)
def large_gpu_test(*, min_gb: int):
"""
Decorate a test to be skipped if no GPU is available or it does not have
sufficient memory.
Currently, the CI machine uses L4 GPU which has 24 GB VRAM.
"""
mark = large_gpu_mark(min_gb)
def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
return mark(f)
return wrapper
async def completions_with_server_args(
prompts: List[str],
model_name: str,
server_cli_args: List[str],
num_logprobs: Optional[int],
max_wait_seconds: int = 240,
max_tokens: Union[int, list] = 5,
) -> List[Completion]:
'''Construct a remote OpenAI server, obtain an async client to the
server & invoke the completions API to obtain completions.
Args:
prompts: test prompts
model_name: model to spin up on the vLLM server
server_cli_args: CLI args for starting the server
num_logprobs: Number of logprobs to report (or `None`)
max_wait_seconds: timeout interval for bringing up server.
Default: 240sec
max_tokens: max_tokens value for each of the given input prompts.
if only one max_token value is given, the same value is used
for all the prompts.
Returns:
OpenAI Completion instance
'''
if isinstance(max_tokens, int):
max_tokens = [max_tokens] * len(prompts)
assert len(max_tokens) == len(prompts)
outputs = None
with RemoteOpenAIServer(model_name,
server_cli_args,
max_wait_seconds=max_wait_seconds) as server:
client = server.get_async_client()
outputs = [ client.completions.create(model=model_name,
prompt=[p],
temperature=0,
stream=False,
max_tokens=max_tok,
logprobs=num_logprobs) \
for p, max_tok in zip(prompts, max_tokens) ]
outputs = await asyncio.gather(*outputs)
assert outputs is not None, "Completion API call failed."
return outputs
def get_client_text_generations(completions: List[Completion]) -> List[str]:
'''Extract generated tokens from the output of a
request made to an Open-AI-protocol completions endpoint.
'''
assert all([len(x.choices) == 1 for x in completions])
return [x.choices[0].text for x in completions]
def get_client_text_logprob_generations(
completions: List[Completion]) -> List[TextTextLogprobs]:
'''Operates on the output of a request made to an Open-AI-protocol
completions endpoint; obtains top-rank logprobs for each token in
each :class:`SequenceGroup`
'''
text_generations = get_client_text_generations(completions)
text = ''.join(text_generations)
return [(text_generations, text,
(None if x.logprobs is None else x.logprobs.top_logprobs))
for completion in completions for x in completion.choices]

View File

@@ -511,9 +511,6 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)

View File

@@ -13,4 +13,68 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# What's Patched and how it works:
# ** File: worker/patch_common/patch_metrics.py **
# 1. `vllm.spec_decode.metrics.AsyncMetricsCollector.init_tensors` and
# `vllm.spec_decode.metrics.AsyncMetricsCollector._copy_rejsample_metrics_async`
# Why:
# There are cuda hard code (torch.cuda.Stream) in `AsyncMetricsCollector.init_tensors` and
# `AsyncMetricsCollector._copy_rejsample_metrics_async`
# How
# Replace it with the corresponding npu method
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
# https://github.com/vllm-project/vllm/pull/14411
# Future Plan:
# Revert it when the related pr is merged in vllm.
#
# 2. `vllm.spec_decode.metrics.AsyncMetricsCollector.maybe_collect_rejsample_metrics`
# Why:
# There are cuda hard code (current_platform.is_cuda_alike()) in
# `AsyncMetricsCollector.maybe_collect_rejsample_metrics`
# How
# Change to use `current_platform.Event` to determine whether to return None
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
# https://github.com/vllm-project/vllm/pull/14411
# Future Plan:
# Revert it when the related pr is merged in vllm.
#
# ** File: worker/patch_common/patch_multi_step_worker.py **
# 1. `vllm.spec_decode.multi_step_worker.MultiStepWorker.sampler_output`
# Why:
# There are cuda hard code (current_platform.is_cuda_alike()) in
# `MultiStepWorker.sampler_output`, and we need to use the patched `TP1DraftModelRunner` in it.
# How
# Make speculative decoding extensible to different backends.
# - support attention metadata register to the set supported spec decode
# - offer a api in platform to determine whether spec decode is supported,
# and deprecate is_cuda_alike in it.
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
# - https://github.com/vllm-project/vllm/pull/15195
# - https://github.com/vllm-project/vllm-ascend/pull/395
# Future Plan:
# Revert it when the related pr is merged in vllm and vllm-ascend.
#
# ** File: worker/patch_common/patch_multi_step_worker.py **
# 1. `vllm.spec_decode.spec_decode_worker.SpecDecodeWorker.create_worker`
# Why:
# We need to use the patched `TP1DraftModelRunner` in `SpecDecodeWorker.create_worker`.
# The mainly reason to overwrite `TP1DraftModelRunner`is the hard code of
# `FlashAttentionMetadata`
# How
# ditto
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
# - https://github.com/vllm-project/vllm/pull/15195
# - https://github.com/vllm-project/vllm-ascend/pull/395
# Future Plan:
# Revert it when the related pr is merged in vllm and vllm-ascend.
# current_platform.is_cuda_alike()
# 0.8.4 patch doc:
# platform-0.8.4 + platform-common + worker-0.8.4 + worker-common
# ...
import vllm_ascend.patch.worker.patch_common.patch_metrics # noqa
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa

View File

@@ -0,0 +1,88 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Callable, Optional, Union
import torch
import torch_npu
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics)
Timer = Callable[[], float]
# TODO: revert this patch when the cuda hard code is removed in vllm
# init_tensors: Modified the hard-coded cuda judgment logic to npu;
# maybe_collect_rejsample_metrics: Removed the check for current_platform.is_cuda_alike()
def init_tensors(self,
rank: int,
device_type: Union[torch.device, str] = 'npu') -> None:
self._rank = rank
if isinstance(device_type, torch.device):
device_type = device_type.type
if device_type == 'npu':
self._copy_stream = torch_npu.npu.Stream()
def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# If a copy was initiated in the previous call, collect and return.
if self._in_flight_copy is not None:
ready_event = self._in_flight_copy
self._in_flight_copy = None
return self._collect_rejsample_metrics(k, ready_event)
# Otherwise, check if we should start a new copy.
if self._should_collect_rejsample_metrics(self._timer()):
assert self._in_flight_copy is None
self._in_flight_copy = self._copy_rejsample_metrics_async()
return None
def _copy_rejsample_metrics_async(self) -> torch.npu.Event:
"""
TODO: torch.cuda.xxx --> torch.npu.xxx
Copy rejection/typical-acceptance sampling metrics
(number of accepted tokens, etc) to CPU asynchronously.
Returns a NPU event recording when the copy is complete.
"""
assert self._copy_stream is not None
self._copy_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(self._copy_stream):
self._aggregate_num_accepted_tokens.copy_(
self.spec_decode_sampler.num_accepted_tokens, non_blocking=True)
self._aggregate_num_emitted_tokens.copy_(
self.spec_decode_sampler.num_emitted_tokens, non_blocking=True)
# Number of draft tokens is calculated on CPU, so no copy is
# required.
self._aggregate_num_draft_tokens = (
self.spec_decode_sampler.num_draft_tokens)
aggregate_metrics_ready = torch.npu.Event()
aggregate_metrics_ready.record(self._copy_stream)
return aggregate_metrics_ready
AsyncMetricsCollector.init_tensors = init_tensors
AsyncMetricsCollector.maybe_collect_rejsample_metrics = maybe_collect_rejsample_metrics
AsyncMetricsCollector._copy_rejsample_metrics_async = _copy_rejsample_metrics_async

View File

@@ -0,0 +1,87 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Set, Tuple
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
seq_ids_with_bonus_token_in_last_step: Set[int],
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass sample_len times. Returns the list of
sampler output, one per model forward pass, along with indicator of
whether torch tensor in sampler output need to be transposed in latter
sampler_output_to_torch logic.
For multi step worker, this indicator shall be True.
"""
self._raise_if_unsupported(execute_model_req)
# Expand the batch for sequences with a bonus token.
# Perform a forward pass on the expanded batch and filter the
# response to retain only the original sequences' responses.
expanded_request, indices_of_seq_with_bonus_tokens =\
self._expand_execute_model_request(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
# TODO: supports_gpu_multi_step is False in ASCEND
if isinstance(self.model_runner, TP1DraftModelRunner) and \
self.model_runner.supports_gpu_multi_step(expanded_request):
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request.num_steps = sample_len
self.model_runner.set_indices_of_seq_with_bonus_tokens(
indices_of_seq_with_bonus_tokens)
model_outputs = self.execute_model(execute_model_req=expanded_request)
else:
# Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
for _ in range(sample_len):
model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request)
assert (len(model_output) == 1
), "composing multistep workers not supported"
model_output = model_output[0]
self._append_new_tokens(model_output,
expanded_request.seq_group_metadata_list,
indices_of_seq_with_bonus_tokens)
model_outputs.append(model_output)
# move indices to device to avoid stream sync
indices_of_seq_with_bonus_tokens = torch.tensor(
indices_of_seq_with_bonus_tokens, device=self.device)
filtered_model_outputs = self._filter_model_output(
model_outputs, indices_of_seq_with_bonus_tokens)
return filtered_model_outputs, True
MultiStepWorker.sampler_output = torch.inference_mode()(sampler_output)

View File

@@ -0,0 +1,151 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Dict, Optional
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.spec_decode_base_sampler import \
SpecDecodeBaseSampler
from vllm.model_executor.layers.typical_acceptance_sampler import \
TypicalAcceptanceSampler
from vllm.spec_decode.medusa_worker import MedusaWorker
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
from vllm.worker.worker_base import WorkerBase
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
logger = init_logger(__name__)
def create_worker(
cls,
scorer_worker: WorkerBase,
draft_worker_kwargs: Dict[str, Any],
disable_mqa_scorer: bool,
disable_by_batch_size: Optional[int],
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool,
disable_log_stats: bool,
num_speculative_tokens: int,
) -> "SpecDecodeWorker":
allow_zero_draft_token_step = True
enable_lm_head_weight_load = False
num_spec_prefill_steps = 1
ngram_prompt_lookup_max = (
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
ngram_prompt_lookup_min = (
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
draft_model_config = draft_worker_kwargs["vllm_config"].model_config
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
'vllm_config'].parallel_config
if ngram_prompt_lookup_max > 0:
draft_worker_kwargs[
"device_type"] = scorer_worker.device_config.device.type
proposer_worker = NGramWorker(**draft_worker_kwargs)
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
ngram_prompt_lookup_max)
else:
draft_tp = draft_parallel_config.tensor_parallel_size
target_tp = scorer_worker.parallel_config.tensor_parallel_size
if draft_model_config.hf_config.model_type == "mlp_speculator":
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
elif draft_model_config.hf_config.model_type == "medusa":
proposer_worker = MedusaWorker(**draft_worker_kwargs)
else:
# Note: The current version of the MTP module doer not support
# the use of TP1DraftModelRunner
if draft_tp == 1 and draft_model_config.hf_config.model_type !=\
"deepseek_mtp":
draft_worker_kwargs["model_runner_cls"] = TP1DraftModelRunner
else:
if draft_model_config.hf_config.model_type == "eagle":
raise NotImplementedError(
f"{draft_model_config.hf_config.model_type} "
"does not support TP > 1 yet")
allow_zero_draft_token_step = False
# Load lm_head weight for eagle in init_device
if draft_model_config.hf_config.model_type == "eagle":
enable_lm_head_weight_load = True
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
if draft_model_config.hf_config.model_type == "deepseek_mtp":
num_spec_prefill_steps = num_speculative_tokens
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
proposer_worker, draft_tp, target_tp)
logger.info("Configuring SpecDecodeWorker with proposer=%s",
type(proposer_worker))
spec_decode_sampler: SpecDecodeBaseSampler = None
if draft_token_acceptance_method == "rejection_sampler":
spec_decode_sampler = RejectionSampler()
elif draft_token_acceptance_method == "typical_acceptance_sampler":
spec_decode_sampler = TypicalAcceptanceSampler(
posterior_threshold=\
typical_acceptance_sampler_posterior_threshold,
posterior_alpha=typical_acceptance_sampler_posterior_alpha,
)
logger.info(
"[Speculative Decoding] Configuring"
" SpecDecodeWorker with sampler=%s", type(spec_decode_sampler))
if not disable_mqa_scorer:
if scorer_worker.model_runner.attn_backend.get_name() != "FLASH_ATTN":
disable_mqa_scorer = True
logger.info("[Speculative Decoding] Disabling MQA scorer as the "
"MQA is only available with flash attn backend.")
if draft_model_config and \
draft_model_config.max_model_len < \
scorer_worker.model_config.max_model_len:
disable_mqa_scorer = True
logger.info("[Speculative Decoding] Disabling MQA scorer as the "
"draft model max_model_len is smaller than the target "
"model max_model_len.")
if not scorer_worker.model_runner.model_config.enforce_eager:
disable_mqa_scorer = True
logger.info("[Speculative Decoding] Disabling MQA scorer as the "
"target model is not running in eager mode.")
return SpecDecodeWorker(
proposer_worker,
scorer_worker,
disable_mqa_scorer=disable_mqa_scorer,
disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step,
enable_lm_head_weight_load=enable_lm_head_weight_load,
num_spec_prefill_steps=num_spec_prefill_steps)
SpecDecodeWorker.create_worker = classmethod(create_worker)

View File

@@ -0,0 +1,320 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Optional
import torch
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MultiModalKwargs
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.worker.model_runner_base import (ModelRunnerBase,
ModelRunnerInputBase,
ModelRunnerWrapperBase)
from vllm_ascend.attention.attention import AscendMetadata
logger = init_logger(__name__)
# A flag to enable debug prints for the updated input tensors
# before each step.
debug_advance_input = False
# A flag to allow GPU advance step for draft model runner.
# Set to False for debugging.
allow_gpu_advance_step = True
class TP1DraftModelRunner(ModelRunnerWrapperBase):
"""Specialized model runner for speculative decoding draft model.
Since the draft model always execute k forward passes consecutively to
generate k speculative tokens in a single speculative decoding step,
we could get rid of most CPU-GPU synchronization and data transfer
overheads by keeping model input and output tensors on GPU all the time.
TODOs:
1. Currently supports only flash-attn, add support for other attn_backends.
2. Support TP > 1 (this requires some designs because we do not expect
any broadcasting inside execute_model).
"""
def __init__(self, model_runner: ModelRunnerBase):
if hasattr(
model_runner,
"return_hidden_states") and model_runner.return_hidden_states:
raise ValueError(
"return_hidden_states is not supported for TP1DraftModelRunner."
)
super().__init__(model_runner)
self.indices_of_seq_with_bonus_tokens = None
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
num_queries):
assert sampling_metadata.num_prompts == 0
assert len(sampling_metadata.seq_groups) == num_queries
assert sampling_metadata.selected_token_indices.shape == (
num_queries, )
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
# Verify that all sequences are decodes
for i in range(num_queries):
seq_group = sampling_metadata.seq_groups[i]
assert seq_group.is_prompt is False # No prompt
assert seq_group.prompt_logprob_indices == [] # No prompt
assert seq_group.sample_indices == [i] # Simple
def _gpu_advance_step(self, model_input: ModelRunnerInputBase,
last_output: SamplerOutput) -> ModelRunnerInputBase:
# Currently, we expect "decode mode" only
assert not model_input.is_prompt
# Get num_seqs
num_seqs = len(model_input.seq_lens)
num_queries = len(model_input.query_lens)
# Get output tokens GPU tensor
sampled_token_ids = last_output.sampled_token_ids
assert sampled_token_ids is not None
# Update attn_metadata
attn_metadata = model_input.attn_metadata
assert isinstance(attn_metadata, AscendMetadata)
attn_metadata.advance_step(model_input, sampled_token_ids,
self.block_size, num_seqs, num_queries)
# Update sampling_metadata
sampling_metadata = model_input.sampling_metadata
self._update_sampling_metadata(sampling_metadata, num_seqs,
num_queries)
# Create new input
new_model_input = self._model_input_cls(
input_tokens=model_input.input_tokens,
input_positions=model_input.input_positions,
attn_metadata=attn_metadata,
seq_lens=attn_metadata.seq_lens,
query_lens=model_input.query_lens,
# Notes: If vllm_ascend supports LORA, we need to
# add the following two params.
# lora_mapping=model_input.lora_mapping,
# lora_requests=model_input.lora_requests,
multi_modal_kwargs=model_input.multi_modal_kwargs,
sampling_metadata=model_input.sampling_metadata,
is_prompt=False,
)
# Ensure we skip CPU samples
assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True
# We can reuse sampling tensors since every decode iteration is the same
new_model_input.sampling_metadata.reuse_sampling_tensors = True
if debug_advance_input:
logger.debug("NEW INPUT: ")
logger.debug(" input_tokens = %s", new_model_input.input_tokens)
logger.debug(" input_positions = %s",
new_model_input.input_positions)
logger.debug(" seq_lens = %d", new_model_input.seq_lens)
logger.debug(" query_lens = %d", new_model_input.query_lens)
logger.debug(" attn_metadata:")
logger.debug(" seq_lens_tensor: %s",
attn_metadata.seq_lens_tensor)
logger.debug(" slot_mapping: %s", attn_metadata.slot_mapping)
logger.debug(" block_tables: %s", attn_metadata.block_tables)
return new_model_input
def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
"""Determines if draft_model_runner GPU multi-step can be used.
Currently required conditions are:
1. Only decodes
2. Only flash-attn
3. No LORA
4. No prompt_adapter_config
"""
if not allow_gpu_advance_step:
return False
# We allow multi-step GPU only in decode mode
for seq_group in execute_model_req.seq_group_metadata_list:
if seq_group.is_prompt:
return False
# TODO: Add support for ASCEND when outer multi_step_worker
# could work correct.
if self.attn_backend.get_name() not in ("FLASH_ATTN", "TRITON_MLA"):
return False
# TODO: Add support for LORA
if self.lora_config:
return False
# TODO: Add soft-tuning prompt adapter support
return not self.prompt_adapter_config
def set_indices_of_seq_with_bonus_tokens(self,
indices_of_seq_with_bonus_tokens):
self.indices_of_seq_with_bonus_tokens = indices_of_seq_with_bonus_tokens
@torch.inference_mode()
def execute_model(
self,
model_input: ModelRunnerInputBase,
kv_caches: List[torch.Tensor],
previous_hidden_states: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
**kwargs,
) -> Optional[List[SamplerOutput]]:
"""Executes num_steps forward passes with advacement of input tensors
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
Optimizations used:
1. Input tensors are updated on the GPU directly
2. Skips GPU=>CPU serialization of sampler outputs (we don't need
them since we do batch expansion later that uses GPU outputs)
3. Reuses sampling tensors (since we run only decodes and they have
a repeating sampling logic)
"""
# When num_steps == 1, we execute the fallback here for the GPU
# advance_step, which runs prepare_inputs on CPU and for each spec
# iteration invokes this function only once
# (Look at multi-step-worker code)
is_fallback = num_steps == 1
if not is_fallback:
# Since we do not broadcast data inside execute_model anymore,
# we need to figure out the best way to support TP > 1 in this
# case, because we will at least need to broadcast the sampled
# tokens to all workers.
if not self.is_driver_worker:
raise ValueError("TP1DraftModelRunner only supports TP=1.")
# Sanity
if self.lora_config is not None:
raise ValueError("TP1DraftModelRunner has no support for LORA")
if self.prompt_adapter_config is not None:
raise ValueError("TP1DraftModelRunner has no support for "
"prompt_adapter_config")
if model_input.multi_modal_kwargs:
raise ValueError(
"TP1DraftModelRunner has no support for multi_modal_kwargs"
)
else:
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
self.attn_state.begin_forward(model_input)
# Detect exec mode
assert model_input.attn_metadata is not None
if model_input.attn_metadata.num_prefills > 0:
# In this case, execute_model(..) was called directly
if num_steps > 1:
raise ValueError(
"execute_model(..) of draft_model_runner can be called "
"directly only with a single-step prefill")
else:
# We can skip CPU samples for spec token generation.
# (We do allow CPU samples for num_steps == 1 to support the
# fallback case, where supports_gpu_multi_step(..) does not pass)
model_input.sampling_metadata.skip_sampler_cpu_output = (
not is_fallback)
model_executable = self.model
hidden_states = previous_hidden_states
outputs: List[SamplerOutput] = []
for step in range(num_steps):
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
model_execute_kwargs = {"previous_hidden_states": hidden_states} \
if previous_hidden_states is not None else {}
compute_logits_kwargs = {}
# Run model
if hasattr(self.model.config, "num_nextn_predict_layers"):
# for DeepSeek MTP only to use the corresponding layer for
# each step
spec_step_idx = kwargs.get("spec_step_idx", step)
model_execute_kwargs["spec_step_idx"] = spec_step_idx
compute_logits_kwargs["spec_step_idx"] = spec_step_idx
with set_forward_context(model_input.attn_metadata,
self.vllm_config):
if model_input.attn_metadata is not None:
model_input.attn_metadata.input_positions = model_input.input_positions
hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**model_execute_kwargs,
)
# Compute the logits.
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata,
**compute_logits_kwargs)
if not self.is_driver_worker:
return []
# Sample the next token.
output = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
outputs.append(output)
if model_input.attn_metadata.num_prefills == 0 \
and self.indices_of_seq_with_bonus_tokens is not None:
assert output.sampled_token_ids is not None
# output.sampled_token_ids should be of shape (num_seqs, 1)
nums_seqs, num_tokens_per_seq = output.sampled_token_ids.shape
assert num_tokens_per_seq == 1
count = 0
for i in range(nums_seqs):
bonus_seq_idx = self.indices_of_seq_with_bonus_tokens[
count]
if i != bonus_seq_idx:
# The following might cause a cpu->gpu sync
# However, the performance impact is negligible as we
# benchmarked on H100.
output.sampled_token_ids[
i, :] = model_input.input_tokens[bonus_seq_idx]
else:
count += 1
# Prepare inputs for the next step
if step != num_steps - 1:
model_input = self._gpu_advance_step(model_input, outputs[-1])
return outputs