[Bugfix][CI] Remove V0 Spec Decode CI (#1656)
### What this PR does / why we need it?
To solve the error in the CI of long term test:
```bash
modelscope - ERROR - Repo JackFram/llama-68m not exists on either https://www.modelscope.cn/ or https://www.modelscope.ai/
```
Replace the hf model with modelscope model.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.9.1
- vLLM main:
71d1d75b7a
---------
Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com>
This commit is contained in:
@@ -96,13 +96,8 @@ jobs:
|
||||
- name: Run vllm-project/vllm-ascend long term test
|
||||
run: |
|
||||
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
|
||||
# v0 spec decode test
|
||||
# TODO: Revert me when test_mtp_correctness is fixed
|
||||
# VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/spec_decode_v0/e2e/test_mtp_correctness.py # it needs a clean process
|
||||
pytest -sv tests/e2e/long_term/spec_decode_v0 --ignore=tests/e2e/long_term/spec_decode_v0/e2e/test_mtp_correctness.py
|
||||
# accuracy test single card
|
||||
pytest -sv tests/e2e/long_term/test_accuracy.py
|
||||
else
|
||||
# else
|
||||
# accuracy test multi card
|
||||
VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/test_deepseek_v2_lite_tp2_accuracy.py
|
||||
# VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/test_deepseek_v2_lite_tp2_accuracy.py
|
||||
fi
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from vllm_ascend.patch import worker # noqa: F401
|
||||
@@ -1,28 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/tests/spec_decode/conftest.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
Since this module is V0 only, set VLLM_USE_V1=0 for
|
||||
all tests in the module.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
@@ -1,212 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/conftest.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import shutil
|
||||
from itertools import cycle
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
from vllm import SamplingParams
|
||||
from vllm.sequence import PromptLogprobs, SampleLogprobs
|
||||
|
||||
from tests.model_utils import (TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs,
|
||||
check_logprobs_close, check_outputs_equal)
|
||||
|
||||
PROMPTS = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
"San Francisco is know for its",
|
||||
"Facebook was created in 2004 by",
|
||||
"Curious George is a",
|
||||
"Python 3.11 brings improvements to its",
|
||||
]
|
||||
|
||||
|
||||
def check_logprobs_correctness(
|
||||
spec_outputs: Sequence[Union[TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs]],
|
||||
baseline_outputs: Sequence[Union[TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs]],
|
||||
disable_logprobs: bool = False,
|
||||
):
|
||||
"""Compare sampled and prompt logprobs between baseline and spec decoding
|
||||
"""
|
||||
if not disable_logprobs:
|
||||
return check_logprobs_close(
|
||||
outputs_0_lst=baseline_outputs,
|
||||
outputs_1_lst=spec_outputs,
|
||||
name_0="org",
|
||||
name_1="sd",
|
||||
)
|
||||
|
||||
# Check correctness when disable_logprobs == True
|
||||
for spec_output, baseline_output in zip(spec_outputs, baseline_outputs):
|
||||
# Check generated token logprobs.
|
||||
spec_logprobs = spec_output[2]
|
||||
baseline_logprobs = baseline_output[2]
|
||||
_check_logprobs_when_output_disabled(spec_logprobs,
|
||||
baseline_logprobs,
|
||||
is_prompt_logprobs=False)
|
||||
|
||||
# Check prompt logprobs too, if they exist
|
||||
if len(baseline_output) == 4:
|
||||
assert len(spec_output) == 4
|
||||
spec_prompt_logprobs = spec_output[3]
|
||||
baseline_prompt_logprobs = baseline_output[3]
|
||||
_check_logprobs_when_output_disabled(spec_prompt_logprobs,
|
||||
baseline_prompt_logprobs,
|
||||
is_prompt_logprobs=True)
|
||||
|
||||
|
||||
def _check_logprobs_when_output_disabled(
|
||||
spec_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
|
||||
baseline_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
|
||||
is_prompt_logprobs: bool = False,
|
||||
):
|
||||
# Prompt logprobs are optional
|
||||
if is_prompt_logprobs and baseline_logprobs is None:
|
||||
assert spec_logprobs is None
|
||||
return
|
||||
|
||||
assert spec_logprobs is not None
|
||||
assert baseline_logprobs is not None
|
||||
assert len(spec_logprobs) == len(baseline_logprobs)
|
||||
|
||||
# For each generated position of the sequence.
|
||||
for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
|
||||
zip(spec_logprobs, baseline_logprobs)):
|
||||
|
||||
# First prompt logprob is expected to be None
|
||||
if is_prompt_logprobs and baseline_pos_logprobs is None:
|
||||
assert spec_pos_logprobs is None
|
||||
assert pos == 0
|
||||
continue
|
||||
|
||||
assert spec_pos_logprobs is not None
|
||||
assert baseline_pos_logprobs is not None
|
||||
|
||||
# When disabled, the 1 logprob is returned with dummy values for the
|
||||
# score and rank, but the token id should match the baseline model
|
||||
assert len(spec_pos_logprobs) == 1
|
||||
(spec_pos_logprob_token_id,
|
||||
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
|
||||
assert spec_pos_logprob.rank == -1
|
||||
assert spec_pos_logprob.logprob == 0.0
|
||||
if isinstance(spec_pos_logprob_token_id, torch.Tensor):
|
||||
spec_pos_logprob_token_id = spec_pos_logprob_token_id.item()
|
||||
assert spec_pos_logprob_token_id in baseline_pos_logprobs
|
||||
|
||||
|
||||
def _clean_torchair_cache():
|
||||
cache_path = Path.cwd() / '.torchair_cache'
|
||||
if cache_path.exists() and cache_path.is_dir():
|
||||
shutil.rmtree(cache_path)
|
||||
|
||||
|
||||
def run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size: int,
|
||||
max_output_len: int,
|
||||
seed: Optional[int] = 0,
|
||||
temperature: float = 0.0,
|
||||
disable_seed: bool = False,
|
||||
ignore_eos: bool = True,
|
||||
ensure_all_accepted: bool = False,
|
||||
expected_acceptance_rate: Optional[float] = None,
|
||||
logprobs: Optional[int] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
disable_logprobs: bool = False):
|
||||
|
||||
org_args = {
|
||||
**common_llm_kwargs,
|
||||
**per_test_common_llm_kwargs,
|
||||
**baseline_llm_kwargs,
|
||||
}
|
||||
|
||||
sd_args = {
|
||||
**common_llm_kwargs,
|
||||
**per_test_common_llm_kwargs,
|
||||
**test_llm_kwargs,
|
||||
}
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
|
||||
|
||||
if disable_seed:
|
||||
seed = None
|
||||
|
||||
sampling_params = SamplingParams(temperature=temperature,
|
||||
max_tokens=max_output_len,
|
||||
seed=seed,
|
||||
ignore_eos=ignore_eos,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
|
||||
# TODO current torchair graph mode needs clean torchair cache.
|
||||
# if do not clean, it will raise error
|
||||
torchair_graph_enabled = common_llm_kwargs.get(
|
||||
"additional_config", {}).get("torchair_graph_config",
|
||||
{}).get("enabled", False)
|
||||
|
||||
with vllm_runner(**org_args) as vllm_model:
|
||||
if torchair_graph_enabled:
|
||||
_clean_torchair_cache()
|
||||
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
||||
|
||||
with vllm_runner(**sd_args) as vllm_model:
|
||||
if torchair_graph_enabled:
|
||||
_clean_torchair_cache()
|
||||
if ensure_all_accepted or expected_acceptance_rate is not None:
|
||||
# Force log interval to be 0 to catch all metrics.
|
||||
stat_logger = vllm_model.model.llm_engine.stat_loggers[
|
||||
'prometheus']
|
||||
stat_logger.local_interval = -100
|
||||
|
||||
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
||||
|
||||
if ensure_all_accepted or expected_acceptance_rate is not None:
|
||||
acceptance_rate = (stat_logger.metrics.
|
||||
gauge_spec_decode_draft_acceptance_rate.labels(
|
||||
**stat_logger.labels)._value.get())
|
||||
|
||||
if ensure_all_accepted:
|
||||
assert True
|
||||
# FIXME: ci fails to log acceptance rate.
|
||||
# It works locally.
|
||||
# assert acceptance_rate == 1.0
|
||||
|
||||
if expected_acceptance_rate is not None:
|
||||
assert acceptance_rate >= expected_acceptance_rate - 1e-2
|
||||
|
||||
# Only pass token entries, not the logprobs
|
||||
check_outputs_equal(outputs_0_lst=[out[0:2] for out in org_outputs],
|
||||
outputs_1_lst=[out[0:2] for out in sd_outputs],
|
||||
name_0="org",
|
||||
name_1="sd")
|
||||
|
||||
# Check logprobs if requested
|
||||
if logprobs is not None or prompt_logprobs is not None:
|
||||
check_logprobs_correctness(sd_outputs, org_outputs, disable_logprobs)
|
||||
@@ -1,344 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_eagle_correctness.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""This docstring details important information on the testing methodology.
|
||||
|
||||
Most of the tests rely on "greedy equality", where we expect the output of
|
||||
speculative decoding on a sequence to exactly match the output of normal non-
|
||||
speculative decoding.
|
||||
|
||||
Since speculative decoding with rejection sampling guarantees that the output
|
||||
distribution matches the target model's output distribution (up to hardware
|
||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||
equality.
|
||||
|
||||
However, we still need to verify below scenario could be passed:
|
||||
* Batch size 1 greedy equality
|
||||
* Batch size >1 greedy equality
|
||||
* Test greedy equality under preemption
|
||||
* Test greedy equality under various number of speculative tokens.
|
||||
|
||||
With those tests, we can say at least, EAGLE would not break the
|
||||
correctness for the target model outputs.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.e2e.long_term.spec_decode_v0.e2e.conftest import \
|
||||
run_equality_correctness_test
|
||||
|
||||
# main model
|
||||
MAIN_MODEL = "JackFram/llama-68m"
|
||||
|
||||
# speculative model
|
||||
SPEC_MODEL = "abhigoyal/vllm-eagle-llama-68m-random"
|
||||
|
||||
# max. number of speculative tokens: this corresponds to
|
||||
# num_heads in the config.json of the speculator model.
|
||||
MAX_SPEC_TOKENS = 4
|
||||
|
||||
# precision
|
||||
# TODO The vLLM here uses float32, but some op on the vllm-ascend
|
||||
# do not support float32, such as ROPE, When it is fixed, it is
|
||||
# recommended to change this to float32.
|
||||
PRECISION = "float16"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
}, {
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int,
|
||||
logprobs: int):
|
||||
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="Open it when graph mode ready.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"enforce_eager": False,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_eagle_e2e_greedy_correctness_cuda_graph(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality with cuda graph enabled and different
|
||||
batch sizes."""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="Open it when preempt ready.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 8,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_eagle_e2e_greedy_correctness_with_preemption(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
}
|
||||
# Try a range of num. speculative tokens
|
||||
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_eagle_different_k(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that eagle speculative decoding produces exact equality
|
||||
to without spec decode with different values of num_speculative_tokens.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_by_batch_size": 4,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that eagle speculative decoding produces exact equality
|
||||
to without spec decode when speculation is disabled for large
|
||||
batch sizes.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
pytest.main([__file__])
|
||||
@@ -1,446 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_medusa_correctness.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""This docstring details important information on the testing methodology.
|
||||
|
||||
Most of the tests rely on "greedy equality", where we expect the output of
|
||||
speculative decoding on a sequence to exactly match the output of normal non-
|
||||
speculative decoding.
|
||||
|
||||
Since speculative decoding with rejection sampling guarantees that the output
|
||||
distribution matches the target model's output distribution (up to hardware
|
||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||
equality.
|
||||
|
||||
However, we still need to verify below scenario could be passed:
|
||||
* Batch size 1 greedy equality
|
||||
* Batch size >1 greedy equality
|
||||
* Test greedy equality under preemption
|
||||
* Test greedy equality under various number of speculative tokens.
|
||||
|
||||
With those tests, we can say at least, Medusa would not break the
|
||||
correctess for the target model outputs.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.e2e.long_term.spec_decode_v0.e2e.conftest import \
|
||||
run_equality_correctness_test
|
||||
from tests.e2e.long_term.spec_decode_v0.utils import \
|
||||
maybe_enable_chunked_prefill
|
||||
|
||||
# main model
|
||||
# lmsys/vicuna-7b-v1.3 was to be used but it's causing
|
||||
# OOM in CI pipeline, so using a smaller model.
|
||||
MAIN_MODEL = "JackFram/llama-68m"
|
||||
|
||||
# speculative model
|
||||
SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random"
|
||||
|
||||
# max number of speculative tokens: this corresponds to
|
||||
# num_heads in the config.json of the speculator model.
|
||||
MAX_SPEC_TOKENS = 5
|
||||
|
||||
# precision
|
||||
# TODO: The vLLM here uses float32, but some op on the vllm-ascend
|
||||
# do not support float32, such as ROPE, When it is fixed, it is
|
||||
# recommended to change this to float32 to keep it consistent
|
||||
# with vLLM.
|
||||
PRECISION = "float16"
|
||||
|
||||
PREFILL_CHUNK_SIZE = [
|
||||
-1,
|
||||
# TODO:enable chunked prefill when it is supported
|
||||
# 32
|
||||
]
|
||||
|
||||
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
|
||||
def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int, prefill_chunk_size: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
8,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
|
||||
def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int, logprobs: int,
|
||||
prefill_chunk_size: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="Open it when graph mode ready.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"enforce_eager": False,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
|
||||
def test_medusa_e2e_greedy_correctness_cuda_graph(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int, prefill_chunk_size: int):
|
||||
"""Verify greedy equality with cuda graph enabled and different
|
||||
batch sizes."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="Open it when preempt ready.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 16,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
|
||||
def test_medusa_e2e_greedy_correctness_with_preemption(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int, prefill_chunk_size: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
}
|
||||
# Try a range of num. speculative tokens
|
||||
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
|
||||
def test_medusa_different_k(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int, prefill_chunk_size: int):
|
||||
"""Verify that medusa speculative decoding produces exact equality
|
||||
to without spec decode with different values of num_speculative_tokens.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_by_batch_size": 4,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
|
||||
def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int,
|
||||
prefill_chunk_size: int):
|
||||
"""Verify that medusa speculative decoding produces exact equality
|
||||
to without spec decode when speculation is disabled for large
|
||||
batch sizes.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_by_batch_size": 4,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
|
||||
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int, prefill_chunk_size: int):
|
||||
"""Verify that speculative decoding generates the same output
|
||||
with batch expansion scorer and mqa scorer.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
@@ -1,561 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_mlp_correctness.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""This docstring details important information on the testing methodology.
|
||||
|
||||
Most of the tests rely on "greedy equality", where we expect the output of
|
||||
speculative decoding on a sequence to exactly match the output of normal non-
|
||||
speculative decoding.
|
||||
|
||||
Since speculative decoding with rejection sampling guarantees that the output
|
||||
distribution matches the target model's output distribution (up to hardware
|
||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||
equality.
|
||||
|
||||
However, we still need to verify below scenario could be passed:
|
||||
* Batch size 1 greedy equality
|
||||
* Batch size >1 greedy equality
|
||||
* Test greedy equality under preemption
|
||||
* Test greedy equality under various number of speculative tokens.
|
||||
|
||||
With those tests, we can say at least, MLPSpeculator would not break the
|
||||
correctness for the target model outputs.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import \
|
||||
pad_vocab_size # noqa: F401
|
||||
|
||||
from tests.e2e.long_term.spec_decode_v0.e2e.conftest import \
|
||||
run_equality_correctness_test
|
||||
from tests.e2e.long_term.spec_decode_v0.utils import \
|
||||
maybe_enable_chunked_prefill
|
||||
|
||||
# main model
|
||||
MAIN_MODEL = "JackFram/llama-160m"
|
||||
|
||||
# speculative model
|
||||
SPEC_MODEL = "ibm-ai-platform/llama-160m-accelerator"
|
||||
|
||||
# max. number of speculative tokens: this corresponds to
|
||||
# n_predict in the config.json of the speculator model.
|
||||
MAX_SPEC_TOKENS = 3
|
||||
|
||||
PREFILL_CHUNK_SIZE_1 = [
|
||||
-1,
|
||||
# TODO:enable chunked prefill when it is supported
|
||||
# 4
|
||||
]
|
||||
PREFILL_CHUNK_SIZE_2 = [
|
||||
-1,
|
||||
# TODO:enable chunked prefill when it is supported
|
||||
# 32
|
||||
]
|
||||
# precision
|
||||
# TODO: The vLLM here uses float32, but some op on the vllm-ascend
|
||||
# do not support float32, such as ROPE, When it is fixed, it is
|
||||
# recommended to change this to float32 to keep it consistent
|
||||
# with vLLM.
|
||||
PRECISION = "float16"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_2)
|
||||
def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int, prefill_chunk_size: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [8])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
|
||||
def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int,
|
||||
logprobs: int, prefill_chunk_size: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
# NOTE Test is sensitive enough st if we don't enable chunked prefill
|
||||
# scheduling on baseline too, we get slightly different logprobs, ending
|
||||
# up sampling different tokens at the tail (ie top tokens don't change).
|
||||
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [2048])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
|
||||
def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify acceptance rate with different batch size and large output
|
||||
length."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
temperature=0.0,
|
||||
seed=seed,
|
||||
expected_acceptance_rate=0.48)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
|
||||
# Speculative config
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
|
||||
@pytest.mark.parametrize("output_len", [64])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("temperature", [1.0])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
temperature: float,
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify seeded runs produce the same output."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
temperature=temperature,
|
||||
seed=seed)
|
||||
|
||||
# Ensure this same test does fail if we _don't_ include per-request seeds
|
||||
with pytest.raises(AssertionError):
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
temperature=temperature,
|
||||
seed=seed,
|
||||
disable_seed=True)
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="Open it when preempt ready.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 16,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_e2e_greedy_correctness_with_preemption(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="Open it when preempt ready.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 16,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
|
||||
def test_mlp_e2e_greedy_correctness_with_padding(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify greedy equality when the vocab dimension is padded
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
|
||||
# Default pad_to is 64, test model has vocab_size of 32000
|
||||
def patched_pad_vocab_size(vocab_size, pad_to=None):
|
||||
return pad_vocab_size(vocab_size, pad_to=32064)
|
||||
|
||||
# NOTE: Compared with vLLM, the patch method has been modified
|
||||
pad_vocab_size = patched_pad_vocab_size # noqa: F811
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
}
|
||||
# Try a range of num. speculative tokens
|
||||
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_different_k(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
prefill_chunk_size: int, seed: int, output_len: int):
|
||||
"""Verify that mlp speculative decoding produces exact equality
|
||||
to without spec decode with different values of num_speculative_tokens.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"disable_by_batch_size": 4,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
# Speculative decoding is disabled when sequences reach decoding and the batch
|
||||
# consists of single-token requests. Hence we set `max_num_seqs`
|
||||
# >= `speculative_disable_by_batch_size` to test feature interaction.
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
prefill_chunk_size: int, seed: int,
|
||||
output_len: int):
|
||||
"""Verify that mlp speculative decoding produces exact equality
|
||||
to without spec decode when speculation is disabled for large
|
||||
batch sizes.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": MAIN_MODEL,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||
output_len: int, prefill_chunk_size: int, seed: int):
|
||||
"""Verify that speculative decoding generates the same output
|
||||
with batch expansion scorer and mqa scorer.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
@@ -1,455 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_mtp_correctness.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""This docstring details important information on the testing methodology.
|
||||
|
||||
Most of the tests rely on "greedy equality", where we expect the output of
|
||||
speculative decoding on a sequence to exactly match the output of normal non-
|
||||
speculative decoding.
|
||||
|
||||
Since speculative decoding with rejection sampling guarantees that the output
|
||||
distribution matches the target model's output distribution (up to hardware
|
||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||
equality.
|
||||
|
||||
However, we still need to verify below scenario could be passed:
|
||||
* Batch size 1 greedy equality
|
||||
* Batch size >1 greedy equality
|
||||
* Test greedy equality under preemption
|
||||
* Test greedy equality under various number of speculative tokens.
|
||||
|
||||
With those tests, we can say at least, mtp would not break the
|
||||
correctess for the target model outputs.
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
# NOTE both main model and MTP are bfloat16
|
||||
FLOAT_MODEL = "wemaster/deepseek_mtp_main_random_bf16"
|
||||
|
||||
# NOTE main model is w8a8, MTP is bfloat16
|
||||
QUANT_MODEL = "wemaster/deepseek_mtp_main_random_w8a8_part"
|
||||
|
||||
# TODO when msmodelslim can quantify both main and MTP model
|
||||
# This UT should use w8a8 fully weights.
|
||||
|
||||
# max. number of speculative tokens: this corresponds to
|
||||
# num_nextn_predict_layers in the config.json of the speculator model.
|
||||
MAX_SPEC_TOKENS = 1
|
||||
|
||||
# precision
|
||||
PRECISION = "bfloat16"
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
|
||||
reason="mtp is not supported on v1")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": FLOAT_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.85
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="quant model is not ready.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": QUANT_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.8
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_e2e_quant_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
|
||||
reason="mtp is not supported on v1")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": FLOAT_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.8
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int,
|
||||
logprobs: int):
|
||||
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="torchair ut can not clean mem.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"additional_config": {
|
||||
'torchair_graph_config': {
|
||||
"enabled": True,
|
||||
},
|
||||
},
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": FLOAT_MODEL,
|
||||
"gpu_memory_utilization": 0.8
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_e2e_greedy_correctness_torchair_graph(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality with torchair graph enabled and different
|
||||
batch sizes using bfloat16 weights."""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="quant model is not ready.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"additional_config": {
|
||||
'torchair_graph_config': {
|
||||
"enabled": True,
|
||||
},
|
||||
},
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": QUANT_MODEL,
|
||||
"gpu_memory_utilization": 0.8
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_e2e_quant_greedy_correctness_torchair_graph(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality with torchair graph enabled and different
|
||||
batch sizes using quant weights."""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
|
||||
reason="mtp is not supported on v1")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 16,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": FLOAT_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.8
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_e2e_greedy_correctness_with_preemption(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
|
||||
reason="mtp is not supported on v1")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": FLOAT_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.8
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
}
|
||||
# Try a range of num. speculative tokens
|
||||
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_different_k(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that mtp speculative decoding produces exact equality
|
||||
to without spec decode with different values of num_speculative_tokens.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
|
||||
reason="mtp is not supported on v1")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": FLOAT_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.8
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_by_batch_size": 4
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that mtp speculative decoding produces exact equality
|
||||
to without spec decode when speculation is disabled for large
|
||||
batch sizes.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
@@ -1,405 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_ngram_correctness.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""This docstring details important information on the testing methodology.
|
||||
|
||||
Most of the tests rely on "greedy equality", where we expect the output of
|
||||
speculative decoding on a sequence to exactly match the output of normal non-
|
||||
speculative decoding.
|
||||
|
||||
Since speculative decoding with rejection sampling guarantees that the output
|
||||
distribution matches the target model's output distribution (up to hardware
|
||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||
equality.
|
||||
|
||||
For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding,
|
||||
and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775.
|
||||
Since there is no model is needed for generate the proposal, we could make
|
||||
the testcase much simpler than drafter multi-step one.
|
||||
|
||||
However, we still need to verify below scenario could be passed:
|
||||
* Batch size 1 greedy equality
|
||||
* Batch size >1 greedy equality
|
||||
* Test greedy equality under preemption
|
||||
* Test greedy equality under various ngram sizes / speculative sizes
|
||||
|
||||
With those tests, we can say at least, ngram spec would not break the correctess
|
||||
for the target model outputs.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.e2e.long_term.spec_decode_v0.e2e.conftest import \
|
||||
run_equality_correctness_test
|
||||
from tests.e2e.long_term.spec_decode_v0.utils import \
|
||||
maybe_enable_chunked_prefill
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_mqa_scorer": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize(
|
||||
"prefill_chunk_size",
|
||||
[
|
||||
-1,
|
||||
# TODO:enable chunked prefill when it is supported
|
||||
# 4
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify greedy equality on a tiny model with different batch size."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
8,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int,
|
||||
logprobs: int):
|
||||
"""Verify greedy equality on a tiny model with different batch size."""
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="Open it when preempt ready.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 16,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"model_name": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_e2e_greedy_correctness_with_preemption(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
temperature=0,
|
||||
seed=seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": k,
|
||||
"prompt_lookup_max": 3,
|
||||
},
|
||||
}
|
||||
# Try a range of common k, as well as large speculation.
|
||||
for k in [1, 3, 5]
|
||||
] + [
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": k,
|
||||
"prompt_lookup_max": 1,
|
||||
},
|
||||
}
|
||||
# Try a range of common k, as well as large speculation.
|
||||
for k in [1, 3, 5]
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_different_k(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that ngram speculative decoding produces exact equality
|
||||
to without spec decode with many different values of k and
|
||||
different ngram_prompt_lookup_max.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_by_batch_size": 4
|
||||
},
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_by_batch_size": 4,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
# FIXME: enable me when chunked prefill is available
|
||||
# "max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
}
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that ngram speculative decoding produces exact equality
|
||||
to without spec decode with many different values of k and
|
||||
different ngram_prompt_lookup_max.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_scorer(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that ngram speculative decoding generates the same output
|
||||
with batch expansion scorer and mqa scorer.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
@@ -1,106 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/tests/spec_decode/test_dynamic_spec_decode.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
|
||||
from tests.e2e.long_term.spec_decode_v0.test_utils import \
|
||||
mock_spec_decode_sampler
|
||||
from tests.e2e.long_term.spec_decode_v0.utils import create_batch, mock_worker
|
||||
|
||||
|
||||
@pytest.mark.parametrize('queue_size', [4])
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
@pytest.mark.parametrize('k', [1])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify that speculative tokens are disabled when the batch size
|
||||
exceeds the threshold.
|
||||
"""
|
||||
disable_by_batch_size = 3
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
worker = SpecDecodeWorker(proposer_worker=draft_worker,
|
||||
scorer_worker=target_worker,
|
||||
spec_decode_sampler=mock_spec_decode_sampler(
|
||||
acceptance_sampler_method),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector,
|
||||
disable_by_batch_size=disable_by_batch_size)
|
||||
|
||||
exception_secret = 'artificial stop'
|
||||
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k,
|
||||
running_queue_size=queue_size)
|
||||
|
||||
if queue_size > disable_by_batch_size:
|
||||
with patch.object(worker,
|
||||
'_run_no_spec',
|
||||
side_effect=ValueError(exception_secret)), \
|
||||
pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
# When the batch size is larger than the threshold,
|
||||
# we expect no speculative tokens (0).
|
||||
expected_num_spec_tokens = None if queue_size < disable_by_batch_size else 0
|
||||
assert seq_group_metadata_list[
|
||||
0].num_speculative_tokens == expected_num_spec_tokens
|
||||
|
||||
draft_worker.sampler_output.side_effect = ValueError(exception_secret)
|
||||
|
||||
proposer = Top1Proposer(
|
||||
worker=draft_worker,
|
||||
device='cpu', # not used
|
||||
vocab_size=100, # not used
|
||||
# Must be long enough to avoid being skipped due to length.
|
||||
max_proposal_len=1024,
|
||||
)
|
||||
|
||||
if queue_size < disable_by_batch_size:
|
||||
# Should raise exception when executing the mocked draft model.
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k),
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
else:
|
||||
# Should not execute the draft model because spec decode is disabled
|
||||
# for all requests. Accordingly, the proposal length should be 0.
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k),
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
assert proposals.proposal_lens.tolist() == [0] * batch_size
|
||||
@@ -1,846 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/tests/spec_decode/test_multi_step_worker.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import random
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob,
|
||||
get_all_seq_ids)
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
|
||||
from tests.e2e.long_term.spec_decode_v0.utils import (
|
||||
assert_logprobs_dict_allclose, create_batch,
|
||||
create_seq_group_metadata_from_prompts, create_worker,
|
||||
patch_execute_model_with_seeds, zero_kv_cache)
|
||||
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
|
||||
from vllm_ascend.worker.worker import NPUWorker
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_steps', list(range(1, 17)))
|
||||
def test_assert_enough_kv_space(num_steps: int):
|
||||
"""Test that the multi step worker checks for sufficient space in the KV
|
||||
cache. It should throw if it cannot run all the steps.
|
||||
"""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
|
||||
prompts = [
|
||||
list(range(block_size * 3)),
|
||||
list(range(block_size * 2)),
|
||||
]
|
||||
|
||||
prev_output_tokens = [
|
||||
list(range(block_size * 1)),
|
||||
list(range(block_size * 2)),
|
||||
]
|
||||
|
||||
final_prompt_lens = [
|
||||
len(prompt + output) + num_steps
|
||||
for prompt, output in zip(prompts, prev_output_tokens)
|
||||
]
|
||||
|
||||
inputs = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens,
|
||||
continuations=prev_output_tokens)
|
||||
|
||||
assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access
|
||||
worker = MagicMock()
|
||||
worker.model_runner.block_size = block_size
|
||||
|
||||
for seq_group_metadata in inputs:
|
||||
original_block_tables = seq_group_metadata.block_tables
|
||||
|
||||
# No exception.
|
||||
assert_enough_kv_space(worker, inputs, num_steps)
|
||||
|
||||
seq_group_metadata.block_tables = {
|
||||
seq_id: []
|
||||
for seq_id, physical_blocks in original_block_tables.items()
|
||||
}
|
||||
|
||||
# Expect exception.
|
||||
with pytest.raises(ValueError,
|
||||
match='times but found insufficient KV space for'):
|
||||
assert_enough_kv_space(worker, inputs, num_steps)
|
||||
|
||||
seq_group_metadata.block_tables = original_block_tables
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_same_output_for_single_step():
|
||||
"""Verify the multi step worker produces the same output as the normal
|
||||
worker for num_steps=1.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
multi_step_worker = create_worker(
|
||||
MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
worker = create_worker(
|
||||
NPUWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
# multi_step_worker.model_runner = worker.model_runner
|
||||
# multi_step_worker.cache_engine = worker.cache_engine
|
||||
|
||||
num_steps = 1
|
||||
|
||||
prompts = [
|
||||
[1, 2, 3, 4, 5],
|
||||
[6, 7, 8, 9, 10],
|
||||
]
|
||||
|
||||
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
|
||||
|
||||
multi_step_seq_group = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
actual_output, _ = multi_step_worker.sampler_output(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=multi_step_seq_group),
|
||||
sample_len=num_steps,
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
assert len(actual_output) == num_steps
|
||||
actual_output = actual_output[0]
|
||||
|
||||
single_step_seq_group = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
expected_output = worker.execute_model(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=single_step_seq_group))[0]
|
||||
|
||||
actual_token_ids = [
|
||||
output.samples[0].output_token for output in actual_output
|
||||
]
|
||||
actual_logprobs = [output.samples[0].logprobs for output in actual_output]
|
||||
|
||||
expected_token_ids = [
|
||||
output.samples[0].output_token for output in expected_output
|
||||
]
|
||||
expected_logprobs = [
|
||||
output.samples[0].logprobs for output in expected_output
|
||||
]
|
||||
|
||||
assert actual_token_ids == expected_token_ids
|
||||
|
||||
print(f'{actual_logprobs=}')
|
||||
print(f'{expected_logprobs=}')
|
||||
assert_logprobs_dict_allclose(actual_logprobs, expected_logprobs)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_same_output_for_multi_step():
|
||||
"""Verify the multi-step worker produces the same output as the normal
|
||||
worker when num_steps > 1. This test runs the multi-step worker once, and
|
||||
then runs the worker num_steps times, and compares the output.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
|
||||
block_size = 16
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
multi_step_worker = create_worker(
|
||||
MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
|
||||
worker = create_worker(
|
||||
NPUWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
|
||||
# Make sure we go over the block boundary.
|
||||
num_steps = block_size + 1
|
||||
|
||||
random.seed(seed)
|
||||
prompts = [[
|
||||
random.randint(0, 1000) for _ in range(random.randint(10, 20))
|
||||
] for _ in range(10)]
|
||||
|
||||
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
|
||||
|
||||
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
|
||||
multi_step_worker.execute_model = patch_execute_model_with_seeds(
|
||||
multi_step_worker, rand_seeds)
|
||||
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
|
||||
|
||||
continuations = [[1] for _ in prompts]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Run multi-step.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
multi_step_output, _ = multi_step_worker.sampler_output(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list),
|
||||
sample_len=num_steps,
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
|
||||
# Run single-step repeatedly.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
single_step_output: list[SamplerOutput] = []
|
||||
continuations = [[1] for _ in prompts]
|
||||
set_random_seed(seed)
|
||||
|
||||
for _ in multi_step_output:
|
||||
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
single_step_output.extend(
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list)))
|
||||
|
||||
# Append output tokens to new sequence data.
|
||||
for i, seq_group_output in enumerate(single_step_output[-1]):
|
||||
continuations[i].append(seq_group_output.samples[0].output_token)
|
||||
|
||||
# Get token ids and logprobs for comparison.
|
||||
multi_step_output_logprobs: list[list[dict[int,
|
||||
Logprob]]] = [[]
|
||||
for _ in prompts]
|
||||
single_step_output_logprobs: list[list[dict[int,
|
||||
Logprob]]] = [[]
|
||||
for _ in prompts]
|
||||
|
||||
multi_step_output_token_ids: list[list[int]] = [[] for _ in prompts]
|
||||
single_step_output_token_ids: list[list[int]] = [[] for _ in prompts]
|
||||
for i, _ in enumerate(prompts):
|
||||
for multi_step, single_step in zip(multi_step_output,
|
||||
single_step_output):
|
||||
multi_step_output_token_ids[i].append(
|
||||
multi_step[i].samples[0].output_token)
|
||||
single_step_output_token_ids[i].append(
|
||||
single_step[i].samples[0].output_token)
|
||||
|
||||
multi_step_output_logprobs[i].append(
|
||||
multi_step[i].samples[0].logprobs)
|
||||
single_step_output_logprobs[i].append(
|
||||
single_step[i].samples[0].logprobs)
|
||||
|
||||
# Print per-sequence token ids
|
||||
for i, (multi_step_tokens, single_step_tokens) in enumerate(
|
||||
zip(multi_step_output_token_ids, single_step_output_token_ids)):
|
||||
print(f'{i=} {multi_step_tokens=}')
|
||||
print(f'{i=} {single_step_tokens=}')
|
||||
print(f'{i=} equal {multi_step_tokens == single_step_tokens}')
|
||||
|
||||
# Assert token ids are equal.
|
||||
for multi_step_tokens, single_step_tokens in zip(
|
||||
multi_step_output_token_ids, single_step_output_token_ids):
|
||||
assert multi_step_tokens == single_step_tokens
|
||||
|
||||
# Assert logprobs are equal.
|
||||
for multi_step_logprobs, single_step_logprobs in zip(
|
||||
multi_step_output_logprobs, single_step_output_logprobs):
|
||||
assert_logprobs_dict_allclose(multi_step_logprobs,
|
||||
single_step_logprobs)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_multi_step_with_batch_expansion_correct_output():
|
||||
"""
|
||||
In this test we verify that the MultiStepWorker is able to handle bonus
|
||||
tokens correctly. The test verifies that if a sequence has a
|
||||
bonus token then the MultiStepWorker is able to expand the batch by adding
|
||||
new sequences corresponding to the sequences with bonus tokens. The
|
||||
expanded batch is then used for predicting the next tokens.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
|
||||
block_size = 16
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
batch_size = 128
|
||||
multi_step_worker = create_worker(
|
||||
MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
multi_step_worker.set_include_gpu_probs_tensor()
|
||||
worker = create_worker(
|
||||
NPUWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
random.seed(seed)
|
||||
prompts = [[0] for _ in range(batch_size)]
|
||||
num_steps = 2
|
||||
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
|
||||
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
|
||||
multi_step_worker.execute_model = patch_execute_model_with_seeds(
|
||||
multi_step_worker, rand_seeds)
|
||||
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
|
||||
# Create the test continuations
|
||||
continuations = [[random.randint(0, 1000)] for _ in prompts]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Run single-step twice to generate 2 tokens. This
|
||||
# will simulate the bonus token case with the second token
|
||||
# being the bonus token.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
single_step_output: list[SamplerOutput] = []
|
||||
set_random_seed(seed)
|
||||
for _ in range(num_steps):
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
single_step_output.extend(
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list)))
|
||||
# Append output tokens to new sequence data.
|
||||
for i, seq_group_output in enumerate(single_step_output[-1]):
|
||||
continuations[i].append(seq_group_output.samples[0].output_token)
|
||||
|
||||
# Create continuations for the MultiStepWorker. The continuations have
|
||||
# 2 tokens in order to simulate the bonus token case.
|
||||
multi_step_continuations = []
|
||||
for continuation in continuations:
|
||||
multi_step_continuations.append(continuation[:2])
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=multi_step_continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Run multi-step and verify that the third token prediction is accurate
|
||||
# for all sequences.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
all_seq_ids = {i for i in range(batch_size)}
|
||||
multi_step_output, _ = multi_step_worker.sampler_output(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list),
|
||||
sample_len=1,
|
||||
seq_ids_with_bonus_token_in_last_step=all_seq_ids)
|
||||
for index, output in enumerate(multi_step_output[-1].outputs):
|
||||
assert (continuations[index][-1] == output.samples[0].output_token)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_multi_step_with_batch_expansion_incorrect_output():
|
||||
"""
|
||||
Tests the MultiStepWorker's ability to handle batch expansion with bonus
|
||||
tokens in a negative case scenario. This test provides the MultiStepWorker
|
||||
with a batch containing sequences with bonus tokens but specifies the
|
||||
sequence IDs with bonus tokens incorrectly. The test verifies that the
|
||||
MultiStepWorker generates correct tokens for the sequences where the
|
||||
sequence ID is specified correctly and incorrect tokens for those where
|
||||
the sequence ID is specified incorrectly.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
|
||||
block_size = 16
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
batch_size = 128
|
||||
multi_step_worker = create_worker(
|
||||
MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
multi_step_worker.set_include_gpu_probs_tensor()
|
||||
worker = create_worker(
|
||||
NPUWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
random.seed(seed)
|
||||
prompts = [[0] for _ in range(batch_size)]
|
||||
num_steps = 2
|
||||
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
|
||||
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
|
||||
multi_step_worker.execute_model = patch_execute_model_with_seeds(
|
||||
multi_step_worker, rand_seeds)
|
||||
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
|
||||
# Create the test continuations
|
||||
continuations = [[random.randint(0, 1000)] for _ in prompts]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
# Run single-step twice to generate 2 tokens. This
|
||||
# will simulate the bonus token case with the second token
|
||||
# being the bonus token.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
single_step_output: list[SamplerOutput] = []
|
||||
set_random_seed(seed)
|
||||
for _ in range(num_steps):
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
single_step_output.extend(
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list)))
|
||||
# Append output tokens to new sequence data.
|
||||
for i, seq_group_output in enumerate(single_step_output[-1]):
|
||||
continuations[i].append(seq_group_output.samples[0].output_token)
|
||||
|
||||
# Create continuations for the MultiStepWorker. The continuations have
|
||||
# 2 tokens in order to simulate the bonus token case.
|
||||
multi_step_continuations = []
|
||||
for continuation in continuations:
|
||||
multi_step_continuations.append(continuation[:2])
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=multi_step_continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Run multi-step. In this run INCORRECTLY specify that only the odd number
|
||||
# sequences have bonus tokens. Verify that with this setting the third token
|
||||
# prediction is accurate only for the odd numbered sequences. Also verify
|
||||
# that the prediction might be wrong for some of the even numbered
|
||||
# sequences.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
odd_seq_ids = {i for i in range(batch_size) if i % 2 != 0}
|
||||
multi_step_output, _ = multi_step_worker.sampler_output(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list),
|
||||
sample_len=1,
|
||||
seq_ids_with_bonus_token_in_last_step=odd_seq_ids)
|
||||
num_mismatch = 0
|
||||
for index, output in enumerate(multi_step_output[-1].outputs):
|
||||
if (index % 2) != 0:
|
||||
assert (continuations[index][-1] == output.samples[0].output_token)
|
||||
elif (continuations[index][-1] != output.samples[0].output_token):
|
||||
num_mismatch += 1
|
||||
# The prediction is accurate for some of the sequences even without proper
|
||||
# handling of the bonus tokens. Hence verify that the number of sequences
|
||||
# for which there is a mismatch is > 0.
|
||||
assert (num_mismatch > 0)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize('num_steps', [1, 2, 3, 4])
|
||||
def test_multi_step_correct_kvcache(num_steps):
|
||||
"""Verify that the KV cache of the draft model
|
||||
is correctly updated for sequences with bonus token.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = "JackFram/llama-68m"
|
||||
|
||||
block_size = 16
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
batch_size = 1
|
||||
|
||||
dtype = 'float16'
|
||||
multi_step_worker = create_worker(MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
dtype=dtype)
|
||||
multi_step_worker.set_include_gpu_probs_tensor()
|
||||
worker = create_worker(NPUWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
dtype=dtype)
|
||||
|
||||
prompts = [[0] for _ in range(batch_size)]
|
||||
# Already generate two tokens for the sequence
|
||||
# so that we can simulate the bonus token case
|
||||
multi_step_continuations = [[
|
||||
random.randint(0, 1000),
|
||||
random.randint(0, 1000)
|
||||
] for _ in prompts]
|
||||
final_prompt_lens = [len(prompt) + 2 + num_steps for prompt in prompts]
|
||||
|
||||
seq_ids_with_bonus_token_in_last_step = set(range(batch_size))
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=multi_step_continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Run multi-step.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
multi_step_worker.sampler_output(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list),
|
||||
sample_len=num_steps,
|
||||
seq_ids_with_bonus_token_in_last_step=
|
||||
seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
# Run single-step repeatedly.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
# Generate the kv cache for the bonus token first
|
||||
single_step_continuations = [c[:1] for c in multi_step_continuations]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=single_step_continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
single_step_output = worker.execute_model(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list))
|
||||
for _ in range(num_steps):
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=multi_step_continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
single_step_output = worker.execute_model(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list))
|
||||
|
||||
for i, seq_group_output in enumerate(single_step_output[-1]):
|
||||
multi_step_continuations[i].append(
|
||||
seq_group_output.samples[0].output_token)
|
||||
|
||||
# Verify that the KV cache of the single-step and
|
||||
# multi-step workers are the same.
|
||||
single_step_gpu_cache = worker.cache_engine[0].gpu_cache
|
||||
multi_step_gpu_cache = multi_step_worker.cache_engine[0].gpu_cache
|
||||
num_layers = len(single_step_gpu_cache)
|
||||
allclose = lambda a, b: torch.allclose( # noqa: E731
|
||||
a.npu(), b.npu(), rtol=1e-2, atol=1e-2)
|
||||
for i in range(num_layers):
|
||||
assert allclose(single_step_gpu_cache[i][0],
|
||||
multi_step_gpu_cache[i][0])
|
||||
assert allclose(single_step_gpu_cache[i][1],
|
||||
multi_step_gpu_cache[i][1])
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_draft_proposals_full_speculation_len():
|
||||
"""Verify Top1Proposer correctly handles case where all sequences
|
||||
can speculate.
|
||||
"""
|
||||
k = 10
|
||||
batch_size = 32
|
||||
vocab_size = 32_000
|
||||
device = 'npu:0'
|
||||
|
||||
draft_worker = MagicMock()
|
||||
proposer = Top1Proposer(
|
||||
worker=draft_worker,
|
||||
device=device,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=2048,
|
||||
)
|
||||
draft_worker.sampler_output.return_value = [
|
||||
SamplerOutput(
|
||||
outputs=[],
|
||||
sampled_token_probs=torch.rand(batch_size,
|
||||
vocab_size,
|
||||
device=device,
|
||||
dtype=torch.float32),
|
||||
logprobs=torch.rand(batch_size,
|
||||
vocab_size,
|
||||
device=device,
|
||||
dtype=torch.float32),
|
||||
sampled_token_ids=torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, ),
|
||||
device=device,
|
||||
dtype=torch.long),
|
||||
) for _ in range(k)
|
||||
], True
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k),
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
|
||||
|
||||
assert proposals.proposal_lens.shape == torch.Size([batch_size])
|
||||
assert proposals.proposal_lens.tolist() == [k for _ in range(batch_size)]
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_draft_proposals_no_speculations():
|
||||
"""Verify Top1Proposer correctly handles case where no sequences
|
||||
can speculate.
|
||||
"""
|
||||
k = 10
|
||||
batch_size = 32
|
||||
vocab_size = 32_000
|
||||
device = 'npu:0'
|
||||
prompt_len = 10
|
||||
|
||||
draft_worker = MagicMock()
|
||||
proposer = Top1Proposer(
|
||||
worker=draft_worker,
|
||||
device=device,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=prompt_len + k - 1,
|
||||
)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
prompt_len=prompt_len)
|
||||
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k),
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
|
||||
|
||||
assert proposals.proposal_lens.shape == torch.Size([batch_size])
|
||||
assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)]
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_draft_proposals_mixed_k():
|
||||
"""Verify Top1Proposer correctly handles case some sequences can
|
||||
speculate and some can't.
|
||||
"""
|
||||
k = 10
|
||||
batch_size = 32
|
||||
vocab_size = 32_000
|
||||
device = 'npu:0'
|
||||
|
||||
small_prompt_len = 5
|
||||
long_prompt_len = 10
|
||||
prev_output_token_len = 20
|
||||
|
||||
expected_num_proposal_seqs = 6
|
||||
expected_num_no_proposal_seqs = batch_size - expected_num_proposal_seqs
|
||||
|
||||
prompt_len = [
|
||||
small_prompt_len for _ in range(expected_num_proposal_seqs - 1)
|
||||
] + [long_prompt_len
|
||||
for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len]
|
||||
|
||||
draft_worker = MagicMock()
|
||||
proposer = Top1Proposer(
|
||||
worker=draft_worker,
|
||||
device=device,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=long_prompt_len + prev_output_token_len + k - 1,
|
||||
)
|
||||
|
||||
draft_worker.sampler_output.return_value = [
|
||||
SamplerOutput(
|
||||
outputs=[],
|
||||
sampled_token_probs=torch.rand(expected_num_proposal_seqs,
|
||||
vocab_size,
|
||||
device=device,
|
||||
dtype=torch.float32),
|
||||
logprobs=torch.rand(expected_num_proposal_seqs,
|
||||
vocab_size,
|
||||
device=device,
|
||||
dtype=torch.float32),
|
||||
sampled_token_ids=torch.randint(
|
||||
low=0,
|
||||
high=vocab_size,
|
||||
size=(expected_num_proposal_seqs, ),
|
||||
device=device,
|
||||
dtype=torch.long),
|
||||
) for _ in range(k)
|
||||
], True
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(
|
||||
batch_size,
|
||||
k,
|
||||
prompt_len=prompt_len,
|
||||
prev_output_token_len=prev_output_token_len,
|
||||
)
|
||||
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k),
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
|
||||
|
||||
assert proposals.proposal_lens.shape == torch.Size([batch_size])
|
||||
assert proposals.proposal_lens.tolist() == [
|
||||
k for _ in range(expected_num_proposal_seqs - 1)
|
||||
] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k]
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_use_draft_model_runner_advance_step():
|
||||
"""Verify that draft model runner triggers advance step
|
||||
when applicable.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
|
||||
k = 5
|
||||
batch_size = 32
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
worker = create_worker(
|
||||
MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
|
||||
# Mock "_gpu_advance_step" to raise an exception when called.
|
||||
exception_secret = "artificial stop"
|
||||
worker.model_runner._gpu_advance_step = MagicMock()
|
||||
worker.model_runner._gpu_advance_step.side_effect = ValueError(
|
||||
exception_secret)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# Fallback (should not call) when num_steps=1.
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k,
|
||||
num_steps=1)
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
# Expect exception if _gpu_advance_step is called.
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k,
|
||||
num_steps=k)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
call_args_list = worker.model_runner._gpu_advance_step.call_args_list
|
||||
assert len(call_args_list) == 1
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_expand_execute_model_request_sync_with_expand_hidden_states():
|
||||
"""
|
||||
In this test we verify that the logic for expanding the
|
||||
seq_group_metadata_list remains in sync with the expansion logic of
|
||||
the HiddenStates in _expand_execute_model_request.
|
||||
"""
|
||||
k = 5
|
||||
batch_size = 16
|
||||
seq_with_bonus_token_in_last_step = [1, 3, 8, 10, 13, 15]
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
execute_model_request = ExecuteModelRequest(
|
||||
seq_group_metadata_list,
|
||||
previous_hidden_states=HiddenStates(
|
||||
torch.arange(batch_size), seq_group_metadata_list,
|
||||
torch.arange(batch_size, 2 * batch_size)))
|
||||
|
||||
expanded_execute_model_request, orig_seq_group_ids = MultiStepWorker.\
|
||||
_expand_execute_model_request(execute_model_request,
|
||||
seq_with_bonus_token_in_last_step)
|
||||
|
||||
all_seq_ids = torch.tensor(
|
||||
get_all_seq_ids(
|
||||
expanded_execute_model_request.seq_group_metadata_list))
|
||||
ref_expanded_hidden_states = all_seq_ids + batch_size
|
||||
ref_expanded_hidden_states[orig_seq_group_ids] -= batch_size
|
||||
|
||||
assert (ref_expanded_hidden_states == expanded_execute_model_request.
|
||||
previous_hidden_states.hidden_states).all().item()
|
||||
@@ -1,237 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/tests/spec_decode/test_ngram_worker.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import torch
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
|
||||
from tests.e2e.long_term.spec_decode_v0.utils import (
|
||||
create_seq_group_metadata_from_prompts, create_worker)
|
||||
|
||||
|
||||
def test_ngram_algo_correctness_for_single_no_match():
|
||||
"""Verify our ngram algo find the right candidate in the prompt
|
||||
|
||||
For the scenario cannot find any candidate in one single batch
|
||||
"""
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
vocab_size = 32_000
|
||||
device = 'npu:0'
|
||||
|
||||
ngram_worker = create_worker(
|
||||
NGramWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
|
||||
proposer = Top1Proposer(
|
||||
worker=ngram_worker,
|
||||
device=device,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=20,
|
||||
)
|
||||
|
||||
# set ngram window [1, 3], which is window=1/2/3
|
||||
ngram_worker.set_ngram_window_size(1, 3)
|
||||
|
||||
prompts = [
|
||||
# shall find no candidate
|
||||
[1, 2, 3, 4, 5, 6, 7],
|
||||
]
|
||||
|
||||
proposal_len = 5
|
||||
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=proposal_len),
|
||||
seq_ids_with_bonus_token_in_last_step=None)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([1, proposal_len])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([1, proposal_len])
|
||||
assert proposals.proposal_lens.shape == torch.Size([1])
|
||||
assert proposals.proposal_lens.tolist() == [0]
|
||||
|
||||
|
||||
def test_ngram_algo_correctness_for_batches_not_match_all():
|
||||
"""Verify our ngram algo find the right candidate in the prompt
|
||||
|
||||
For the scenario find some candidate not full in batchs
|
||||
"""
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
vocab_size = 32_000
|
||||
device = 'npu:0'
|
||||
|
||||
ngram_worker = create_worker(
|
||||
NGramWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
|
||||
proposer = Top1Proposer(
|
||||
worker=ngram_worker,
|
||||
device=device,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=20,
|
||||
)
|
||||
|
||||
# set ngram window [1, 3], which is window=1/2/3
|
||||
ngram_worker.set_ngram_window_size(1, 3)
|
||||
|
||||
prompts = [
|
||||
# shall find no candidate
|
||||
[1, 2, 3, 4, 5, 6, 7],
|
||||
# shall find candidate 12,13,14,15,16
|
||||
[11, 12, 13, 14, 15, 16, 11],
|
||||
# shall find candidate 23,24,25,26,21
|
||||
[21, 21, 22, 23, 24, 25, 26, 21, 22],
|
||||
# shall find candidate 34,35,36,37,38
|
||||
[31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
|
||||
# shall find no candidate as exceed max_proposal_len
|
||||
[
|
||||
31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 33, 34, 35, 36, 37,
|
||||
38, 31, 32, 33
|
||||
],
|
||||
]
|
||||
|
||||
proposal_len = 5
|
||||
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
for sg in seq_group_metadata_list:
|
||||
sg.is_prompt = False
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=proposal_len),
|
||||
seq_ids_with_bonus_token_in_last_step=None)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([5, proposal_len])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([5, proposal_len])
|
||||
assert proposals.proposal_lens.shape == torch.Size([5])
|
||||
|
||||
# the first sequence has no match so proposal_len should be overwritten to 0
|
||||
assert proposals.proposal_lens.tolist(
|
||||
) == [0] + [proposal_len for _ in range(3)] + [0]
|
||||
|
||||
for i in range(proposal_len):
|
||||
assert proposals.proposal_token_ids[0][i] == -1
|
||||
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 1]
|
||||
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 3]
|
||||
assert proposals.proposal_token_ids[3][i] == prompts[3][i + 5]
|
||||
assert proposals.proposal_token_ids[4][i] == -1
|
||||
|
||||
|
||||
def test_ngram_algo_correctness_for_batches_match_all():
|
||||
"""Verify our ngram algo find the right candidate in the prompt
|
||||
|
||||
For the scenario find candidate in all batches
|
||||
"""
|
||||
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
vocab_size = 32_000
|
||||
device = 'npu:0'
|
||||
|
||||
ngram_worker = create_worker(
|
||||
NGramWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
|
||||
proposer = Top1Proposer(
|
||||
worker=ngram_worker,
|
||||
device=device,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=20,
|
||||
)
|
||||
|
||||
# set ngram window [0, 3], which is window=1/2/3
|
||||
ngram_worker.set_ngram_window_size(1, 3)
|
||||
|
||||
prompts = [
|
||||
# shall find candidate 12,13,14,15,16
|
||||
[11, 12, 13, 14, 15, 16, 11],
|
||||
# shall find candidate 23,24,25,26,21
|
||||
[21, 21, 22, 23, 24, 25, 26, 21, 22],
|
||||
# shall find candidate 34,35,36,37,38
|
||||
[31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
|
||||
]
|
||||
|
||||
proposal_len = 5
|
||||
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Normally drafter is run on decode requests only; here we check the output
|
||||
# of the ngram worker as it is the sole proposer that has no forward.
|
||||
for sg in seq_group_metadata_list:
|
||||
sg.is_prompt = False
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=proposal_len),
|
||||
seq_ids_with_bonus_token_in_last_step=None)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([3, proposal_len])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([3, proposal_len])
|
||||
assert proposals.proposal_lens.shape == torch.Size([3])
|
||||
|
||||
assert proposals.proposal_lens.tolist() == [proposal_len for _ in range(3)]
|
||||
|
||||
for i in range(proposal_len):
|
||||
assert proposals.proposal_token_ids[0][i] == prompts[0][i + 1]
|
||||
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 3]
|
||||
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 5]
|
||||
@@ -1,958 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/tests/spec_decode/test_spec_decode_worker.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest, SequenceOutput
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
|
||||
SpecDecodeWorkerMetrics)
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
|
||||
split_num_cache_blocks_evenly)
|
||||
|
||||
from tests.e2e.long_term.spec_decode_v0.test_utils import \
|
||||
mock_spec_decode_sampler
|
||||
from tests.e2e.long_term.spec_decode_v0.utils import (
|
||||
create_batch, create_sampler_output_list, create_worker, mock_worker)
|
||||
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
|
||||
from vllm_ascend.worker.worker import NPUWorker
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_correctly_calls_draft_model(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker calls the draft worker with correct
|
||||
inputs. Everything else is mocked out.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker,
|
||||
target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
exception_secret = 'artificial stop'
|
||||
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
call_args_list = draft_worker.get_spec_proposals.call_args_list
|
||||
assert len(call_args_list) == 1
|
||||
|
||||
for args, _ in call_args_list:
|
||||
actual_execute_model_data = args[0]
|
||||
assert actual_execute_model_data == execute_model_req
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_batch_expansion_correctly_calls_target_model(
|
||||
k: int, batch_size: int, acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker calls the target model with correct
|
||||
inputs with batch expansion. Everything else is mocked out.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
||||
target_worker = mock_worker(use_spec=False)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
draft_worker.device = 'npu'
|
||||
target_worker.device = 'npu'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker,
|
||||
target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector,
|
||||
disable_mqa_scorer=True)
|
||||
worker.init_device()
|
||||
|
||||
vocab_size = 32_000
|
||||
|
||||
proposal_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64,
|
||||
device='npu')
|
||||
proposal_probs = torch.rand(batch_size,
|
||||
k,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='npu') * k
|
||||
|
||||
seq_group_metadata_list, prompts, prev_output_tokens = create_batch(
|
||||
batch_size, k)
|
||||
|
||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_token_ids,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens)
|
||||
|
||||
exception_secret = 'artificial stop'
|
||||
target_worker.execute_model.side_effect = ValueError(exception_secret)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
|
||||
seen_contexts: list[list[int]] = []
|
||||
|
||||
call_args_list = target_worker.execute_model.call_args_list
|
||||
assert len(call_args_list) == 1
|
||||
for _, kwargs in call_args_list:
|
||||
seq_group_metadata_list = kwargs[
|
||||
"execute_model_req"].seq_group_metadata_list
|
||||
|
||||
assert len(seq_group_metadata_list) == (k + 1) * batch_size
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
for seq_data in seq_group_metadata.seq_data.values():
|
||||
seen_contexts.append(seq_data.get_token_ids())
|
||||
|
||||
expected_seen_contexts: list[list[int]] = []
|
||||
|
||||
for prompt, prev_generated, draft_tokens in zip(
|
||||
prompts, prev_output_tokens, proposal_token_ids.tolist()):
|
||||
|
||||
for i in range(len(draft_tokens) + 1):
|
||||
expected_seen_contexts.append(prompt + prev_generated +
|
||||
draft_tokens[:i])
|
||||
|
||||
seen_contexts.sort()
|
||||
expected_seen_contexts.sort()
|
||||
assert expected_seen_contexts == seen_contexts
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker calls the rejection sampler with
|
||||
correct inputs. Everything else is mocked out.
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
|
||||
draft_worker = mock_worker(cls=MultiStepWorker,
|
||||
vocab_size=vocab_size,
|
||||
use_spec=False)
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
draft_worker.device = 'npu'
|
||||
target_worker.device = 'npu'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
spec_decode_sampler,
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
proposal_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64,
|
||||
device='npu')
|
||||
proposal_probs = torch.rand(batch_size,
|
||||
k,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
|
||||
proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='npu') * k
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_token_ids,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens)
|
||||
|
||||
target_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(1, batch_size * (k + 1)),
|
||||
dtype=torch.int64,
|
||||
device='npu')
|
||||
target_token_probs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
target_token_logprobs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs,
|
||||
target_token_logprobs)
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
exception_secret = 'artificial stop'
|
||||
|
||||
spec_decode_sampler.side_effect = ValueError(exception_secret)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
|
||||
assert len(spec_decode_sampler.call_args_list) == 1
|
||||
_, kwargs = spec_decode_sampler.call_args_list[0]
|
||||
actual = SimpleNamespace(**kwargs)
|
||||
|
||||
assert torch.equal(actual.bonus_token_ids,
|
||||
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
|
||||
assert torch.equal(actual.target_with_bonus_probs,
|
||||
target_token_probs.reshape(batch_size, k + 1, -1))
|
||||
assert torch.equal(actual.draft_token_ids, proposal_token_ids)
|
||||
assert torch.equal(actual.draft_probs, proposal_probs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_correctly_formats_output(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker formats sampler output correctly.
|
||||
Everything else is mocked out.
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
|
||||
draft_worker = mock_worker(cls=MultiStepWorker,
|
||||
vocab_size=vocab_size,
|
||||
use_spec=False)
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
draft_worker.device = 'npu'
|
||||
target_worker.device = 'npu'
|
||||
|
||||
set_random_seed(1)
|
||||
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
spec_decode_sampler,
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
proposal_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64,
|
||||
device='npu')
|
||||
proposal_probs = torch.rand(batch_size,
|
||||
k,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
|
||||
proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='npu') * k
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_token_ids,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens)
|
||||
|
||||
target_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(1, batch_size * (k + 1)),
|
||||
dtype=torch.int64,
|
||||
device='npu')
|
||||
target_token_probs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
target_token_logprobs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs,
|
||||
target_token_logprobs)
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
spec_decode_sampler_output = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k + 1),
|
||||
dtype=torch.int64,
|
||||
device='npu')
|
||||
for i in range(batch_size):
|
||||
minimum_accepted_tokens = 1
|
||||
spec_decode_sampler_output[i][
|
||||
-random.randint(minimum_accepted_tokens, k + 1):] = -1
|
||||
|
||||
spec_decode_sampler.return_value = spec_decode_sampler_output
|
||||
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
|
||||
expected_output = create_sampler_output_list(
|
||||
token_ids=spec_decode_sampler_output.transpose(0, 1),
|
||||
probs=[None for _ in range(k + 1)],
|
||||
logprobs=[None for _ in range(k + 1)])
|
||||
|
||||
seq_ids = [
|
||||
next(iter(seq_group_metadata.seq_data.keys()))
|
||||
for seq_group_metadata in seq_group_metadata_list
|
||||
]
|
||||
actual_output_by_seq: dict[int, list[SequenceOutput]] = {
|
||||
seq_id: []
|
||||
for seq_id in seq_ids
|
||||
}
|
||||
expected_output_by_seq: dict[int, list[SequenceOutput]] = {
|
||||
seq_id: []
|
||||
for seq_id in seq_ids
|
||||
}
|
||||
|
||||
for step in output:
|
||||
for seq_group in step:
|
||||
for sample in seq_group.samples:
|
||||
seq_id = sample.parent_seq_id
|
||||
actual_output_by_seq[seq_id].append(sample)
|
||||
|
||||
for step in expected_output:
|
||||
for seq_group in step:
|
||||
for sample in seq_group.samples:
|
||||
seq_id = sample.parent_seq_id
|
||||
expected_output_by_seq[seq_id].append(sample)
|
||||
|
||||
all_seen_seq_ids = set(
|
||||
list(actual_output_by_seq.keys()) +
|
||||
list(expected_output_by_seq.keys()))
|
||||
for seq_id in all_seen_seq_ids:
|
||||
actual_by_step = actual_output_by_seq[seq_id]
|
||||
expected_by_step = expected_output_by_seq[seq_id]
|
||||
|
||||
for i in range(k + 1):
|
||||
if i >= len(actual_by_step):
|
||||
assert expected_by_step[i].output_token == -1
|
||||
continue
|
||||
assert actual_by_step[i].output_token == expected_by_step[
|
||||
i].output_token
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2])
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
@pytest.mark.parametrize('returns_metrics', [True, False])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker collects metrics.
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
|
||||
draft_worker = mock_worker(cls=MultiStepWorker,
|
||||
vocab_size=vocab_size,
|
||||
use_spec=False)
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
draft_worker.device = 'npu'
|
||||
target_worker.device = 'npu'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
spec_decode_sampler,
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
proposal_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64,
|
||||
device='npu')
|
||||
proposal_probs = torch.rand(batch_size,
|
||||
k,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
|
||||
proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='npu') * k
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_token_ids,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens)
|
||||
|
||||
target_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(1, batch_size * (k + 1)),
|
||||
dtype=torch.int64,
|
||||
device='npu')
|
||||
target_token_probs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
target_token_logprobs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs,
|
||||
target_token_logprobs)
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
spec_decode_sampler_output = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k + 1),
|
||||
dtype=torch.int64,
|
||||
device='npu')
|
||||
for i in range(batch_size):
|
||||
minimum_accepted_tokens = 1
|
||||
spec_decode_sampler_output[i][
|
||||
-random.randint(minimum_accepted_tokens, k + 1):] = -1
|
||||
spec_decode_sampler.return_value = spec_decode_sampler_output
|
||||
|
||||
mock_rejsample_metrics = MagicMock(
|
||||
spec=SpecDecodeWorkerMetrics) if returns_metrics else None
|
||||
metrics_collector.maybe_collect_rejsample_metrics.return_value = (
|
||||
mock_rejsample_metrics)
|
||||
|
||||
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
|
||||
|
||||
call_args_list = (
|
||||
metrics_collector.maybe_collect_rejsample_metrics.call_args_list)
|
||||
assert len(call_args_list) == 1
|
||||
args, kwargs = call_args_list[0]
|
||||
assert args[0] == k or kwargs.get('k', -1) == k
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [0])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_k_equals_zero(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify that the SpecDecodeWorker calls the draft and target workers
|
||||
when k is zero. This happens during prefill.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
sampler_output = MagicMock(spec=SamplerOutput)
|
||||
sampler_output.hidden_states = None
|
||||
target_worker.execute_model.return_value = [sampler_output]
|
||||
|
||||
draft_worker.device = 'npu'
|
||||
target_worker.device = 'npu'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(
|
||||
proposer_worker=draft_worker,
|
||||
scorer_worker=target_worker,
|
||||
spec_decode_sampler=mock_spec_decode_sampler(
|
||||
acceptance_sampler_method),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector,
|
||||
)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
prev_output_token_len=0)
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
|
||||
|
||||
out = worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
assert len(out) == 1, f"expected only one token output when {k=}"
|
||||
assert out[0].sampled_token_probs is None, (
|
||||
"expect gpu tensor references to be None")
|
||||
assert out[
|
||||
0].sampled_token_ids is None, "expect gpu tensor references to be None"
|
||||
|
||||
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [0, 5])
|
||||
@pytest.mark.parametrize('batch_size', [0])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_empty_input_batch(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify that the SpecDecodeWorker calls the draft and target workers
|
||||
when the input batch is empty. This can happen if the engine communicates
|
||||
to the workers information without scheduling a batch.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
sampler_output = MagicMock(spec=SamplerOutput)
|
||||
sampler_output.hidden_states = None
|
||||
target_worker.execute_model.return_value = [sampler_output]
|
||||
|
||||
draft_worker.device = 'npu'
|
||||
target_worker.device = 'npu'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(
|
||||
proposer_worker=draft_worker,
|
||||
scorer_worker=target_worker,
|
||||
spec_decode_sampler=mock_spec_decode_sampler(
|
||||
acceptance_sampler_method),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector,
|
||||
)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
prev_output_token_len=0)
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
|
||||
|
||||
out = worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
assert len(out) == 1, f"expected only one token output when {k=}"
|
||||
assert out[0].sampled_token_probs is None, (
|
||||
"expect gpu tensor references to be None")
|
||||
assert out[
|
||||
0].sampled_token_ids is None, "expect gpu tensor references to be None"
|
||||
|
||||
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_init_device(acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
|
||||
well as other GPU initialization.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
||||
target_worker = mock_worker(use_spec=False)
|
||||
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
worker = SpecDecodeWorker(
|
||||
proposer_worker=draft_worker,
|
||||
scorer_worker=target_worker,
|
||||
spec_decode_sampler=spec_decode_sampler,
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector,
|
||||
)
|
||||
worker.init_device()
|
||||
|
||||
draft_worker.init_device.assert_called_once()
|
||||
|
||||
target_worker.init_device.assert_called_once()
|
||||
|
||||
metrics_collector.init_tensors.assert_called_once()
|
||||
spec_decode_sampler.init_tensors.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_initialize_cache(acceptance_sampler_method):
|
||||
"""Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer
|
||||
workers.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
worker = SpecDecodeWorker(proposer_worker=draft_worker,
|
||||
scorer_worker=target_worker,
|
||||
spec_decode_sampler=mock_spec_decode_sampler(
|
||||
acceptance_sampler_method),
|
||||
metrics_collector=metrics_collector)
|
||||
|
||||
kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
|
||||
worker.initialize_cache(**kwargs)
|
||||
|
||||
draft_worker.initialize_cache.assert_called_once_with(**kwargs)
|
||||
target_worker.initialize_cache.assert_called_once_with(**kwargs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('available_gpu_blocks', [1, 1024])
|
||||
@pytest.mark.parametrize('available_cpu_blocks', [500])
|
||||
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
|
||||
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_determine_num_available_blocks(available_gpu_blocks: int,
|
||||
available_cpu_blocks: int,
|
||||
target_cache_block_size_bytes: int,
|
||||
draft_kv_size_bytes: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks.
|
||||
Specifically, it should run profiling in the scorer worker, and then evenly
|
||||
split the blocks between proposer and scorer worker.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
target_worker.determine_num_available_blocks.return_value = (
|
||||
available_gpu_blocks, available_cpu_blocks)
|
||||
target_worker.get_cache_block_size_bytes.return_value = (
|
||||
target_cache_block_size_bytes)
|
||||
draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes
|
||||
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker, target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||
|
||||
num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks()
|
||||
|
||||
target_worker.determine_num_available_blocks.assert_called_once()
|
||||
assert num_cpu_blocks == available_cpu_blocks
|
||||
|
||||
assert num_gpu_blocks == split_num_cache_blocks_evenly(
|
||||
target_cache_block_size_bytes, draft_kv_size_bytes,
|
||||
available_gpu_blocks)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('available_gpu_blocks',
|
||||
list(range(20)) + [1024, 1024**2])
|
||||
@pytest.mark.parametrize('target_cache_block_size_bytes',
|
||||
[2 * 2 * 4096, 2 * 2 * 8192])
|
||||
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
|
||||
target_cache_block_size_bytes: int,
|
||||
draft_kv_size_bytes: int):
|
||||
"""Verify split_num_cache_blocks_evenly does not exceed original memory
|
||||
allocation in bytes.
|
||||
"""
|
||||
num_blocks = split_num_cache_blocks_evenly(target_cache_block_size_bytes,
|
||||
draft_kv_size_bytes,
|
||||
available_gpu_blocks)
|
||||
assert (num_blocks * target_cache_block_size_bytes) + (
|
||||
num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks *
|
||||
target_cache_block_size_bytes)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_populate_seq_ids_with_bonus_tokens():
|
||||
"""
|
||||
Verify that a call to _create_output_sampler_list correctly updates
|
||||
seq_with_bonus_token_in_last_step.
|
||||
|
||||
seq_with_bonus_token_in_last_step is an internal data structure in
|
||||
SpecDecodeWorker that tracks the sequence IDs which are assigned bonus
|
||||
tokens by the target model in their last forward pass. This state is
|
||||
maintained only for models relying on the KV cache, such as those using
|
||||
the MultiStepWorker.
|
||||
"""
|
||||
batch_size = 10
|
||||
k = 5
|
||||
vocab_size = 10000
|
||||
num_sequences_with_bonus_tokens = 5
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
|
||||
target_worker.device = 'npu'
|
||||
|
||||
set_random_seed(1)
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
draft_worker.device = 'npu'
|
||||
# The sequence_ids attached to each sequence in the batch.
|
||||
# The sequence at index i has seq_id assigned_seq_ids[i]
|
||||
assigned_seq_ids = list(range(batch_size))
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
seq_ids=assigned_seq_ids,
|
||||
prev_output_token_len=10)
|
||||
target_token_logprobs = torch.rand(batch_size, (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
accepted_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, (k + 1)),
|
||||
dtype=torch.int64,
|
||||
device='npu')
|
||||
expected_request_id_seq_ids_mapping: dict[str, set[int]] = defaultdict(set)
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
for seq_id in seq_group_metadata.seq_data:
|
||||
expected_request_id_seq_ids_mapping[
|
||||
seq_group_metadata.request_id].add(seq_id)
|
||||
# Generate a random sample of sequence indexes with bonus tokens
|
||||
seq_indexes_with_bonus_tokens = random.sample(
|
||||
range(batch_size), num_sequences_with_bonus_tokens)
|
||||
# Create a mask that is True for indices in seq_indexes_with_bonus_tokens
|
||||
mask = torch.ones(batch_size, dtype=torch.bool, device='npu')
|
||||
mask[seq_indexes_with_bonus_tokens] = False
|
||||
# Set the last token ID to -1 for all indices not in
|
||||
# seq_indexes_with_bonus_tokens to indicate the lack of bonus token in
|
||||
# those indices.
|
||||
accepted_token_ids[mask, -1:] = -1
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
mock_spec_decode_sampler("rejection_sampler"),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
# Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs.
|
||||
# This set includes all sequence IDs in the batch as well as an additional
|
||||
# `num_extra_sequence_ids` sequence IDs. Note that the sequence IDs are in
|
||||
# the range [0, batch_size + num_extra_sequence_ids).
|
||||
num_extra_sequence_ids = 10
|
||||
worker._seq_with_bonus_token_in_last_step = set(
|
||||
range(batch_size + num_extra_sequence_ids))
|
||||
worker._create_output_sampler_list(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
accepted_token_ids=accepted_token_ids,
|
||||
target_logprobs=target_token_logprobs,
|
||||
prompt_logprobs=None,
|
||||
k=k,
|
||||
stage_times=(0, 0, 0))
|
||||
# Verify that _seq_with_bonus_token_in_last_step contains the following:
|
||||
# 1. Sequence IDs that were already present in
|
||||
# _seq_with_bonus_token_in_last_step but were not part of the current
|
||||
# batch are retained.
|
||||
# 2. Of the sequence IDs present in the current batch, only those with a
|
||||
# bonus token are retained in _seq_with_bonus_token_in_last_step.
|
||||
# Sequence IDs that are present in the current batch but do not have
|
||||
# bonus tokens are removed from _seq_with_bonus_token_in_last_step.
|
||||
expected_seq_ids_with_bonus_tokens = \
|
||||
set([assigned_seq_ids[i] for i in seq_indexes_with_bonus_tokens])
|
||||
additional_sequence_ids = \
|
||||
set(range(batch_size, batch_size + num_extra_sequence_ids))
|
||||
assert worker._seq_with_bonus_token_in_last_step == \
|
||||
expected_seq_ids_with_bonus_tokens.union(additional_sequence_ids)
|
||||
assert worker._request_id_seq_id_mapping == \
|
||||
expected_request_id_seq_ids_mapping
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_handle_finished_requests():
|
||||
"""
|
||||
Test to verify that finished request IDs are appropriately processed to
|
||||
update the internal state of the SpecDecodeWorker.
|
||||
|
||||
This test initializes the SpecDecodeWorker with mock data, marks certain
|
||||
requests as finished, and ensures that the corresponding sequence IDs are
|
||||
correctly removed from the internal mappings.
|
||||
"""
|
||||
batch_size = 32
|
||||
k = 3
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker,
|
||||
mock_spec_decode_sampler("rejection_sampler"),
|
||||
metrics_collector)
|
||||
# Initialize the request_id_seq_id_mapping mapping dict with a few fake
|
||||
# request ids and corresponding sequence ids.
|
||||
worker._request_id_seq_id_mapping = \
|
||||
{'request-1': {1,2,3}, 'request-2': {4,5,6,7},
|
||||
'request-3': {8,9}, 'request-4': {10,11}}
|
||||
# Initialize seq_with_bonus_token_in_last_step with a few fake
|
||||
# sequence ids.
|
||||
worker._seq_with_bonus_token_in_last_step = {1, 4, 5, 8, 9, 10}
|
||||
exception_secret = 'artificial stop'
|
||||
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
# Mark requests with ids request-1 and request-3 as finished.
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k,
|
||||
finished_requests_ids=['request-1', 'request-3'])
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
# Verify that request-1 and request-3 are removed from
|
||||
# request_id_seq_id_mapping
|
||||
assert worker._request_id_seq_id_mapping == \
|
||||
{'request-2': {4,5,6,7}, 'request-4': {10,11}}
|
||||
# Verify that all sequence ids corresponding to 'request-1'
|
||||
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
|
||||
assert worker._seq_with_bonus_token_in_last_step == \
|
||||
{4,5,10}
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [3])
|
||||
@pytest.mark.parametrize('batch_size', [2, 32])
|
||||
@pytest.mark.parametrize("batch_composition",
|
||||
["prefill_only", "decode_only", "mixed"])
|
||||
@torch.inference_mode()
|
||||
def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
|
||||
"""
|
||||
Verify SpecDecodeWorker calls match the expected flow.
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
mock_spec_decode_sampler("rejection_sampler"),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
exception_secret = 'artificial stop'
|
||||
worker.scorer = mock_worker(BatchExpansionTop1Scorer)
|
||||
worker.scorer.score_proposals.side_effect = ValueError(exception_secret)
|
||||
|
||||
# Create batch with combination of terminal/non-terminal prefill chunks
|
||||
# and decodes (different seq_ids).
|
||||
decodes, _, _ = create_batch(batch_size, k)
|
||||
# Pre-chunking here, get 'batch_size' chunks.
|
||||
prefill, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
prefill_chunk_size=4,
|
||||
seq_ids=list(range(batch_size,
|
||||
batch_size * 2)))
|
||||
|
||||
if batch_composition == "prefill_only":
|
||||
n_prefills = batch_size
|
||||
elif batch_composition == "decode_only":
|
||||
n_prefills = 0
|
||||
else:
|
||||
n_prefills = random.randint(1, batch_size - 1)
|
||||
n_decodes = batch_size - n_prefills
|
||||
|
||||
prefill = random.sample(prefill, n_prefills)
|
||||
decodes = random.sample(decodes, n_decodes)
|
||||
target_group_metadata_list = prefill + decodes
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=target_group_metadata_list,
|
||||
# For prefill only batches we expect num_lookahead_slots = 0.
|
||||
num_lookahead_slots=k if n_decodes > 0 else 0)
|
||||
|
||||
target_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(1, batch_size * (k + 1)),
|
||||
dtype=torch.int64,
|
||||
device='npu')
|
||||
target_token_probs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
target_token_logprobs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs,
|
||||
target_token_logprobs)
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
if not len(decodes):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
# no spec run (prefill only)
|
||||
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
else:
|
||||
# Decode-only run OR mixed batch, scorer call fails (it's mocked)
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
# but first draft still counted
|
||||
assert draft_worker.get_spec_proposals.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="TODO revert me after fix it by CMQ")
|
||||
def test_correctly_load_weight_for_eagle():
|
||||
"""
|
||||
Verify SpecDecodeWorker loads lm_head weight for eagle correctly.
|
||||
"""
|
||||
seed = 100
|
||||
block_size = 32
|
||||
num_gpu_blocks = 8096 // block_size
|
||||
target_worker = create_worker(
|
||||
NPUWorker,
|
||||
"JackFram/llama-68m",
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
draft_worker = create_worker(
|
||||
MultiStepWorker,
|
||||
"abhigoyal/vllm-eagle-llama-68m-random",
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
|
||||
spec_decode_sampler = mock_spec_decode_sampler("rejection_sampler")
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
spec_decode_sampler,
|
||||
disable_logprobs=False)
|
||||
worker.proposer_worker.maybe_load_lm_head_weight(
|
||||
target_worker.model_runner.model.lm_head.weight.data)
|
||||
assert torch.allclose(
|
||||
worker.proposer_worker.worker.model_runner.model.lm_head.weight.data,
|
||||
worker.scorer_worker.model_runner.model.lm_head.weight.data)
|
||||
@@ -1,165 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/tests/spec_decode/test_utils.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.layers.sampler import _get_ranks
|
||||
from vllm.model_executor.layers.typical_acceptance_sampler import \
|
||||
TypicalAcceptanceSampler
|
||||
from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
|
||||
from vllm.spec_decode.util import (get_sampled_token_logprobs,
|
||||
split_batch_by_proposal_len)
|
||||
|
||||
|
||||
def test_get_all_seq_ids():
|
||||
"""Verify get_all_seq_ids extracts all seq ids.
|
||||
"""
|
||||
expected_seq_ids = list(range(10)) + list(range(100, 110))
|
||||
|
||||
seq_group_metadata_list = [
|
||||
SequenceGroupMetadata(
|
||||
request_id=str(seq_id),
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
seq_id: MagicMock(),
|
||||
},
|
||||
sampling_params=MagicMock(),
|
||||
block_tables={
|
||||
seq_id: MagicMock(),
|
||||
},
|
||||
lora_request=None,
|
||||
) for seq_id in expected_seq_ids
|
||||
]
|
||||
|
||||
actual_seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
||||
assert actual_seq_ids == expected_seq_ids
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_sequence_group_metadata():
|
||||
seq_ids = list(range(3))
|
||||
return [
|
||||
SequenceGroupMetadata(
|
||||
request_id=str(i),
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
i: MagicMock(),
|
||||
},
|
||||
sampling_params=MagicMock(),
|
||||
block_tables={
|
||||
i: MagicMock(),
|
||||
},
|
||||
lora_request=None,
|
||||
) for i in seq_ids
|
||||
]
|
||||
|
||||
|
||||
def test_filter_zero_length_proposals(fake_sequence_group_metadata):
|
||||
proposal_lens = [0, 1, 0]
|
||||
_, (filtered_groups,
|
||||
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
||||
proposal_lens)
|
||||
|
||||
expected_groups = [
|
||||
fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
|
||||
]
|
||||
expected_indices = [0, 2]
|
||||
|
||||
assert filtered_groups == expected_groups
|
||||
assert indices == expected_indices
|
||||
|
||||
|
||||
def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
|
||||
proposal_lens = [0, 1, 2]
|
||||
(filtered_groups,
|
||||
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
||||
proposal_lens)
|
||||
|
||||
expected_groups = [
|
||||
fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
|
||||
]
|
||||
expected_indices = [1, 2]
|
||||
|
||||
assert filtered_groups == expected_groups
|
||||
assert indices == expected_indices
|
||||
|
||||
|
||||
def test_empty_inputs():
|
||||
_, (filtered_groups, indices) = split_batch_by_proposal_len([], [])
|
||||
|
||||
assert filtered_groups == []
|
||||
assert indices == []
|
||||
|
||||
|
||||
def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
|
||||
proposal_lens = [0, 0, 0]
|
||||
(filtered_groups,
|
||||
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
||||
proposal_lens)
|
||||
|
||||
assert filtered_groups == []
|
||||
assert indices == []
|
||||
|
||||
|
||||
def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
|
||||
proposal_lens = [1, 1, 1]
|
||||
_, (filtered_groups,
|
||||
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
||||
proposal_lens)
|
||||
|
||||
assert filtered_groups == []
|
||||
assert indices == []
|
||||
|
||||
|
||||
def mock_spec_decode_sampler(acceptance_sampler_method):
|
||||
"""
|
||||
Returns either a RejectionSampler or TypicalAcceptanceSampler
|
||||
object depending on whether acceptance_sampler_method is
|
||||
'rejection_sampler' or 'typical_acceptance_sampler' respectively.
|
||||
"""
|
||||
if acceptance_sampler_method == "rejection_sampler":
|
||||
sampler = MagicMock(spec=RejectionSampler)
|
||||
sampler.token_id_dtype = torch.int64
|
||||
return sampler
|
||||
elif acceptance_sampler_method == "typical_acceptance_sampler":
|
||||
sampler = MagicMock(spec=TypicalAcceptanceSampler)
|
||||
sampler.token_id_dtype = torch.int64
|
||||
return sampler
|
||||
else:
|
||||
raise ValueError(f"Invalid sampler name {acceptance_sampler_method}")
|
||||
|
||||
|
||||
def test_get_sampled_token_logprobs():
|
||||
"""Verify get_sampled_token_logprobs returns consistent rankings
|
||||
with regular get_ranks when probabilities match exactly.
|
||||
"""
|
||||
logprob_tensor = torch.tensor(
|
||||
[[[-.1, -.1]] * 2]) # shape (num_steps, batch_size, vocab_size)
|
||||
sampled_token_tensor = torch.tensor([[1,
|
||||
0]]) # shape (num_steps, batch_size)
|
||||
ranks_spec_dec, _ = get_sampled_token_logprobs(logprob_tensor,
|
||||
sampled_token_tensor)
|
||||
|
||||
ranks_regular = _get_ranks(logprob_tensor.reshape((2, -1)),
|
||||
sampled_token_tensor.reshape(-1))
|
||||
|
||||
assert torch.equal(ranks_spec_dec.reshape(-1), ranks_regular)
|
||||
@@ -1,317 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/tests/spec_decode/utils.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from itertools import count
|
||||
from typing import Callable, Optional, TypeVar, Union
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
SequenceData, SequenceGroupMetadata, SequenceOutput)
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker # noqa: F401
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
|
||||
from vllm_ascend.worker.model_runner import NPUModelRunner
|
||||
from vllm_ascend.worker.worker import NPUWorker
|
||||
|
||||
T = TypeVar("T", bound=NPUWorker)
|
||||
|
||||
|
||||
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
|
||||
return (seq_len + block_size - 1) // block_size
|
||||
|
||||
|
||||
def mock_worker(cls=None,
|
||||
vocab_size: int = 30_000,
|
||||
max_model_len: int = 2048,
|
||||
rank: int = 0,
|
||||
use_spec: bool = True) -> MagicMock:
|
||||
if cls is None:
|
||||
cls = NPUWorker
|
||||
|
||||
spec = cls if use_spec else None
|
||||
|
||||
worker = MagicMock(spec=spec)
|
||||
worker.vocab_size = vocab_size
|
||||
worker.max_model_len = max_model_len
|
||||
worker.rank = rank
|
||||
worker.device = 'npu:0'
|
||||
return worker
|
||||
|
||||
|
||||
def patch_execute_model_with_seeds(worker: NPUWorker, rand_seeds: list[int]):
|
||||
seed_iter = iter(rand_seeds)
|
||||
original_execute_model = worker.execute_model
|
||||
|
||||
def new_execute_model(*args, **kwargs):
|
||||
result = original_execute_model(*args, **kwargs)
|
||||
set_random_seed(next(seed_iter))
|
||||
return result
|
||||
|
||||
return new_execute_model
|
||||
|
||||
|
||||
def zero_kv_cache(cache_engine: list[CacheEngine]):
|
||||
assert cache_engine[0].gpu_cache
|
||||
for key_blocks, value_blocks in cache_engine[0].gpu_cache:
|
||||
key_blocks.zero_()
|
||||
value_blocks.zero_()
|
||||
|
||||
|
||||
def create_worker(cls: Callable[..., T],
|
||||
model_name: str,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
seed: int,
|
||||
is_driver_worker: bool = True,
|
||||
enforce_eager: bool = True,
|
||||
model_runner_cls: Optional[NPUModelRunner] = None,
|
||||
dtype: Optional[str] = "auto") -> T:
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
seed=seed,
|
||||
block_size=block_size,
|
||||
enforce_eager=enforce_eager,
|
||||
dtype=dtype,
|
||||
)
|
||||
engine_config = engine_args.create_engine_config()
|
||||
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
|
||||
if cls.__name__ == "NGramWorker":
|
||||
# we need to pass by device type to enable this on npu
|
||||
worker = cls(vllm_config=engine_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker,
|
||||
model_runner_cls=model_runner_cls,
|
||||
device_type="npu")
|
||||
else:
|
||||
worker = cls(
|
||||
vllm_config=engine_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker,
|
||||
model_runner_cls=model_runner_cls,
|
||||
)
|
||||
|
||||
worker.init_device()
|
||||
worker.load_model()
|
||||
|
||||
engine_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
engine_config.cache_config.num_cpu_blocks = 0
|
||||
worker.initialize_cache(
|
||||
num_gpu_blocks=engine_config.cache_config.num_gpu_blocks,
|
||||
num_cpu_blocks=engine_config.cache_config.num_cpu_blocks)
|
||||
|
||||
return worker
|
||||
|
||||
|
||||
def create_seq_group_metadata_from_prompts(
|
||||
prompts: list[list[int]],
|
||||
num_gpu_blocks: int,
|
||||
block_size: int,
|
||||
final_prompt_lens: list[int],
|
||||
continuations: Optional[list[list[int]]] = None,
|
||||
seq_ids: Optional[list[int]] = None,
|
||||
) -> list[SequenceGroupMetadata]:
|
||||
|
||||
if continuations is None:
|
||||
continuations = [[] for _ in prompts]
|
||||
|
||||
if seq_ids is None:
|
||||
seq_ids = list(i for i, _ in enumerate(prompts))
|
||||
|
||||
free_gpu_blocks = list(range(num_gpu_blocks))
|
||||
|
||||
block_allocations = {
|
||||
i: [
|
||||
free_gpu_blocks.pop()
|
||||
for _ in range(round_up_to_next_block(final_len, block_size))
|
||||
]
|
||||
for i, final_len in enumerate(final_prompt_lens)
|
||||
}
|
||||
|
||||
seq_grou_metadata_list = []
|
||||
for i, (prompt_token_ids,
|
||||
cont_token_ids) in enumerate(zip(prompts, continuations)):
|
||||
data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids)
|
||||
data.update_num_computed_tokens(
|
||||
len(prompt_token_ids) + len(cont_token_ids) - 1)
|
||||
seq_data = {i: data}
|
||||
seq_grou_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=str(i),
|
||||
is_prompt=len(cont_token_ids) == 0,
|
||||
seq_data=seq_data,
|
||||
sampling_params=SamplingParams(temperature=0.0),
|
||||
block_tables={i: block_allocations[i][:]},
|
||||
))
|
||||
return seq_grou_metadata_list
|
||||
|
||||
|
||||
def create_chunked_seq_group_metadata_from_prompt(
|
||||
prompt: list[int],
|
||||
num_gpu_blocks: int,
|
||||
chunk_size: int,
|
||||
block_size: int,
|
||||
seq_id: Optional[int] = None) -> list[SequenceGroupMetadata]:
|
||||
|
||||
if seq_id is None:
|
||||
seq_id = 0
|
||||
|
||||
free_gpu_blocks = list(range(num_gpu_blocks))
|
||||
|
||||
block_allocations = [
|
||||
free_gpu_blocks.pop()
|
||||
for _ in range(round_up_to_next_block(len(prompt), block_size))
|
||||
]
|
||||
|
||||
seq_group_metadata_list = []
|
||||
for i, idx in enumerate(range(0, len(prompt), chunk_size)):
|
||||
chunk_ids = prompt[idx:idx + chunk_size]
|
||||
data = SequenceData.from_seqs(prompt)
|
||||
data.update_num_computed_tokens(idx)
|
||||
seq_data = {i: data}
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=str(seq_id),
|
||||
is_prompt=True,
|
||||
do_sample=idx + chunk_size >= len(prompt), # terminal chunk
|
||||
seq_data=seq_data,
|
||||
sampling_params=SamplingParams(temperature=0.0),
|
||||
block_tables={i: block_allocations},
|
||||
token_chunk_size=len(chunk_ids)))
|
||||
return seq_group_metadata_list
|
||||
|
||||
|
||||
def assert_logprobs_dict_allclose(
|
||||
actual_logprobs: list[dict[int, Logprob]],
|
||||
expected_logprobs: list[dict[int, Logprob]]) -> None:
|
||||
for single_step_actual_logprobs, single_step_expected_logprobs in zip(
|
||||
actual_logprobs, expected_logprobs):
|
||||
assert set(single_step_actual_logprobs.keys()) == set(
|
||||
single_step_expected_logprobs.keys())
|
||||
for token_id in single_step_actual_logprobs:
|
||||
actual = torch.tensor(
|
||||
single_step_actual_logprobs[token_id].logprob)
|
||||
expected = torch.tensor(
|
||||
single_step_expected_logprobs[token_id].logprob)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
def create_sampler_output_list(
|
||||
token_ids: torch.Tensor,
|
||||
probs: GenericSequence[Optional[torch.Tensor]],
|
||||
logprobs: GenericSequence[Optional[torch.Tensor]],
|
||||
seq_ids: Optional[list[int]] = None) -> list[SamplerOutput]:
|
||||
num_steps, batch_size = token_ids.shape
|
||||
token_ids_by_step = token_ids.tolist()
|
||||
|
||||
if seq_ids is None:
|
||||
seq_ids = list(range(batch_size))
|
||||
|
||||
return [
|
||||
SamplerOutput(outputs=[
|
||||
CompletionSequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
output_token=token_id,
|
||||
parent_seq_id=seq_ids[seq_index],
|
||||
logprobs={token_id: Logprob(0)},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
) for seq_index, token_id in enumerate(token_ids_by_step[step])
|
||||
],
|
||||
sampled_token_probs=probs[step],
|
||||
logprobs=logprobs[step],
|
||||
sampled_token_ids=token_ids[step])
|
||||
for step in range(num_steps)
|
||||
]
|
||||
|
||||
|
||||
def create_batch(batch_size,
|
||||
k,
|
||||
prompt_len: Union[int, list[int]] = 10,
|
||||
prev_output_token_len: int = 10,
|
||||
seq_ids: Optional[list[int]] = None,
|
||||
num_gpu_blocks: Optional[int] = None,
|
||||
block_size: Optional[int] = None,
|
||||
prefill_chunk_size: Optional[int] = None):
|
||||
if block_size is None:
|
||||
block_size = 8
|
||||
|
||||
if num_gpu_blocks is None:
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
|
||||
iterator = count()
|
||||
|
||||
if isinstance(prompt_len, int):
|
||||
prompt_lens = [prompt_len for _ in range(batch_size)]
|
||||
else:
|
||||
prompt_lens = prompt_len
|
||||
|
||||
prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens]
|
||||
|
||||
if prefill_chunk_size:
|
||||
# Create a batch of chunked prompts.
|
||||
if not seq_ids:
|
||||
seq_ids = list(range(len(prompts)))
|
||||
seq_group_metadata_list = []
|
||||
for p, sid in zip(prompts, seq_ids):
|
||||
seq_group_metadata_list += \
|
||||
create_chunked_seq_group_metadata_from_prompt(
|
||||
p, num_gpu_blocks, prefill_chunk_size, block_size, sid)
|
||||
seq_group_metadata_list = seq_group_metadata_list[:batch_size]
|
||||
prev_output_tokens = []
|
||||
else:
|
||||
prev_output_tokens = [[
|
||||
next(iterator) for _ in range(prev_output_token_len)
|
||||
] for _ in range(batch_size)]
|
||||
final_prompt_lens = [
|
||||
len(prompt) + len(prev_output_token) + k + 1
|
||||
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
|
||||
]
|
||||
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size, final_prompt_lens,
|
||||
prev_output_tokens, seq_ids)
|
||||
return seq_group_metadata_list, prompts, prev_output_tokens
|
||||
|
||||
|
||||
def maybe_enable_chunked_prefill(prefill_chunk_size, llm_kwargs):
|
||||
if prefill_chunk_size > 0:
|
||||
llm_kwargs.update(
|
||||
**{
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": prefill_chunk_size,
|
||||
"max_num_seqs": prefill_chunk_size
|
||||
})
|
||||
else:
|
||||
llm_kwargs["enable_chunked_prefill"] = False
|
||||
Reference in New Issue
Block a user