[CI] Refactor CI (#952)

1. remove some useless test func and file
2. fix format.sh problem
3. enable full test for singlecard and multicard
4. move long term test to long_term folder. For this kind of test, it
only runs by labeled and daily test. Include: spec decode、accuracy test

## After refactor:
There are 4 test modules
- `singlecard`: contains the test running on one NPU. It'll be run for
each PR and daily test.
- `multicard`: contains the test running on multi NPUs. It'll be run for
each PR and daily test.
- `long_term`: contains the test that cost much time(Now include `spec
decode` and `accuracy` test). It'll be run for the PR with
`long-term-test` labeled and daily test.
- `e2e`: contains the test for doc and pd feature. It'll be run for the
PR with `pd-test` labeled and daily test.

## Todo:
1. some test are skipped, they should be fixed and reenabled in the
future.
2. pyhccl test for multicard doesn't work at all. It should be enabled
as well.
3. ensure long-term-test pass by daily test.

### Know issue
Now, `ready` labels is required to start pd test or long term test. And
when `long-term-test` or `pd-test` is labeled after another one, the old
labeled test will be re-run again. So the labeled test should be ran in
the following step:

1. decide which test need run, then label it. `long-term-test` or
`pd-test` or both.
2. add `ready-for-test` label, then the test will be ran.

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-05-28 06:31:35 +08:00
committed by GitHub
parent 9f5ab59e30
commit e2a0c19cea
34 changed files with 171 additions and 1288 deletions

View File

@@ -0,0 +1,118 @@
# SPDX-License-Identifier: Apache-2.0
"""
Test the piecewise compilation with a simple model so that we
can exactly calculate the expected output and side effects.
"""
import pytest
import torch
from torch import nn
from torch.library import Library
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
set_current_vllm_config)
from vllm.utils import direct_register_custom_op
global_counter = 0
# create a library to hold the custom op
silly_lib = Library("silly", "FRAGMENT") # noqa
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
global global_counter
global_counter += 1
print(f"{global_counter=}")
out.copy_(q)
out[0] += 1
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return
direct_register_custom_op(
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
dispatch_key="PrivateUse1",
target_lib=silly_lib,
)
@support_torch_compile
class SillyModel(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
**kwargs) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Overall effect:
x += 1
x[0] += 2
global_counter += 2
"""
x = x + 1
x = x + 2
out = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, out)
x = out
x = x - 2
x = x - 1
out = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, out)
x = out
x = x + 1
return x
@pytest.mark.skipif(True, reason="requires unreleased components")
def test_simple_piecewise_compile():
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_inductor=False,
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_copy_inputs=True,
cudagraph_capture_sizes=[1, 2],
))
vllm_config.compilation_config.pass_config.enable_fusion = False
with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix="")
inputs = torch.randn(100).npu()
with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
num_cudagraph_caputured=
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
model(inputs)
model(torch.randn(2).npu())
model(torch.randn(1).npu())
input = torch.zeros(2).npu()
global global_counter
global_counter = 0
output = model(input)
assert global_counter == 2
assert torch.allclose(output.cpu(), torch.tensor([3.0, 1.0]))
if __name__ == "__main__":
test_simple_piecewise_compile()

View File

View File

@@ -0,0 +1,100 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-License-Identifier: Apache-2.0
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/kernels/test_moe.py
"""Tests for the MOE layers.
Run `pytest tests/ops/test_fused_moe.py`.
"""
# fused moe ops test will hit the infer_schema error, we need add the patch
# here to make the test pass.
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
import pytest
import torch
from vllm.model_executor.layers.activation import SiluAndMul
from vllm_ascend.ops.fused_moe import fused_experts
NUM_EXPERTS = [8, 64]
EP_SIZE = [1, 4]
TOP_KS = [2, 6]
DEVICE = ["npu"]
def torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
topk_weights = topk_weights.view(-1)
topk_ids = topk_ids.view(-1)
if expert_map is not None:
topk_ids = expert_map[topk_ids]
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(B, -1, w2.shape[1]) *
topk_weights.view(B, -1, 1).to(out.dtype)).sum(dim=1)
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("device", DEVICE)
def test_fused_experts(
m: int,
n: int,
k: int,
e: int,
topk: int,
ep_size: int,
dtype: torch.dtype,
device: str,
):
a = torch.randn((m, k), device=device, dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
score = torch.randn((m, e), device=device, dtype=dtype)
if ep_size > 1:
local_e = e // ep_size
e_ids = torch.randint(0,
e, (local_e, ),
device=device,
dtype=torch.int32)
e_map = torch.full((e, ), -1, device=device, dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e, device=device, dtype=torch.int32)
w1 = w1[e_ids]
w2 = w2[e_ids]
else:
e_map = None
score = torch.softmax(score, dim=-1, dtype=dtype)
topk_weights, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.to(torch.int32)
output = fused_experts(a, w1, w2, topk_weights, topk_ids, topk, e_map)
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, e_map)
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1)
torch.npu.empty_cache()

View File

@@ -0,0 +1,190 @@
# Copyright (c) China Merchants Bank Co., Ltd. 2025. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#/
# to run this test, you need to cd to the upper package which is 'tests',
# and run with command 'pytest -s ops/test_multi_step.py'
import torch
import torch_npu # noqa: F401
DTYPES = [torch.int32, torch.int64]
DEVICES = [f"npu:{0}"]
# Set tolerance to 0 for equals
DEFAULT_ATOL = 0
DEFAULT_RTOL = 0
# test custom ops of https://github.com/vllm-project/vllm-ascend/tree/main/csrc/kernels/advance_step.cpp
@torch.inference_mode()
def test_single_generation_multi_step() -> None:
input_tokens_data = [2926]
input_tokens_ascendc = torch.tensor(input_tokens_data, device='npu:0')
input_tokens_python = torch.tensor(input_tokens_data, device='npu:0')
sampled_token_ids_data = [[13]]
sampled_token_ids = torch.tensor(sampled_token_ids_data, device='npu:0')
input_positions_data = [5]
input_positions_ascendc = torch.tensor(input_positions_data,
device='npu:0')
input_positions_python = torch.tensor(input_positions_data, device='npu:0')
seq_lens_data = [6]
seq_lens_ascendc = torch.tensor(seq_lens_data,
device='npu:0',
dtype=torch.int32)
seq_lens_python = torch.tensor(seq_lens_data,
device='npu:0',
dtype=torch.int32)
slot_mapping_data = [5]
slot_mapping_ascendc = torch.tensor(slot_mapping_data,
device='npu:0',
dtype=torch.int32)
slot_mapping_python = torch.tensor(slot_mapping_data,
device='npu:0',
dtype=torch.int32)
block_tables_data = [[0]]
block_tables = torch.tensor(block_tables_data,
device='npu:0',
dtype=torch.int32)
torch.ops._C.advance_step_flashattn_ascendc(
1, 1, 128, input_tokens_ascendc, sampled_token_ids,
input_positions_ascendc, seq_lens_ascendc, slot_mapping_ascendc,
block_tables)
normal(1, 1, 128, input_tokens_python, sampled_token_ids,
input_positions_python, seq_lens_python, slot_mapping_python,
block_tables)
# Compare the results.
torch.testing.assert_close(input_tokens_ascendc,
input_tokens_python,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(input_positions_ascendc,
input_positions_python,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(seq_lens_ascendc,
seq_lens_python,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(slot_mapping_ascendc,
slot_mapping_python,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
@torch.inference_mode()
def test_multi_result_generation_multi_step() -> None:
input_tokens_data = [2926, 279, 12095, 1588]
input_tokens_ascendc = torch.tensor(input_tokens_data, device='npu:0')
input_tokens_python = torch.tensor(input_tokens_data, device='npu:0')
sampled_token_ids_data = [[13], [1968], [13], [13]]
sampled_token_ids = torch.tensor(sampled_token_ids_data, device='npu:0')
input_positions_data = [5, 7, 5, 5]
input_positions_ascendc = torch.tensor(input_positions_data,
device='npu:0')
input_positions_python = torch.tensor(input_positions_data, device='npu:0')
seq_lens_data = [6, 8, 6, 6]
seq_lens_ascendc = torch.tensor(seq_lens_data,
device='npu:0',
dtype=torch.int32)
seq_lens_python = torch.tensor(seq_lens_data,
device='npu:0',
dtype=torch.int32)
slot_mapping_data = [5, 135, 261, 389]
slot_mapping_ascendc = torch.tensor(slot_mapping_data,
device='npu:0',
dtype=torch.int32)
slot_mapping_python = torch.tensor(slot_mapping_data,
device='npu:0',
dtype=torch.int32)
block_tables_data = [[0], [1], [2], [3]]
block_tables = torch.tensor(block_tables_data,
device='npu:0',
dtype=torch.int32)
torch.ops._C.advance_step_flashattn_ascendc(
4, 4, 128, input_tokens_ascendc, sampled_token_ids,
input_positions_ascendc, seq_lens_ascendc, slot_mapping_ascendc,
block_tables)
normal(4, 4, 128, input_tokens_python, sampled_token_ids,
input_positions_python, seq_lens_python, slot_mapping_python,
block_tables)
# Compare the results.
torch.testing.assert_close(input_tokens_ascendc,
input_tokens_python,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(input_positions_ascendc,
input_positions_python,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(seq_lens_ascendc,
seq_lens_python,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(slot_mapping_ascendc,
slot_mapping_python,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
def normal(num_seqs: int, num_queries: int, block_size: int,
input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor, seq_lens_tensor: torch.Tensor,
slot_mapping: torch.Tensor, block_tables: torch.Tensor) -> None:
sampled_token_ids_list = sampled_token_ids[:num_queries].squeeze(-1)
input_tokens[:num_queries] = sampled_token_ids_list
# get seq_lens and input_positions
seq_lens = seq_lens_tensor[:num_queries]
next_seq_lens = seq_lens + 1
next_input_pos = next_seq_lens - 1
# update seq_lens and input_positions
seq_lens_tensor[:num_queries] = next_seq_lens
input_positions[:num_queries] = next_input_pos # type: ignore
# get block index and offset
block_idx = next_input_pos // block_size
block_offset = next_input_pos % block_size
current_block_table = block_tables.gather(
1, block_idx.unsqueeze(-1)).squeeze(-1)
slot_num = current_block_table * block_size + block_offset
# update slot_mapping
slot_mapping[:num_queries] = slot_num

View File

@@ -0,0 +1,198 @@
# Copyright 2023 The vLLM team.
# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
# Adapted from
# https://github.com/vllm-project/vllm/blob/main/vllm/tests/kernels/test_rotary_embedding.py
from typing import Optional, Tuple, Union
import pytest
import torch
import torch.nn as nn
import vllm_ascend.platform # noqa: F401
# Only Neox style true scenario is supported for now
IS_NEOX_STYLE = [True]
DTYPES = [torch.half]
HEAD_SIZES = [64, 96, 128, 256]
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
NUM_HEADS = [17] # Arbitrary values for testing
BATCH_SIZES = [5] # Arbitrary values for testing
SEQ_LENS = [11, 4096] # Arbitrary values for testing
SEEDS = [0]
DEVICES = [f"npu:{0}"]
# Set tolerance to 1 for quant ops
DEFAULT_ATOL = 1e-3
DEFAULT_RTOL = 1e-3
def _apply_rotary_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
# adapted from https://github.com/vllm-project/vllm/vllm/model_executor/layers/rotary_embedding.py
class RotaryEmbedding(nn.Module):
"""Original rotary positional embedding."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
cache = self._compute_cos_sin_cache()
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-native implementation of forward()."""
if offsets is not None:
positions = positions + offsets
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
# test with leading dimension and merge seqlen and batch_size as num_tokens
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_rotary_embedding_quant_with_leading_dim(
is_neox_style: bool,
batch_size: int,
seq_len: int,
num_heads: int,
head_size: int,
rotary_dim: Optional[int],
dtype: torch.dtype,
seed: int,
device: str,
max_position: int = 8192,
base: int = 10000,
) -> None:
if rotary_dim is None:
rotary_dim = head_size
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype)
rope = rope.to(dtype=dtype)
num_tokens = batch_size * seq_len
positions = torch.randint(0, max_position, (batch_size * seq_len, ))
qkv_tensor = torch.randn(num_tokens,
num_heads * head_size * 3,
dtype=dtype)
query, key, _ = qkv_tensor.split(
[num_heads * head_size, num_heads * head_size, num_heads * head_size],
dim=-1,
)
ref_query, ref_key = rope.forward_native(positions, query, key)
query, key = torch.ops._C.rotary_embedding(
positions,
query,
key,
rope.head_size,
rope.cos_sin_cache,
rope.is_neox_style,
)
# Compare the results.
torch.testing.assert_close(query.view(ref_query.size()),
ref_query,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(key.view(ref_key.size()),
ref_key,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)

View File

View File

@@ -0,0 +1,611 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional
import pytest
import torch
import torch.nn.functional as F
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm_ascend.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
AscendRejectionSampler)
DEVICE = "npu"
@pytest.fixture
def rejection_sampler():
return AscendRejectionSampler()
def create_logits_tensor(output_token_ids: list[list[int]],
vocab_size: int = 100) -> torch.Tensor:
"""Helper function to create logits tensor that
will produce desired token ids on argmax"""
token_ids = [tokens[:-1] for tokens in output_token_ids]
num_total_tokens = sum(len(tokens) for tokens in token_ids)
logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE)
start_loc = 0
for tokens in token_ids:
for j, token_id in enumerate(tokens):
logits[start_loc + j, token_id] = 100.0
start_loc += len(tokens)
return logits
def create_sampling_metadata(
all_greedy: bool,
temperature: Optional[torch.Tensor] = None,
top_k: Optional[torch.Tensor] = None,
top_p: Optional[torch.Tensor] = None,
generators: Optional[dict[int, Any]] = None,
) -> SamplingMetadata:
"""Create a v1 sampling metadata object with all_greedy set
to the given value. Either all greedy or all random sampling
is used.
"""
generators = generators or {}
if all_greedy:
temperature = None
else:
assert temperature is not None
return SamplingMetadata(
temperature=temperature,
all_greedy=all_greedy,
all_random=not all_greedy,
top_p=top_p,
top_k=top_k,
min_p=torch.empty(1, ),
generators=generators,
max_num_logprobs=0,
no_penalties=False,
prompt_token_ids=None,
frequency_penalties=torch.tensor([]),
presence_penalties=torch.tensor([]),
repetition_penalties=torch.tensor([]),
output_token_ids=[],
min_tokens={},
logit_bias=[None],
allowed_token_ids_mask=None,
bad_words_token_ids={},
)
########################### Tests for Greedy Sampling ###################
def test_perfect_match(rejection_sampler):
"""Test when output tokens perfectly match speculated tokens"""
spec_tokens = [[1, 2, 3]]
output_tokens = [[1, 2, 3, 4]] # 4 is the bonus token
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor([[1, 2, 3, 4]],
dtype=torch.int,
device=logits.device)
assert torch.equal(output, expected)
def test_early_mismatch(rejection_sampler):
"""Test when there's an early mismatch in tokens"""
spec_tokens = [[1, 2, 3]]
output_tokens = [[1, 5, 3, 4]] # Mismatch at position 1
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor(
[[1, 5, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]],
dtype=torch.int,
device=logits.device,
)
assert torch.equal(output, expected)
def test_multiple_sequences(rejection_sampler):
"""Test handling multiple sequences of speculated tokens"""
spec_tokens = [[1, 2], [3]]
output_tokens = [[1, 2, 5], [3,
4]] # Two sequences with bonus tokens 5 and 4
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor(
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor([[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]],
dtype=torch.int,
device=logits.device)
assert torch.equal(output, expected)
def test_single_token_sequence(rejection_sampler):
"""Test handling sequences with single token"""
spec_tokens = [[1]]
output_tokens = [[1, 2]] # Single token with bonus token 2
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
assert torch.equal(output, expected)
def test_empty_sequence(rejection_sampler):
"""Test handling empty sequence of speculated tokens"""
spec_tokens: list[list[int]] = [[]]
output_tokens = [[5]] # Just the bonus token
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
assert torch.equal(output, expected)
def test_multiple_mismatches(rejection_sampler):
"""Test handling multiple sequences with mismatches"""
spec_tokens = [[1, 2, 3], [4, 5, 6]]
output_tokens = [[1, 2, 7, 6], [4, 8, 6,
9]] # Mismatches in both sequences
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor(
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor(
[[1, 2, 7, PLACEHOLDER_TOKEN_ID],
[4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]],
dtype=torch.int,
device=logits.device,
)
assert torch.equal(output, expected)
@pytest.mark.parametrize(
"spec_tokens,output_tokens,expected",
[
([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]), # Perfect match with bonus
([[1]], [[2, 3]], [[2, PLACEHOLDER_TOKEN_ID]]), # First mismatch
([[1, 2], [3, 4]], [[1, 5, 6], [3, 4, 7]],
[[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]]), # Mixed matches
])
def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens,
expected):
"""Parametrized test for various matching scenarios"""
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens],
device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
device=logits.device)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected_tensor = torch.tensor(expected,
dtype=torch.int,
device=logits.device)
assert torch.equal(output, expected_tensor)
########################### Tests for Random Sampling ###################
@pytest.mark.parametrize("k", [1, 3, 5])
@pytest.mark.parametrize("vocab_size", [1000])
@pytest.mark.parametrize("batch_size", [1, 4, 8])
@pytest.mark.parametrize("frac_seeded", [0.0, 0.5])
@pytest.mark.parametrize("n_rep", [20])
def test_deterministic_when_seeded(
rejection_sampler,
k: int,
vocab_size: int,
batch_size: int,
frac_seeded: float,
n_rep: int,
):
num_tokens = batch_size * k
draft_probs = torch.rand(num_tokens,
vocab_size,
dtype=torch.float32,
device=DEVICE)
draft_probs = F.softmax(draft_probs, dim=-1)
target_logits = torch.rand_like(draft_probs)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64,
device=DEVICE)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device=DEVICE)
seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded
results = []
for _ in range(n_rep):
seeded_seqs = {
i: torch.Generator(device=DEVICE).manual_seed(i)
for i in range(batch_size) if seeded_mask[i]
}
temperature = torch.ones(batch_size,
dtype=torch.float32,
device=DEVICE)
sampling_metadata = create_sampling_metadata(all_greedy=False,
temperature=temperature,
generators=seeded_seqs)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids.tolist(), device=DEVICE)
rep_result = rejection_sampler(
spec_decode_metadata,
draft_probs=draft_probs,
target_logits=target_logits,
bonus_token_ids=bonus_token_ids,
sampling_metadata=sampling_metadata,
)
results.append(rep_result)
for i in range(batch_size):
if seeded_mask[i]:
for j in range(1, n_rep):
assert torch.equal(results[j][i], results[0][i])
@pytest.mark.skipif(True, reason="Test failed, need fix")
def test_rejection_sampling_approximates_target_distribution():
"""Verify rejection sampling approximates target distribution,
despite sampling from a potentially distinct draft distribution.
This is done by first creating a random target probability
distribution and a random draft probability distribution. We then
sample token ids from the rejection sampler using these draft
and target distributions. The samples are used to estimate
the output probability distribution, which we expect to approximate
the target distribution.
A basic distance metric is used to determine similarity between
distributions.
We expect that as we increase the number of samples,
the distance between the observed distribution and the target
distribution decreases. To measure this, we compare the distance
of the observed distribution against both the target distribution
and a uniform random distribution. We expect the distance between
the observed distribution and the target distribution to improve
much more than the distance improvement between the observed
distribution and the random distribution.
"""
torch.set_default_device(DEVICE)
vocab_size = 10
k = 2
num_reference_probs = 100
# Prepare draft, target, and reference probability distributions
draft_probs = F.softmax(torch.rand(vocab_size, dtype=torch.float32),
dim=-1)
target_logits = torch.rand(vocab_size, dtype=torch.float32)
target_probs = F.softmax(target_logits, dim=-1)
reference_probs = F.softmax(
torch.rand(num_reference_probs, vocab_size, dtype=torch.float32),
dim=-1,
)
sample_sizes = [10, 100, 1_000, 10_000, 100_000]
distance_wrt_reference: list[float] = []
distance_wrt_target: list[float] = []
for num_samples in sample_sizes:
# Sample using rejection sampling.
rej_sample_probs = estimate_rejection_sampling_pdf(
draft_probs, target_logits, k, vocab_size, num_samples)
rej_sample_probs = rej_sample_probs.to(DEVICE)
# Average distance from reference probs.
reference_vs_rejsample_dist = torch.dist(
reference_probs,
rej_sample_probs).item() / reference_probs.shape[0]
target_vs_rejsample_dist = torch.dist(target_probs,
rej_sample_probs).item()
distance_wrt_reference.append(reference_vs_rejsample_dist)
distance_wrt_target.append(target_vs_rejsample_dist)
relative_change_in_distance_wrt_target = get_ratio_first_to_last(
distance_wrt_target)
relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
distance_wrt_reference)
print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
f"{reference_vs_rejsample_dist=:.05f}")
print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} "
f"{relative_change_in_distance_wrt_reference=:.02f}")
relative_change_in_distance_wrt_target = get_ratio_first_to_last(
distance_wrt_target)
relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
distance_wrt_reference)
expected_improvement_multiplier = 20
assert (relative_change_in_distance_wrt_target >
relative_change_in_distance_wrt_reference *
expected_improvement_multiplier)
def get_ratio_first_to_last(elements: list[float]) -> float:
return elements[0] / elements[-1]
def estimate_rejection_sampling_pdf(
draft_probs: torch.Tensor,
target_logits: torch.Tensor,
k: int,
vocab_size: int,
num_samples: int,
) -> torch.Tensor:
"""Estimate the probability distribution of the output tokens
using rejection sampling.
Args:
draft_probs: Draft probability distribution.
target_logits: Target logits.
num_samples: Number of samples to draw.
Returns:
Estimated probability distribution of the output tokens.
"""
rejection_sampler = AscendRejectionSampler()
num_tokens = num_samples * k
# Repeat draft probs num_samples * k times.
draft_probs = draft_probs.reshape(1, 1,
vocab_size).repeat(num_samples, k, 1)
# Repeat target probs num_tokens times.
target_logits = target_logits.reshape(1, vocab_size).repeat(num_tokens, 1)
# Randomly sample draft token ids from draft probs.
draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
num_samples=k,
replacement=True).reshape(
num_samples, k)
draft_probs = draft_probs.view(num_tokens, vocab_size)
# Bonus tokens not used but required.
bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64,
device=DEVICE).repeat(num_samples, 1)
temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE)
sampling_metadata = create_sampling_metadata(all_greedy=False,
temperature=temperature)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids.tolist(), device=bonus_token_ids.device)
output_token_ids = rejection_sampler(
spec_decode_metadata,
draft_probs=draft_probs,
target_logits=target_logits,
bonus_token_ids=bonus_token_ids,
sampling_metadata=sampling_metadata,
)
output_token_ids = output_token_ids[:, :-1].flatten()
hist = torch.histogram(output_token_ids.to(dtype=torch.float,
device="cpu"),
bins=vocab_size,
range=(0, vocab_size),
density=True)
return hist.hist
def _test_masked_logits(
rejection_sampler,
batch_size: int,
num_draft_tokens: int,
vocab_size: int,
target_logits: torch.Tensor,
unmasked_indices: torch.Tensor,
sampling_metadata: SamplingMetadata,
):
# Set up test parameters
num_tokens = batch_size * num_draft_tokens
# Create random draft probabilities.
draft_probs = torch.rand((num_tokens, vocab_size),
dtype=torch.float32,
device=DEVICE)
draft_probs = F.softmax(draft_probs, dim=-1)
# Randomly sample draft token ids from draft probs
draft_token_ids = torch.multinomial(draft_probs, num_samples=1)
draft_token_ids = draft_token_ids.reshape(batch_size, num_draft_tokens)
draft_token_ids = draft_token_ids.tolist()
# Bonus tokens not used but required
bonus_token_ids = torch.zeros((batch_size, 1),
dtype=torch.int64,
device=DEVICE)
# Create spec decode metadata
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids,
device=DEVICE,
)
# Run rejection sampling
output_token_ids = rejection_sampler(
spec_decode_metadata,
draft_probs=draft_probs,
target_logits=target_logits,
bonus_token_ids=bonus_token_ids,
sampling_metadata=sampling_metadata,
)
# Remove bonus tokens and reshape
output_token_ids = output_token_ids[:, :-1].flatten().tolist()
# Check that all sampled tokens are within the unmasked indices.
for i in range(num_tokens):
token_id = output_token_ids[i]
if token_id == PLACEHOLDER_TOKEN_ID:
continue
assert token_id in unmasked_indices[i]
@pytest.mark.parametrize("top_k", [1, 5, 99])
def test_top_k(rejection_sampler, top_k):
"""Test rejection sampling with top-k sampling"""
vocab_size = 100
batch_size = 100
num_draft_tokens = 3
num_tokens = batch_size * num_draft_tokens
# Randomly create top-k indices.
top_k_indices = [
torch.randperm(vocab_size, device=DEVICE)[:top_k]
for _ in range(num_tokens)
]
top_k_indices = torch.stack(top_k_indices)
# Create logits with the uniform distribution.
target_logits = torch.zeros((num_tokens, vocab_size), device=DEVICE)
# Increment the logits for top-k indices, a little bit more than the other
# ones. If the masking is effective, the non-topk indices will never be
# sampled despite the small difference in logits.
for i in range(num_tokens):
target_logits[i, top_k_indices[i]] += 0.1
# Create sampling metadata
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
sampling_metadata = create_sampling_metadata(
all_greedy=False,
temperature=temperature,
top_k=torch.tensor([top_k] * batch_size,
device=DEVICE,
dtype=torch.int64),
)
_test_masked_logits(
rejection_sampler,
batch_size=batch_size,
num_draft_tokens=num_draft_tokens,
vocab_size=vocab_size,
target_logits=target_logits,
unmasked_indices=top_k_indices,
sampling_metadata=sampling_metadata,
)
@pytest.mark.parametrize("top_p", [0.5, 0.9, 0.99])
def test_top_p(rejection_sampler, top_p):
"""Test rejection sampling with top-p sampling"""
vocab_size = 100
batch_size = 100
num_draft_tokens = 3
num_tokens = batch_size * num_draft_tokens
# Create logits with the uniform distribution.
target_logits = torch.randn((num_tokens, vocab_size), device=DEVICE)
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
rescaled_logits = target_logits / temperature
logits_sort, logits_idx = rescaled_logits.sort(dim=-1, descending=False)
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum <= 1 - top_p
# at least one
top_p_mask[:, -1] = False
# Get the top-p indices.
top_p_indices = []
for i in range(num_tokens):
top_p_indices.append(logits_idx[i][~top_p_mask[i]].tolist())
# Create sampling metadata
sampling_metadata = create_sampling_metadata(
all_greedy=False,
temperature=temperature,
top_p=torch.tensor([top_p] * batch_size,
device=DEVICE,
dtype=torch.float32),
)
_test_masked_logits(
rejection_sampler,
batch_size=batch_size,
num_draft_tokens=num_draft_tokens,
vocab_size=vocab_size,
target_logits=target_logits,
unmasked_indices=top_p_indices,
sampling_metadata=sampling_metadata,
)

View File

@@ -1,18 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from vllm_ascend.patch import worker # noqa: F401

View File

@@ -1,28 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/spec_decode/conftest.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
Since this module is V0 only, set VLLM_USE_V1=0 for
all tests in the module.
"""
monkeypatch.setenv('VLLM_USE_V1', '0')

View File

@@ -1,274 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/conftest.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import shutil
from itertools import cycle
from pathlib import Path
from typing import List, Optional, Sequence, Tuple, Union
import pytest
import torch
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import PromptLogprobs, SampleLogprobs
from ....model_utils import (TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
check_logprobs_close, check_outputs_equal)
PROMPTS = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"San Francisco is know for its",
"Facebook was created in 2004 by",
"Curious George is a",
"Python 3.11 brings improvements to its",
]
@pytest.fixture
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, seed):
def generate():
kwargs = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**test_llm_kwargs,
}
llm = LLM(**kwargs)
if seed is not None:
set_random_seed(seed)
yield llm
del llm
cleanup_dist_env_and_memory()
return generate
def maybe_assert_ngram_worker(llm):
# Verify the proposer worker is ngram if ngram is specified.
if (llm.llm_engine.speculative_config is not None
and llm.llm_engine.speculative_config.method == "ngram"):
from vllm.spec_decode.ngram_worker import NGramWorker
assert isinstance(
llm.llm_engine.model_executor.driver_worker.proposer_worker,
NGramWorker)
def get_output_from_llm_generator(
llm_generator, prompts,
sampling_params) -> Tuple[List[str], List[List[int]], float]:
tokens: List[str] = []
token_ids: List[List[int]] = []
acceptance_rate: float = -1.0
for llm in llm_generator():
maybe_assert_ngram_worker(llm)
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
token_ids = [output.outputs[0].token_ids for output in outputs]
tokens = [output.outputs[0].text for output in outputs]
# Fetch acceptance rate if logging is enabled.
if stat_loggers := getattr(llm.llm_engine, "stat_loggers", None):
stat_logger = stat_loggers["prometheus"]
acceptance_rate = (stat_logger.metrics.
gauge_spec_decode_draft_acceptance_rate.labels(
**stat_logger.labels)._value.get())
del llm
return tokens, token_ids, acceptance_rate
def check_logprobs_correctness(
spec_outputs: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs]],
baseline_outputs: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs]],
disable_logprobs: bool = False,
):
"""Compare sampled and prompt logprobs between baseline and spec decoding
"""
if not disable_logprobs:
return check_logprobs_close(
outputs_0_lst=baseline_outputs,
outputs_1_lst=spec_outputs,
name_0="org",
name_1="sd",
)
# Check correctness when disable_logprobs == True
for spec_output, baseline_output in zip(spec_outputs, baseline_outputs):
# Check generated token logprobs.
spec_logprobs = spec_output[2]
baseline_logprobs = baseline_output[2]
_check_logprobs_when_output_disabled(spec_logprobs,
baseline_logprobs,
is_prompt_logprobs=False)
# Check prompt logprobs too, if they exist
if len(baseline_output) == 4:
assert len(spec_output) == 4
spec_prompt_logprobs = spec_output[3]
baseline_prompt_logprobs = baseline_output[3]
_check_logprobs_when_output_disabled(spec_prompt_logprobs,
baseline_prompt_logprobs,
is_prompt_logprobs=True)
def _check_logprobs_when_output_disabled(
spec_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
baseline_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
is_prompt_logprobs: bool = False,
):
# Prompt logprobs are optional
if is_prompt_logprobs and baseline_logprobs is None:
assert spec_logprobs is None
return
assert spec_logprobs is not None
assert baseline_logprobs is not None
assert len(spec_logprobs) == len(baseline_logprobs)
# For each generated position of the sequence.
for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
zip(spec_logprobs, baseline_logprobs)):
# First prompt logprob is expected to be None
if is_prompt_logprobs and baseline_pos_logprobs is None:
assert spec_pos_logprobs is None
assert pos == 0
continue
assert spec_pos_logprobs is not None
assert baseline_pos_logprobs is not None
# When disabled, the 1 logprob is returned with dummy values for the
# score and rank, but the token id should match the baseline model
assert len(spec_pos_logprobs) == 1
(spec_pos_logprob_token_id,
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
assert spec_pos_logprob.rank == -1
assert spec_pos_logprob.logprob == 0.0
if isinstance(spec_pos_logprob_token_id, torch.Tensor):
spec_pos_logprob_token_id = spec_pos_logprob_token_id.item()
assert spec_pos_logprob_token_id in baseline_pos_logprobs
def _clean_torchair_cache():
cache_path = Path.cwd() / '.torchair_cache'
if cache_path.exists() and cache_path.is_dir():
shutil.rmtree(cache_path)
def run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size: int,
max_output_len: int,
seed: Optional[int] = 0,
temperature: float = 0.0,
disable_seed: bool = False,
ignore_eos: bool = True,
ensure_all_accepted: bool = False,
expected_acceptance_rate: Optional[float] = None,
logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None,
disable_logprobs: bool = False):
org_args = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**baseline_llm_kwargs,
}
sd_args = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**test_llm_kwargs,
}
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
if disable_seed:
seed = None
sampling_params = SamplingParams(temperature=temperature,
max_tokens=max_output_len,
seed=seed,
ignore_eos=ignore_eos,
logprobs=logprobs,
prompt_logprobs=prompt_logprobs)
# TODO current torchair graph mode needs clean torchair cache.
# if do not clean, it will raise error
additional_config = common_llm_kwargs.get("additional_config")
enable_graph_mode = additional_config.get(
"enable_graph_mode") if additional_config else False
with vllm_runner(**org_args) as vllm_model:
if enable_graph_mode:
_clean_torchair_cache()
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
with vllm_runner(**sd_args) as vllm_model:
if enable_graph_mode:
_clean_torchair_cache()
if ensure_all_accepted or expected_acceptance_rate is not None:
# Force log interval to be 0 to catch all metrics.
stat_logger = vllm_model.model.llm_engine.stat_loggers[
'prometheus']
stat_logger.local_interval = -100
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
if ensure_all_accepted or expected_acceptance_rate is not None:
acceptance_rate = (stat_logger.metrics.
gauge_spec_decode_draft_acceptance_rate.labels(
**stat_logger.labels)._value.get())
if ensure_all_accepted:
assert True
# FIXME: ci fails to log acceptance rate.
# It works locally.
# assert acceptance_rate == 1.0
if expected_acceptance_rate is not None:
assert acceptance_rate >= expected_acceptance_rate - 1e-2
# Only pass token entries, not the logprobs
check_outputs_equal(outputs_0_lst=[out[0:2] for out in org_outputs],
outputs_1_lst=[out[0:2] for out in sd_outputs],
name_0="org",
name_1="sd")
# Check logprobs if requested
if logprobs is not None or prompt_logprobs is not None:
check_logprobs_correctness(sd_outputs, org_outputs, disable_logprobs)

View File

@@ -1,450 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_medusa_correctness.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.
With those tests, we can say at least, Medusa would not break the
correctess for the target model outputs.
"""
import os
import pytest
from tests.singlecard.spec_decode.e2e.conftest import \
run_equality_correctness_test
from tests.singlecard.spec_decode.utils import maybe_enable_chunked_prefill
# main model
# lmsys/vicuna-7b-v1.3 was to be used but it's causing
# OOM in CI pipeline, so using a smaller model.
MAIN_MODEL = "JackFram/llama-68m"
# speculative model
SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random"
# max number of speculative tokens: this corresponds to
# num_heads in the config.json of the speculator model.
MAX_SPEC_TOKENS = 5
# precision
# TODO: The vLLM here uses float32, but some op on the vllm-ascend
# do not support float32, such as ROPE, When it is fixed, it is
# recommended to change this to float32 to keep it consistent
# with vLLM.
PRECISION = "float16"
PREFILL_CHUNK_SIZE = [
-1,
# TODO:enable chunked prefill when it is supported
# 32
]
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int, prefill_chunk_size: int):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": False,
},
},
{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": True,
},
},
])
@pytest.mark.parametrize("output_len", [
8,
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int, logprobs: int,
prefill_chunk_size: int):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.skipif(True, reason="Open it when graph mode ready.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"enforce_eager": False,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
def test_medusa_e2e_greedy_correctness_cuda_graph(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int, prefill_chunk_size: int):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.skipif(True, reason="Open it when preempt ready.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 16,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
def test_medusa_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int, prefill_chunk_size: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": k,
},
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
def test_medusa_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int, prefill_chunk_size: int):
"""Verify that medusa speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_by_batch_size": 4,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
output_len: int, seed: int,
prefill_chunk_size: int):
"""Verify that medusa speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_by_batch_size": 4,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int, prefill_chunk_size: int):
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
if __name__ == "__main__":
import pytest
pytest.main([__file__])

View File

@@ -1,560 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_mlp_correctness.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.
With those tests, we can say at least, MLPSpeculator would not break the
correctness for the target model outputs.
"""
import pytest
from vllm.model_executor.layers.vocab_parallel_embedding import \
pad_vocab_size # noqa: F401
from tests.singlecard.spec_decode.e2e.conftest import \
run_equality_correctness_test
from tests.singlecard.spec_decode.utils import maybe_enable_chunked_prefill
# main model
MAIN_MODEL = "JackFram/llama-160m"
# speculative model
SPEC_MODEL = "ibm-ai-platform/llama-160m-accelerator"
# max. number of speculative tokens: this corresponds to
# n_predict in the config.json of the speculator model.
MAX_SPEC_TOKENS = 3
PREFILL_CHUNK_SIZE_1 = [
-1,
# TODO:enable chunked prefill when it is supported
# 4
]
PREFILL_CHUNK_SIZE_2 = [
-1,
# TODO:enable chunked prefill when it is supported
# 32
]
# precision
# TODO: The vLLM here uses float32, but some op on the vllm-ascend
# do not support float32, such as ROPE, When it is fixed, it is
# recommended to change this to float32 to keep it consistent
# with vLLM.
PRECISION = "float16"
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [4, 32])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_2)
def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int, prefill_chunk_size: int):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
"disable_logprobs": False,
},
},
{
"speculative_config": {
"model": SPEC_MODEL,
"disable_logprobs": True,
},
},
])
@pytest.mark.parametrize("output_len", [8])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int, prefill_chunk_size: int):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
# NOTE Test is sensitive enough st if we don't enable chunked prefill
# scheduling on baseline too, we get slightly different logprobs, ending
# up sampling different tokens at the tail (ie top tokens don't change).
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
},
},
])
@pytest.mark.parametrize("output_len", [2048])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
prefill_chunk_size: int, seed: int):
"""Verify acceptance rate with different batch size and large output
length."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
temperature=0.0,
seed=seed,
expected_acceptance_rate=0.48)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
# Speculative config
"speculative_config": {
"model": SPEC_MODEL,
},
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
@pytest.mark.parametrize("output_len", [64])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("temperature", [1.0])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
temperature: float,
prefill_chunk_size: int, seed: int):
"""Verify seeded runs produce the same output."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
temperature=temperature,
seed=seed)
# Ensure this same test does fail if we _don't_ include per-request seeds
with pytest.raises(AssertionError):
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
temperature=temperature,
seed=seed,
disable_seed=True)
@pytest.mark.skipif(True, reason="Open it when preempt ready.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 16,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
},
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
prefill_chunk_size: int, seed: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.skipif(True, reason="Open it when preempt ready.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 16,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"model": SPEC_MODEL,
},
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
def test_mlp_e2e_greedy_correctness_with_padding(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
prefill_chunk_size: int, seed: int):
"""Verify greedy equality when the vocab dimension is padded
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
# Default pad_to is 64, test model has vocab_size of 32000
def patched_pad_vocab_size(vocab_size, pad_to=None):
return pad_vocab_size(vocab_size, pad_to=32064)
# NOTE: Compared with vLLM, the patch method has been modified
pad_vocab_size = patched_pad_vocab_size # noqa: F811
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_config": {
"model": SPEC_MODEL,
"num_speculative_tokens": k,
},
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
@pytest.mark.parametrize("seed", [1])
def test_mlp_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
prefill_chunk_size: int, seed: int, output_len: int):
"""Verify that mlp speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"disable_by_batch_size": 4,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
# Speculative decoding is disabled when sequences reach decoding and the batch
# consists of single-token requests. Hence we set `max_num_seqs`
# >= `speculative_disable_by_batch_size` to test feature interaction.
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
@pytest.mark.parametrize("seed", [1])
def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
prefill_chunk_size: int, seed: int,
output_len: int):
"""Verify that mlp speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": MAIN_MODEL,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"model": SPEC_MODEL,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
@pytest.mark.parametrize("seed", [1])
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, prefill_chunk_size: int, seed: int):
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)

View File

@@ -1,457 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_mtp_correctness.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.
With those tests, we can say at least, mtp would not break the
correctess for the target model outputs.
"""
import os
import pytest
from .conftest import run_equality_correctness_test
# NOTE both main model and MTP are bfloat16
FLOAT_MODEL = "wemaster/deepseek_mtp_main_random_bf16"
# NOTE main model is w8a8, MTP is bfloat16
QUANT_MODEL = "wemaster/deepseek_mtp_main_random_w8a8_part"
# TODO when msmodelslim can quantify both main and MTP model
# This UT should use w8a8 fully weights.
# max. number of speculative tokens: this corresponds to
# num_nextn_predict_layers in the config.json of the speculator model.
MAX_SPEC_TOKENS = 1
# precision
PRECISION = "bfloat16"
os.environ["VLLM_USE_MODELSCOPE"] = "True"
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
reason="mtp is not supported on v1")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": FLOAT_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.85
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.skipif(True, reason="quant model is not ready.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": QUANT_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_quant_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
reason="mtp is not supported on v1")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": FLOAT_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": False,
},
},
{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_logprobs": True,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int):
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
output_len,
seed,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.skipif(True, reason="torchair ut can not clean mem.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"additional_config": {
'enable_graph_mode': True,
},
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": FLOAT_MODEL,
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_greedy_correctness_torchair_graph(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality with torchair graph enabled and different
batch sizes using bfloat16 weights."""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.skipif(True, reason="quant model is not ready.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"additional_config": {
'enable_graph_mode': True,
},
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model_name": QUANT_MODEL,
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_quant_greedy_correctness_torchair_graph(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality with torchair graph enabled and different
batch sizes using quant weights."""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
reason="mtp is not supported on v1")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 16,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": FLOAT_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_mtp_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
reason="mtp is not supported on v1")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": FLOAT_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_config": {
"num_speculative_tokens": k,
},
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mtp_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that mtp speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
reason="mtp is not supported on v1")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Precision
"dtype": PRECISION,
# Main model
"model_name": FLOAT_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.8
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"num_speculative_tokens": MAX_SPEC_TOKENS,
"disable_by_batch_size": 4
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mtp_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that mtp speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size, output_len, seed)
if __name__ == "__main__":
import pytest
pytest.main([__file__])

View File

@@ -1,404 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_ngram_correctness.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding,
and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775.
Since there is no model is needed for generate the proposal, we could make
the testcase much simpler than drafter multi-step one.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various ngram sizes / speculative sizes
With those tests, we can say at least, ngram spec would not break the correctess
for the target model outputs.
"""
import pytest
from tests.singlecard.spec_decode.e2e.conftest import \
run_equality_correctness_test
from tests.singlecard.spec_decode.utils import maybe_enable_chunked_prefill
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"model_name": "JackFram/llama-68m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": False,
},
},
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": True,
},
},
])
@pytest.mark.parametrize("output_len", [
256,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize(
"prefill_chunk_size",
[
-1,
# TODO:enable chunked prefill when it is supported
# 4
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
prefill_chunk_size: int, seed: int):
"""Verify greedy equality on a tiny model with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"model_name": "JackFram/llama-68m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_logprobs": False,
},
},
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_logprobs": True,
},
},
])
@pytest.mark.parametrize("output_len", [
8,
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int, seed: int,
logprobs: int):
"""Verify greedy equality on a tiny model with different batch size."""
run_equality_correctness_test(
vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0,
logprobs=logprobs,
prompt_logprobs=logprobs,
disable_logprobs=test_llm_kwargs["speculative_config"]
["disable_logprobs"])
@pytest.mark.skipif(True, reason="Open it when preempt ready.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 16,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"model_name": "JackFram/llama-160m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
},
"enable_chunked_prefill": False,
},
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": True,
},
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
256,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
temperature=0,
seed=seed)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": k,
"prompt_lookup_max": 3,
},
}
# Try a range of common k, as well as large speculation.
for k in [1, 3, 5]
] + [
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": k,
"prompt_lookup_max": 1,
},
}
# Try a range of common k, as well as large speculation.
for k in [1, 3, 5]
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_by_batch_size": 4
},
},
{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_by_batch_size": 4,
"disable_mqa_scorer": True,
},
"enable_chunked_prefill": False,
# FIXME: enable me when chunked prefill is available
# "max_num_batched_tokens": 4,
"max_num_seqs": 4
}
])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"speculative_config": {
"method": "ngram",
"num_speculative_tokens": 5,
"prompt_lookup_max": 3,
"disable_mqa_scorer": True,
},
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_scorer(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
"""Verify that ngram speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs,
test_llm_kwargs,
batch_size,
max_output_len=output_len,
seed=seed,
temperature=0.0)

View File

@@ -1,156 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import os
import random
from typing import Any
import pytest
from vllm import LLM, SamplingParams
os.environ["VLLM_USE_MODELSCOPE"] = "True"
@pytest.fixture
def test_prompts():
prompt_types = ["repeat", "sentence"]
num_prompts = 100
prompts = []
random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
# Generate a mixed batch of prompts, some of which can be easily
# predicted by n-gram matching and some which likely cannot.
for kind in random_prompt_type_choices:
word_choices = ["test", "temp", "hello", "where"]
word = random.choice(word_choices)
if kind == "repeat":
prompt = f"""
please repeat the word '{word}' 10 times.
give no other output than the word at least ten times in a row,
in lowercase with spaces between each word and without quotes.
"""
elif kind == "sentence":
prompt = f"""
please give a ten-word sentence that
uses the word {word} at least once.
give no other output than that simple sentence without quotes.
"""
else:
raise ValueError(f"Unknown prompt type: {kind}")
prompts.append([{"role": "user", "content": prompt}])
return prompts
@pytest.fixture
def sampling_config():
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
@pytest.fixture
def model_name():
return "LLM-Research/Meta-Llama-3.1-8B-Instruct"
def eagle_model_name():
return "vllm-ascend/EAGLE-LLaMA3.1-Instruct-8B"
def eagle3_model_name():
return "vllm-ascend/EAGLE3-LLaMA3.1-Instruct-8B"
def test_ngram_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
):
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using ngram speculative decoding.
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
ref_llm = LLM(model=model_name, max_model_len=1024)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
spec_llm = LLM(
model=model_name,
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
max_model_len=1024,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 70% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.7 * len(ref_outputs))
del spec_llm
@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
def test_eagle_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
use_eagle3: bool,
):
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding.
'''
pytest.skip("Not current support for the test.")
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
ref_llm = LLM(model=model_name, max_model_len=2048)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
spec_model_name = eagle3_model_name(
) if use_eagle3 else eagle_model_name()
spec_llm = LLM(
model=model_name,
trust_remote_code=True,
speculative_config={
"method": "eagle3" if use_eagle3 else "eagle",
"model": spec_model_name,
"num_speculative_tokens": 3,
"max_model_len": 2048,
},
max_model_len=2048,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
del spec_llm

View File

@@ -1,105 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/spec_decode/test_dynamic_spec_decode.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from unittest.mock import MagicMock, patch
import pytest
import torch
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from tests.singlecard.spec_decode.test_utils import mock_spec_decode_sampler
from tests.singlecard.spec_decode.utils import create_batch, mock_worker
@pytest.mark.parametrize('queue_size', [4])
@pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('k', [1])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
acceptance_sampler_method: str):
"""Verify that speculative tokens are disabled when the batch size
exceeds the threshold.
"""
disable_by_batch_size = 3
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector,
disable_by_batch_size=disable_by_batch_size)
exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
running_queue_size=queue_size)
if queue_size > disable_by_batch_size:
with patch.object(worker,
'_run_no_spec',
side_effect=ValueError(exception_secret)), \
pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
# When the batch size is larger than the threshold,
# we expect no speculative tokens (0).
expected_num_spec_tokens = None if queue_size < disable_by_batch_size else 0
assert seq_group_metadata_list[
0].num_speculative_tokens == expected_num_spec_tokens
draft_worker.sampler_output.side_effect = ValueError(exception_secret)
proposer = Top1Proposer(
worker=draft_worker,
device='cpu', # not used
vocab_size=100, # not used
# Must be long enough to avoid being skipped due to length.
max_proposal_len=1024,
)
if queue_size < disable_by_batch_size:
# Should raise exception when executing the mocked draft model.
with pytest.raises(ValueError, match=exception_secret):
proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
else:
# Should not execute the draft model because spec decode is disabled
# for all requests. Accordingly, the proposal length should be 0.
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert proposals.proposal_lens.tolist() == [0] * batch_size

View File

@@ -1,846 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/spec_decode/test_multi_step_worker.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import random
from unittest.mock import MagicMock
import pytest
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob,
get_all_seq_ids)
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from tests.singlecard.spec_decode.utils import (
assert_logprobs_dict_allclose, create_batch,
create_seq_group_metadata_from_prompts, create_worker,
patch_execute_model_with_seeds, zero_kv_cache)
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
from vllm_ascend.worker.worker import NPUWorker
@pytest.mark.parametrize('num_steps', list(range(1, 17)))
def test_assert_enough_kv_space(num_steps: int):
"""Test that the multi step worker checks for sufficient space in the KV
cache. It should throw if it cannot run all the steps.
"""
block_size = 16
num_gpu_blocks = 2048 // block_size
prompts = [
list(range(block_size * 3)),
list(range(block_size * 2)),
]
prev_output_tokens = [
list(range(block_size * 1)),
list(range(block_size * 2)),
]
final_prompt_lens = [
len(prompt + output) + num_steps
for prompt, output in zip(prompts, prev_output_tokens)
]
inputs = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens,
continuations=prev_output_tokens)
assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access
worker = MagicMock()
worker.model_runner.block_size = block_size
for seq_group_metadata in inputs:
original_block_tables = seq_group_metadata.block_tables
# No exception.
assert_enough_kv_space(worker, inputs, num_steps)
seq_group_metadata.block_tables = {
seq_id: []
for seq_id, physical_blocks in original_block_tables.items()
}
# Expect exception.
with pytest.raises(ValueError,
match='times but found insufficient KV space for'):
assert_enough_kv_space(worker, inputs, num_steps)
seq_group_metadata.block_tables = original_block_tables
@torch.inference_mode()
def test_same_output_for_single_step():
"""Verify the multi step worker produces the same output as the normal
worker for num_steps=1.
"""
seed = 100
model_name = 'JackFram/llama-68m'
block_size = 32
num_gpu_blocks = 2048 // block_size
multi_step_worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
worker = create_worker(
NPUWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
# multi_step_worker.model_runner = worker.model_runner
# multi_step_worker.cache_engine = worker.cache_engine
num_steps = 1
prompts = [
[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10],
]
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
multi_step_seq_group = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
actual_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=multi_step_seq_group),
sample_len=num_steps,
seq_ids_with_bonus_token_in_last_step=set())
assert len(actual_output) == num_steps
actual_output = actual_output[0]
single_step_seq_group = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)
zero_kv_cache(worker.cache_engine)
set_random_seed(seed)
expected_output = worker.execute_model(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=single_step_seq_group))[0]
actual_token_ids = [
output.samples[0].output_token for output in actual_output
]
actual_logprobs = [output.samples[0].logprobs for output in actual_output]
expected_token_ids = [
output.samples[0].output_token for output in expected_output
]
expected_logprobs = [
output.samples[0].logprobs for output in expected_output
]
assert actual_token_ids == expected_token_ids
print(f'{actual_logprobs=}')
print(f'{expected_logprobs=}')
assert_logprobs_dict_allclose(actual_logprobs, expected_logprobs)
@torch.inference_mode()
def test_same_output_for_multi_step():
"""Verify the multi-step worker produces the same output as the normal
worker when num_steps > 1. This test runs the multi-step worker once, and
then runs the worker num_steps times, and compares the output.
"""
seed = 100
model_name = 'JackFram/llama-68m'
block_size = 16
num_gpu_blocks = 2048 // block_size
multi_step_worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
worker = create_worker(
NPUWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
# Make sure we go over the block boundary.
num_steps = block_size + 1
random.seed(seed)
prompts = [[
random.randint(0, 1000) for _ in range(random.randint(10, 20))
] for _ in range(10)]
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
multi_step_worker.execute_model = patch_execute_model_with_seeds(
multi_step_worker, rand_seeds)
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
continuations = [[1] for _ in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
# Run multi-step.
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
multi_step_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=num_steps,
seq_ids_with_bonus_token_in_last_step=set())
# Run single-step repeatedly.
zero_kv_cache(worker.cache_engine)
single_step_output: list[SamplerOutput] = []
continuations = [[1] for _ in prompts]
set_random_seed(seed)
for _ in multi_step_output:
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
single_step_output.extend(
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list)))
# Append output tokens to new sequence data.
for i, seq_group_output in enumerate(single_step_output[-1]):
continuations[i].append(seq_group_output.samples[0].output_token)
# Get token ids and logprobs for comparison.
multi_step_output_logprobs: list[list[dict[int,
Logprob]]] = [[]
for _ in prompts]
single_step_output_logprobs: list[list[dict[int,
Logprob]]] = [[]
for _ in prompts]
multi_step_output_token_ids: list[list[int]] = [[] for _ in prompts]
single_step_output_token_ids: list[list[int]] = [[] for _ in prompts]
for i, _ in enumerate(prompts):
for multi_step, single_step in zip(multi_step_output,
single_step_output):
multi_step_output_token_ids[i].append(
multi_step[i].samples[0].output_token)
single_step_output_token_ids[i].append(
single_step[i].samples[0].output_token)
multi_step_output_logprobs[i].append(
multi_step[i].samples[0].logprobs)
single_step_output_logprobs[i].append(
single_step[i].samples[0].logprobs)
# Print per-sequence token ids
for i, (multi_step_tokens, single_step_tokens) in enumerate(
zip(multi_step_output_token_ids, single_step_output_token_ids)):
print(f'{i=} {multi_step_tokens=}')
print(f'{i=} {single_step_tokens=}')
print(f'{i=} equal {multi_step_tokens == single_step_tokens}')
# Assert token ids are equal.
for multi_step_tokens, single_step_tokens in zip(
multi_step_output_token_ids, single_step_output_token_ids):
assert multi_step_tokens == single_step_tokens
# Assert logprobs are equal.
for multi_step_logprobs, single_step_logprobs in zip(
multi_step_output_logprobs, single_step_output_logprobs):
assert_logprobs_dict_allclose(multi_step_logprobs,
single_step_logprobs)
@torch.inference_mode()
def test_multi_step_with_batch_expansion_correct_output():
"""
In this test we verify that the MultiStepWorker is able to handle bonus
tokens correctly. The test verifies that if a sequence has a
bonus token then the MultiStepWorker is able to expand the batch by adding
new sequences corresponding to the sequences with bonus tokens. The
expanded batch is then used for predicting the next tokens.
"""
seed = 100
model_name = 'JackFram/llama-68m'
block_size = 16
num_gpu_blocks = 2048 // block_size
batch_size = 128
multi_step_worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
multi_step_worker.set_include_gpu_probs_tensor()
worker = create_worker(
NPUWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
random.seed(seed)
prompts = [[0] for _ in range(batch_size)]
num_steps = 2
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
multi_step_worker.execute_model = patch_execute_model_with_seeds(
multi_step_worker, rand_seeds)
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
# Create the test continuations
continuations = [[random.randint(0, 1000)] for _ in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
# Run single-step twice to generate 2 tokens. This
# will simulate the bonus token case with the second token
# being the bonus token.
zero_kv_cache(worker.cache_engine)
single_step_output: list[SamplerOutput] = []
set_random_seed(seed)
for _ in range(num_steps):
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
single_step_output.extend(
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list)))
# Append output tokens to new sequence data.
for i, seq_group_output in enumerate(single_step_output[-1]):
continuations[i].append(seq_group_output.samples[0].output_token)
# Create continuations for the MultiStepWorker. The continuations have
# 2 tokens in order to simulate the bonus token case.
multi_step_continuations = []
for continuation in continuations:
multi_step_continuations.append(continuation[:2])
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)
# Run multi-step and verify that the third token prediction is accurate
# for all sequences.
zero_kv_cache(multi_step_worker.cache_engine)
all_seq_ids = {i for i in range(batch_size)}
multi_step_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=1,
seq_ids_with_bonus_token_in_last_step=all_seq_ids)
for index, output in enumerate(multi_step_output[-1].outputs):
assert (continuations[index][-1] == output.samples[0].output_token)
@torch.inference_mode()
def test_multi_step_with_batch_expansion_incorrect_output():
"""
Tests the MultiStepWorker's ability to handle batch expansion with bonus
tokens in a negative case scenario. This test provides the MultiStepWorker
with a batch containing sequences with bonus tokens but specifies the
sequence IDs with bonus tokens incorrectly. The test verifies that the
MultiStepWorker generates correct tokens for the sequences where the
sequence ID is specified correctly and incorrect tokens for those where
the sequence ID is specified incorrectly.
"""
seed = 100
model_name = 'JackFram/llama-68m'
block_size = 16
num_gpu_blocks = 2048 // block_size
batch_size = 128
multi_step_worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
multi_step_worker.set_include_gpu_probs_tensor()
worker = create_worker(
NPUWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
random.seed(seed)
prompts = [[0] for _ in range(batch_size)]
num_steps = 2
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
multi_step_worker.execute_model = patch_execute_model_with_seeds(
multi_step_worker, rand_seeds)
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
# Create the test continuations
continuations = [[random.randint(0, 1000)] for _ in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
# Run single-step twice to generate 2 tokens. This
# will simulate the bonus token case with the second token
# being the bonus token.
zero_kv_cache(worker.cache_engine)
single_step_output: list[SamplerOutput] = []
set_random_seed(seed)
for _ in range(num_steps):
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_prompt_lens=final_prompt_lens)
single_step_output.extend(
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list)))
# Append output tokens to new sequence data.
for i, seq_group_output in enumerate(single_step_output[-1]):
continuations[i].append(seq_group_output.samples[0].output_token)
# Create continuations for the MultiStepWorker. The continuations have
# 2 tokens in order to simulate the bonus token case.
multi_step_continuations = []
for continuation in continuations:
multi_step_continuations.append(continuation[:2])
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)
# Run multi-step. In this run INCORRECTLY specify that only the odd number
# sequences have bonus tokens. Verify that with this setting the third token
# prediction is accurate only for the odd numbered sequences. Also verify
# that the prediction might be wrong for some of the even numbered
# sequences.
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
odd_seq_ids = {i for i in range(batch_size) if i % 2 != 0}
multi_step_output, _ = multi_step_worker.sampler_output(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=1,
seq_ids_with_bonus_token_in_last_step=odd_seq_ids)
num_mismatch = 0
for index, output in enumerate(multi_step_output[-1].outputs):
if (index % 2) != 0:
assert (continuations[index][-1] == output.samples[0].output_token)
elif (continuations[index][-1] != output.samples[0].output_token):
num_mismatch += 1
# The prediction is accurate for some of the sequences even without proper
# handling of the bonus tokens. Hence verify that the number of sequences
# for which there is a mismatch is > 0.
assert (num_mismatch > 0)
@torch.inference_mode()
@pytest.mark.parametrize('num_steps', [1, 2, 3, 4])
def test_multi_step_correct_kvcache(num_steps):
"""Verify that the KV cache of the draft model
is correctly updated for sequences with bonus token.
"""
seed = 100
model_name = "JackFram/llama-68m"
block_size = 16
num_gpu_blocks = 2048 // block_size
batch_size = 1
dtype = 'float16'
multi_step_worker = create_worker(MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
dtype=dtype)
multi_step_worker.set_include_gpu_probs_tensor()
worker = create_worker(NPUWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
dtype=dtype)
prompts = [[0] for _ in range(batch_size)]
# Already generate two tokens for the sequence
# so that we can simulate the bonus token case
multi_step_continuations = [[
random.randint(0, 1000),
random.randint(0, 1000)
] for _ in prompts]
final_prompt_lens = [len(prompt) + 2 + num_steps for prompt in prompts]
seq_ids_with_bonus_token_in_last_step = set(range(batch_size))
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)
# Run multi-step.
zero_kv_cache(multi_step_worker.cache_engine)
multi_step_worker.sampler_output(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=num_steps,
seq_ids_with_bonus_token_in_last_step=
seq_ids_with_bonus_token_in_last_step)
# Run single-step repeatedly.
zero_kv_cache(worker.cache_engine)
# Generate the kv cache for the bonus token first
single_step_continuations = [c[:1] for c in multi_step_continuations]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=single_step_continuations,
final_prompt_lens=final_prompt_lens)
single_step_output = worker.execute_model(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list))
for _ in range(num_steps):
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)
single_step_output = worker.execute_model(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list))
for i, seq_group_output in enumerate(single_step_output[-1]):
multi_step_continuations[i].append(
seq_group_output.samples[0].output_token)
# Verify that the KV cache of the single-step and
# multi-step workers are the same.
single_step_gpu_cache = worker.cache_engine[0].gpu_cache
multi_step_gpu_cache = multi_step_worker.cache_engine[0].gpu_cache
num_layers = len(single_step_gpu_cache)
allclose = lambda a, b: torch.allclose( # noqa: E731
a.npu(), b.npu(), rtol=1e-2, atol=1e-2)
for i in range(num_layers):
assert allclose(single_step_gpu_cache[i][0],
multi_step_gpu_cache[i][0])
assert allclose(single_step_gpu_cache[i][1],
multi_step_gpu_cache[i][1])
@torch.inference_mode()
def test_draft_proposals_full_speculation_len():
"""Verify Top1Proposer correctly handles case where all sequences
can speculate.
"""
k = 10
batch_size = 32
vocab_size = 32_000
device = 'npu:0'
draft_worker = MagicMock()
proposer = Top1Proposer(
worker=draft_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=2048,
)
draft_worker.sampler_output.return_value = [
SamplerOutput(
outputs=[],
sampled_token_probs=torch.rand(batch_size,
vocab_size,
device=device,
dtype=torch.float32),
logprobs=torch.rand(batch_size,
vocab_size,
device=device,
dtype=torch.float32),
sampled_token_ids=torch.randint(low=0,
high=vocab_size,
size=(batch_size, ),
device=device,
dtype=torch.long),
) for _ in range(k)
], True
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
assert proposals.proposal_lens.shape == torch.Size([batch_size])
assert proposals.proposal_lens.tolist() == [k for _ in range(batch_size)]
@torch.inference_mode()
def test_draft_proposals_no_speculations():
"""Verify Top1Proposer correctly handles case where no sequences
can speculate.
"""
k = 10
batch_size = 32
vocab_size = 32_000
device = 'npu:0'
prompt_len = 10
draft_worker = MagicMock()
proposer = Top1Proposer(
worker=draft_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=prompt_len + k - 1,
)
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
prompt_len=prompt_len)
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
assert proposals.proposal_lens.shape == torch.Size([batch_size])
assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)]
@torch.inference_mode()
def test_draft_proposals_mixed_k():
"""Verify Top1Proposer correctly handles case some sequences can
speculate and some can't.
"""
k = 10
batch_size = 32
vocab_size = 32_000
device = 'npu:0'
small_prompt_len = 5
long_prompt_len = 10
prev_output_token_len = 20
expected_num_proposal_seqs = 6
expected_num_no_proposal_seqs = batch_size - expected_num_proposal_seqs
prompt_len = [
small_prompt_len for _ in range(expected_num_proposal_seqs - 1)
] + [long_prompt_len
for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len]
draft_worker = MagicMock()
proposer = Top1Proposer(
worker=draft_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=long_prompt_len + prev_output_token_len + k - 1,
)
draft_worker.sampler_output.return_value = [
SamplerOutput(
outputs=[],
sampled_token_probs=torch.rand(expected_num_proposal_seqs,
vocab_size,
device=device,
dtype=torch.float32),
logprobs=torch.rand(expected_num_proposal_seqs,
vocab_size,
device=device,
dtype=torch.float32),
sampled_token_ids=torch.randint(
low=0,
high=vocab_size,
size=(expected_num_proposal_seqs, ),
device=device,
dtype=torch.long),
) for _ in range(k)
], True
seq_group_metadata_list, _, _ = create_batch(
batch_size,
k,
prompt_len=prompt_len,
prev_output_token_len=prev_output_token_len,
)
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k),
seq_ids_with_bonus_token_in_last_step=set())
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
assert proposals.proposal_lens.shape == torch.Size([batch_size])
assert proposals.proposal_lens.tolist() == [
k for _ in range(expected_num_proposal_seqs - 1)
] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k]
@torch.inference_mode()
def test_use_draft_model_runner_advance_step():
"""Verify that draft model runner triggers advance step
when applicable.
"""
seed = 100
model_name = 'JackFram/llama-68m'
k = 5
batch_size = 32
block_size = 32
num_gpu_blocks = 2048 // block_size
worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
# Mock "_gpu_advance_step" to raise an exception when called.
exception_secret = "artificial stop"
worker.model_runner._gpu_advance_step = MagicMock()
worker.model_runner._gpu_advance_step.side_effect = ValueError(
exception_secret)
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
block_size=block_size,
num_gpu_blocks=num_gpu_blocks)
# Fallback (should not call) when num_steps=1.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
num_steps=1)
worker.execute_model(execute_model_req=execute_model_req)
# Expect exception if _gpu_advance_step is called.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
num_steps=k)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
call_args_list = worker.model_runner._gpu_advance_step.call_args_list
assert len(call_args_list) == 1
@torch.inference_mode()
def test_expand_execute_model_request_sync_with_expand_hidden_states():
"""
In this test we verify that the logic for expanding the
seq_group_metadata_list remains in sync with the expansion logic of
the HiddenStates in _expand_execute_model_request.
"""
k = 5
batch_size = 16
seq_with_bonus_token_in_last_step = [1, 3, 8, 10, 13, 15]
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
execute_model_request = ExecuteModelRequest(
seq_group_metadata_list,
previous_hidden_states=HiddenStates(
torch.arange(batch_size), seq_group_metadata_list,
torch.arange(batch_size, 2 * batch_size)))
expanded_execute_model_request, orig_seq_group_ids = MultiStepWorker.\
_expand_execute_model_request(execute_model_request,
seq_with_bonus_token_in_last_step)
all_seq_ids = torch.tensor(
get_all_seq_ids(
expanded_execute_model_request.seq_group_metadata_list))
ref_expanded_hidden_states = all_seq_ids + batch_size
ref_expanded_hidden_states[orig_seq_group_ids] -= batch_size
assert (ref_expanded_hidden_states == expanded_execute_model_request.
previous_hidden_states.hidden_states).all().item()

View File

@@ -1,237 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/spec_decode/test_ngram_worker.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from tests.singlecard.spec_decode.utils import (
create_seq_group_metadata_from_prompts, create_worker)
def test_ngram_algo_correctness_for_single_no_match():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario cannot find any candidate in one single batch
"""
block_size = 32
num_gpu_blocks = 2048 // block_size
seed = 100
model_name = 'JackFram/llama-68m'
vocab_size = 32_000
device = 'npu:0'
ngram_worker = create_worker(
NGramWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
proposer = Top1Proposer(
worker=ngram_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=20,
)
# set ngram window [1, 3], which is window=1/2/3
ngram_worker.set_ngram_window_size(1, 3)
prompts = [
# shall find no candidate
[1, 2, 3, 4, 5, 6, 7],
]
proposal_len = 5
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len),
seq_ids_with_bonus_token_in_last_step=None)
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([1, proposal_len])
assert proposals.proposal_probs.shape[:-1] == torch.Size([1, proposal_len])
assert proposals.proposal_lens.shape == torch.Size([1])
assert proposals.proposal_lens.tolist() == [0]
def test_ngram_algo_correctness_for_batches_not_match_all():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario find some candidate not full in batchs
"""
block_size = 32
num_gpu_blocks = 2048 // block_size
seed = 100
model_name = 'JackFram/llama-68m'
vocab_size = 32_000
device = 'npu:0'
ngram_worker = create_worker(
NGramWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
proposer = Top1Proposer(
worker=ngram_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=20,
)
# set ngram window [1, 3], which is window=1/2/3
ngram_worker.set_ngram_window_size(1, 3)
prompts = [
# shall find no candidate
[1, 2, 3, 4, 5, 6, 7],
# shall find candidate 12,13,14,15,16
[11, 12, 13, 14, 15, 16, 11],
# shall find candidate 23,24,25,26,21
[21, 21, 22, 23, 24, 25, 26, 21, 22],
# shall find candidate 34,35,36,37,38
[31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
# shall find no candidate as exceed max_proposal_len
[
31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 33, 34, 35, 36, 37,
38, 31, 32, 33
],
]
proposal_len = 5
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)
for sg in seq_group_metadata_list:
sg.is_prompt = False
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len),
seq_ids_with_bonus_token_in_last_step=None)
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([5, proposal_len])
assert proposals.proposal_probs.shape[:-1] == torch.Size([5, proposal_len])
assert proposals.proposal_lens.shape == torch.Size([5])
# the first sequence has no match so proposal_len should be overwritten to 0
assert proposals.proposal_lens.tolist(
) == [0] + [proposal_len for _ in range(3)] + [0]
for i in range(proposal_len):
assert proposals.proposal_token_ids[0][i] == -1
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 1]
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 3]
assert proposals.proposal_token_ids[3][i] == prompts[3][i + 5]
assert proposals.proposal_token_ids[4][i] == -1
def test_ngram_algo_correctness_for_batches_match_all():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario find candidate in all batches
"""
block_size = 32
num_gpu_blocks = 2048 // block_size
seed = 100
model_name = 'JackFram/llama-68m'
vocab_size = 32_000
device = 'npu:0'
ngram_worker = create_worker(
NGramWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
proposer = Top1Proposer(
worker=ngram_worker,
device=device,
vocab_size=vocab_size,
max_proposal_len=20,
)
# set ngram window [0, 3], which is window=1/2/3
ngram_worker.set_ngram_window_size(1, 3)
prompts = [
# shall find candidate 12,13,14,15,16
[11, 12, 13, 14, 15, 16, 11],
# shall find candidate 23,24,25,26,21
[21, 21, 22, 23, 24, 25, 26, 21, 22],
# shall find candidate 34,35,36,37,38
[31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
]
proposal_len = 5
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)
# Normally drafter is run on decode requests only; here we check the output
# of the ngram worker as it is the sole proposer that has no forward.
for sg in seq_group_metadata_list:
sg.is_prompt = False
proposals = proposer.get_spec_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=proposal_len),
seq_ids_with_bonus_token_in_last_step=None)
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
assert proposals.proposal_token_ids.shape == torch.Size([3, proposal_len])
assert proposals.proposal_probs.shape[:-1] == torch.Size([3, proposal_len])
assert proposals.proposal_lens.shape == torch.Size([3])
assert proposals.proposal_lens.tolist() == [proposal_len for _ in range(3)]
for i in range(proposal_len):
assert proposals.proposal_token_ids[0][i] == prompts[0][i + 1]
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 3]
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 5]

View File

@@ -1,957 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/spec_decode/test_spec_decode_worker.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import random
from collections import defaultdict
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
import torch
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SequenceOutput
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics)
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
split_num_cache_blocks_evenly)
from tests.singlecard.spec_decode.test_utils import mock_spec_decode_sampler
from tests.singlecard.spec_decode.utils import (create_batch,
create_sampler_output_list,
create_worker, mock_worker)
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
from vllm_ascend.worker.worker import NPUWorker
@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_correctly_calls_draft_model(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker calls the draft worker with correct
inputs. Everything else is mocked out.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(
draft_worker,
target_worker,
mock_spec_decode_sampler(acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector)
exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
call_args_list = draft_worker.get_spec_proposals.call_args_list
assert len(call_args_list) == 1
for args, _ in call_args_list:
actual_execute_model_data = args[0]
assert actual_execute_model_data == execute_model_req
@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_batch_expansion_correctly_calls_target_model(
k: int, batch_size: int, acceptance_sampler_method: str):
"""Verify SpecDecodeWorker calls the target model with correct
inputs with batch expansion. Everything else is mocked out.
"""
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
target_worker = mock_worker(use_spec=False)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'npu'
target_worker.device = 'npu'
set_random_seed(1)
worker = SpecDecodeWorker(
draft_worker,
target_worker,
mock_spec_decode_sampler(acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector,
disable_mqa_scorer=True)
worker.init_device()
vocab_size = 32_000
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device='npu')
proposal_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device='npu')
proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='npu') * k
seq_group_metadata_list, prompts, prev_output_tokens = create_batch(
batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens)
exception_secret = 'artificial stop'
target_worker.execute_model.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
seen_contexts: list[list[int]] = []
call_args_list = target_worker.execute_model.call_args_list
assert len(call_args_list) == 1
for _, kwargs in call_args_list:
seq_group_metadata_list = kwargs[
"execute_model_req"].seq_group_metadata_list
assert len(seq_group_metadata_list) == (k + 1) * batch_size
for seq_group_metadata in seq_group_metadata_list:
for seq_data in seq_group_metadata.seq_data.values():
seen_contexts.append(seq_data.get_token_ids())
expected_seen_contexts: list[list[int]] = []
for prompt, prev_generated, draft_tokens in zip(
prompts, prev_output_tokens, proposal_token_ids.tolist()):
for i in range(len(draft_tokens) + 1):
expected_seen_contexts.append(prompt + prev_generated +
draft_tokens[:i])
seen_contexts.sort()
expected_seen_contexts.sort()
assert expected_seen_contexts == seen_contexts
@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker calls the rejection sampler with
correct inputs. Everything else is mocked out.
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker,
vocab_size=vocab_size,
use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'npu'
target_worker.device = 'npu'
set_random_seed(1)
worker = SpecDecodeWorker(draft_worker,
target_worker,
spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device()
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device='npu')
proposal_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device='npu')
proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='npu') * k
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens)
target_token_ids = torch.randint(low=0,
high=vocab_size,
size=(1, batch_size * (k + 1)),
dtype=torch.int64,
device='npu')
target_token_probs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='npu')
target_token_logprobs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='npu')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs,
target_token_logprobs)
target_worker.execute_model.return_value = [target_output[0]]
exception_secret = 'artificial stop'
spec_decode_sampler.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
assert len(spec_decode_sampler.call_args_list) == 1
_, kwargs = spec_decode_sampler.call_args_list[0]
actual = SimpleNamespace(**kwargs)
assert torch.equal(actual.bonus_token_ids,
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
assert torch.equal(actual.target_with_bonus_probs,
target_token_probs.reshape(batch_size, k + 1, -1))
assert torch.equal(actual.draft_token_ids, proposal_token_ids)
assert torch.equal(actual.draft_probs, proposal_probs)
@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_correctly_formats_output(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker formats sampler output correctly.
Everything else is mocked out.
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker,
vocab_size=vocab_size,
use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'npu'
target_worker.device = 'npu'
set_random_seed(1)
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
worker = SpecDecodeWorker(draft_worker,
target_worker,
spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device()
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device='npu')
proposal_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device='npu')
proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='npu') * k
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens)
target_token_ids = torch.randint(low=0,
high=vocab_size,
size=(1, batch_size * (k + 1)),
dtype=torch.int64,
device='npu')
target_token_probs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='npu')
target_token_logprobs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='npu')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs,
target_token_logprobs)
target_worker.execute_model.return_value = [target_output[0]]
spec_decode_sampler_output = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k + 1),
dtype=torch.int64,
device='npu')
for i in range(batch_size):
minimum_accepted_tokens = 1
spec_decode_sampler_output[i][
-random.randint(minimum_accepted_tokens, k + 1):] = -1
spec_decode_sampler.return_value = spec_decode_sampler_output
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
expected_output = create_sampler_output_list(
token_ids=spec_decode_sampler_output.transpose(0, 1),
probs=[None for _ in range(k + 1)],
logprobs=[None for _ in range(k + 1)])
seq_ids = [
next(iter(seq_group_metadata.seq_data.keys()))
for seq_group_metadata in seq_group_metadata_list
]
actual_output_by_seq: dict[int, list[SequenceOutput]] = {
seq_id: []
for seq_id in seq_ids
}
expected_output_by_seq: dict[int, list[SequenceOutput]] = {
seq_id: []
for seq_id in seq_ids
}
for step in output:
for seq_group in step:
for sample in seq_group.samples:
seq_id = sample.parent_seq_id
actual_output_by_seq[seq_id].append(sample)
for step in expected_output:
for seq_group in step:
for sample in seq_group.samples:
seq_id = sample.parent_seq_id
expected_output_by_seq[seq_id].append(sample)
all_seen_seq_ids = set(
list(actual_output_by_seq.keys()) +
list(expected_output_by_seq.keys()))
for seq_id in all_seen_seq_ids:
actual_by_step = actual_output_by_seq[seq_id]
expected_by_step = expected_output_by_seq[seq_id]
for i in range(k + 1):
if i >= len(actual_by_step):
assert expected_by_step[i].output_token == -1
continue
assert actual_by_step[i].output_token == expected_by_step[
i].output_token
@pytest.mark.parametrize('k', [1, 2])
@pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('returns_metrics', [True, False])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker collects metrics.
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker,
vocab_size=vocab_size,
use_spec=False)
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
draft_worker.device = 'npu'
target_worker.device = 'npu'
set_random_seed(1)
worker = SpecDecodeWorker(draft_worker,
target_worker,
spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device()
proposal_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device='npu')
proposal_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device='npu')
proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='npu') * k
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids,
proposal_probs=proposal_probs,
proposal_lens=proposal_lens)
target_token_ids = torch.randint(low=0,
high=vocab_size,
size=(1, batch_size * (k + 1)),
dtype=torch.int64,
device='npu')
target_token_probs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='npu')
target_token_logprobs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='npu')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs,
target_token_logprobs)
target_worker.execute_model.return_value = [target_output[0]]
spec_decode_sampler_output = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k + 1),
dtype=torch.int64,
device='npu')
for i in range(batch_size):
minimum_accepted_tokens = 1
spec_decode_sampler_output[i][
-random.randint(minimum_accepted_tokens, k + 1):] = -1
spec_decode_sampler.return_value = spec_decode_sampler_output
mock_rejsample_metrics = MagicMock(
spec=SpecDecodeWorkerMetrics) if returns_metrics else None
metrics_collector.maybe_collect_rejsample_metrics.return_value = (
mock_rejsample_metrics)
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
call_args_list = (
metrics_collector.maybe_collect_rejsample_metrics.call_args_list)
assert len(call_args_list) == 1
args, kwargs = call_args_list[0]
assert args[0] == k or kwargs.get('k', -1) == k
@pytest.mark.parametrize('k', [0])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_k_equals_zero(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify that the SpecDecodeWorker calls the draft and target workers
when k is zero. This happens during prefill.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
sampler_output = MagicMock(spec=SamplerOutput)
sampler_output.hidden_states = None
target_worker.execute_model.return_value = [sampler_output]
draft_worker.device = 'npu'
target_worker.device = 'npu'
set_random_seed(1)
worker = SpecDecodeWorker(
proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector,
)
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
prev_output_token_len=0)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
out = worker.execute_model(execute_model_req=execute_model_req)
assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].sampled_token_probs is None, (
"expect gpu tensor references to be None")
assert out[
0].sampled_token_ids is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with(execute_model_req)
target_worker.execute_model.assert_called_once_with(execute_model_req)
@pytest.mark.parametrize('k', [0, 5])
@pytest.mark.parametrize('batch_size', [0])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_empty_input_batch(k: int, batch_size: int,
acceptance_sampler_method: str):
"""Verify that the SpecDecodeWorker calls the draft and target workers
when the input batch is empty. This can happen if the engine communicates
to the workers information without scheduling a batch.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
sampler_output = MagicMock(spec=SamplerOutput)
sampler_output.hidden_states = None
target_worker.execute_model.return_value = [sampler_output]
draft_worker.device = 'npu'
target_worker.device = 'npu'
set_random_seed(1)
worker = SpecDecodeWorker(
proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector,
)
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
prev_output_token_len=0)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
out = worker.execute_model(execute_model_req=execute_model_req)
assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].sampled_token_probs is None, (
"expect gpu tensor references to be None")
assert out[
0].sampled_token_ids is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with(execute_model_req)
target_worker.execute_model.assert_called_once_with(execute_model_req)
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@pytest.mark.skip_global_cleanup
def test_init_device(acceptance_sampler_method: str):
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
well as other GPU initialization.
"""
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
target_worker = mock_worker(use_spec=False)
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(
proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector,
)
worker.init_device()
draft_worker.init_device.assert_called_once()
target_worker.init_device.assert_called_once()
metrics_collector.init_tensors.assert_called_once()
spec_decode_sampler.init_tensors.assert_called_once()
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@torch.inference_mode()
def test_initialize_cache(acceptance_sampler_method):
"""Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer
workers.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
metrics_collector=metrics_collector)
kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
worker.initialize_cache(**kwargs)
draft_worker.initialize_cache.assert_called_once_with(**kwargs)
target_worker.initialize_cache.assert_called_once_with(**kwargs)
@pytest.mark.parametrize('available_gpu_blocks', [1, 1024])
@pytest.mark.parametrize('available_cpu_blocks', [500])
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
@pytest.mark.parametrize("acceptance_sampler_method",
["rejection_sampler", "typical_acceptance_sampler"])
@pytest.mark.skip_global_cleanup
def test_determine_num_available_blocks(available_gpu_blocks: int,
available_cpu_blocks: int,
target_cache_block_size_bytes: int,
draft_kv_size_bytes: int,
acceptance_sampler_method: str):
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks.
Specifically, it should run profiling in the scorer worker, and then evenly
split the blocks between proposer and scorer worker.
"""
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.determine_num_available_blocks.return_value = (
available_gpu_blocks, available_cpu_blocks)
target_worker.get_cache_block_size_bytes.return_value = (
target_cache_block_size_bytes)
draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes
worker = SpecDecodeWorker(
draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks()
target_worker.determine_num_available_blocks.assert_called_once()
assert num_cpu_blocks == available_cpu_blocks
assert num_gpu_blocks == split_num_cache_blocks_evenly(
target_cache_block_size_bytes, draft_kv_size_bytes,
available_gpu_blocks)
@pytest.mark.parametrize('available_gpu_blocks',
list(range(20)) + [1024, 1024**2])
@pytest.mark.parametrize('target_cache_block_size_bytes',
[2 * 2 * 4096, 2 * 2 * 8192])
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
@pytest.mark.skip_global_cleanup
def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
target_cache_block_size_bytes: int,
draft_kv_size_bytes: int):
"""Verify split_num_cache_blocks_evenly does not exceed original memory
allocation in bytes.
"""
num_blocks = split_num_cache_blocks_evenly(target_cache_block_size_bytes,
draft_kv_size_bytes,
available_gpu_blocks)
assert (num_blocks * target_cache_block_size_bytes) + (
num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks *
target_cache_block_size_bytes)
@torch.inference_mode()
def test_populate_seq_ids_with_bonus_tokens():
"""
Verify that a call to _create_output_sampler_list correctly updates
seq_with_bonus_token_in_last_step.
seq_with_bonus_token_in_last_step is an internal data structure in
SpecDecodeWorker that tracks the sequence IDs which are assigned bonus
tokens by the target model in their last forward pass. This state is
maintained only for models relying on the KV cache, such as those using
the MultiStepWorker.
"""
batch_size = 10
k = 5
vocab_size = 10000
num_sequences_with_bonus_tokens = 5
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
target_worker.device = 'npu'
set_random_seed(1)
draft_worker = mock_worker(cls=MultiStepWorker)
draft_worker.device = 'npu'
# The sequence_ids attached to each sequence in the batch.
# The sequence at index i has seq_id assigned_seq_ids[i]
assigned_seq_ids = list(range(batch_size))
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
seq_ids=assigned_seq_ids,
prev_output_token_len=10)
target_token_logprobs = torch.rand(batch_size, (k + 1),
vocab_size,
dtype=torch.float32,
device='npu')
accepted_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, (k + 1)),
dtype=torch.int64,
device='npu')
expected_request_id_seq_ids_mapping: dict[str, set[int]] = defaultdict(set)
for seq_group_metadata in seq_group_metadata_list:
for seq_id in seq_group_metadata.seq_data:
expected_request_id_seq_ids_mapping[
seq_group_metadata.request_id].add(seq_id)
# Generate a random sample of sequence indexes with bonus tokens
seq_indexes_with_bonus_tokens = random.sample(
range(batch_size), num_sequences_with_bonus_tokens)
# Create a mask that is True for indices in seq_indexes_with_bonus_tokens
mask = torch.ones(batch_size, dtype=torch.bool, device='npu')
mask[seq_indexes_with_bonus_tokens] = False
# Set the last token ID to -1 for all indices not in
# seq_indexes_with_bonus_tokens to indicate the lack of bonus token in
# those indices.
accepted_token_ids[mask, -1:] = -1
worker = SpecDecodeWorker(draft_worker,
target_worker,
mock_spec_decode_sampler("rejection_sampler"),
disable_logprobs=False,
metrics_collector=metrics_collector)
# Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs.
# This set includes all sequence IDs in the batch as well as an additional
# `num_extra_sequence_ids` sequence IDs. Note that the sequence IDs are in
# the range [0, batch_size + num_extra_sequence_ids).
num_extra_sequence_ids = 10
worker._seq_with_bonus_token_in_last_step = set(
range(batch_size + num_extra_sequence_ids))
worker._create_output_sampler_list(
seq_group_metadata_list=seq_group_metadata_list,
accepted_token_ids=accepted_token_ids,
target_logprobs=target_token_logprobs,
prompt_logprobs=None,
k=k,
stage_times=(0, 0, 0))
# Verify that _seq_with_bonus_token_in_last_step contains the following:
# 1. Sequence IDs that were already present in
# _seq_with_bonus_token_in_last_step but were not part of the current
# batch are retained.
# 2. Of the sequence IDs present in the current batch, only those with a
# bonus token are retained in _seq_with_bonus_token_in_last_step.
# Sequence IDs that are present in the current batch but do not have
# bonus tokens are removed from _seq_with_bonus_token_in_last_step.
expected_seq_ids_with_bonus_tokens = \
set([assigned_seq_ids[i] for i in seq_indexes_with_bonus_tokens])
additional_sequence_ids = \
set(range(batch_size, batch_size + num_extra_sequence_ids))
assert worker._seq_with_bonus_token_in_last_step == \
expected_seq_ids_with_bonus_tokens.union(additional_sequence_ids)
assert worker._request_id_seq_id_mapping == \
expected_request_id_seq_ids_mapping
@torch.inference_mode()
def test_handle_finished_requests():
"""
Test to verify that finished request IDs are appropriately processed to
update the internal state of the SpecDecodeWorker.
This test initializes the SpecDecodeWorker with mock data, marks certain
requests as finished, and ensures that the corresponding sequence IDs are
correctly removed from the internal mappings.
"""
batch_size = 32
k = 3
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker,
mock_spec_decode_sampler("rejection_sampler"),
metrics_collector)
# Initialize the request_id_seq_id_mapping mapping dict with a few fake
# request ids and corresponding sequence ids.
worker._request_id_seq_id_mapping = \
{'request-1': {1,2,3}, 'request-2': {4,5,6,7},
'request-3': {8,9}, 'request-4': {10,11}}
# Initialize seq_with_bonus_token_in_last_step with a few fake
# sequence ids.
worker._seq_with_bonus_token_in_last_step = {1, 4, 5, 8, 9, 10}
exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
# Mark requests with ids request-1 and request-3 as finished.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
finished_requests_ids=['request-1', 'request-3'])
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
# Verify that request-1 and request-3 are removed from
# request_id_seq_id_mapping
assert worker._request_id_seq_id_mapping == \
{'request-2': {4,5,6,7}, 'request-4': {10,11}}
# Verify that all sequence ids corresponding to 'request-1'
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
assert worker._seq_with_bonus_token_in_last_step == \
{4,5,10}
@pytest.mark.parametrize('k', [3])
@pytest.mark.parametrize('batch_size', [2, 32])
@pytest.mark.parametrize("batch_composition",
["prefill_only", "decode_only", "mixed"])
@torch.inference_mode()
def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
"""
Verify SpecDecodeWorker calls match the expected flow.
"""
vocab_size = 32_000
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker,
target_worker,
mock_spec_decode_sampler("rejection_sampler"),
disable_logprobs=False,
metrics_collector=metrics_collector)
exception_secret = 'artificial stop'
worker.scorer = mock_worker(BatchExpansionTop1Scorer)
worker.scorer.score_proposals.side_effect = ValueError(exception_secret)
# Create batch with combination of terminal/non-terminal prefill chunks
# and decodes (different seq_ids).
decodes, _, _ = create_batch(batch_size, k)
# Pre-chunking here, get 'batch_size' chunks.
prefill, _, _ = create_batch(batch_size,
k,
prefill_chunk_size=4,
seq_ids=list(range(batch_size,
batch_size * 2)))
if batch_composition == "prefill_only":
n_prefills = batch_size
elif batch_composition == "decode_only":
n_prefills = 0
else:
n_prefills = random.randint(1, batch_size - 1)
n_decodes = batch_size - n_prefills
prefill = random.sample(prefill, n_prefills)
decodes = random.sample(decodes, n_decodes)
target_group_metadata_list = prefill + decodes
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=target_group_metadata_list,
# For prefill only batches we expect num_lookahead_slots = 0.
num_lookahead_slots=k if n_decodes > 0 else 0)
target_token_ids = torch.randint(low=0,
high=vocab_size,
size=(1, batch_size * (k + 1)),
dtype=torch.int64,
device='npu')
target_token_probs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='npu')
target_token_logprobs = torch.rand(1,
batch_size * (k + 1),
vocab_size,
dtype=torch.float32,
device='npu')
target_output = create_sampler_output_list(target_token_ids,
target_token_probs,
target_token_logprobs)
target_worker.execute_model.return_value = [target_output[0]]
if not len(decodes):
worker.execute_model(execute_model_req=execute_model_req)
# no spec run (prefill only)
draft_worker.execute_model.assert_called_once_with(execute_model_req)
target_worker.execute_model.assert_called_once_with(execute_model_req)
else:
# Decode-only run OR mixed batch, scorer call fails (it's mocked)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
# but first draft still counted
assert draft_worker.get_spec_proposals.call_count == 1
def test_correctly_load_weight_for_eagle():
"""
Verify SpecDecodeWorker loads lm_head weight for eagle correctly.
"""
seed = 100
block_size = 32
num_gpu_blocks = 8096 // block_size
target_worker = create_worker(
NPUWorker,
"JackFram/llama-68m",
block_size,
num_gpu_blocks,
seed,
)
draft_worker = create_worker(
MultiStepWorker,
"abhigoyal/vllm-eagle-llama-68m-random",
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
spec_decode_sampler = mock_spec_decode_sampler("rejection_sampler")
worker = SpecDecodeWorker(draft_worker,
target_worker,
spec_decode_sampler,
disable_logprobs=False)
worker.proposer_worker.maybe_load_lm_head_weight(
target_worker.model_runner.model.lm_head.weight.data)
assert torch.allclose(
worker.proposer_worker.worker.model_runner.model.lm_head.weight.data,
worker.scorer_worker.model_runner.model.lm_head.weight.data)

View File

@@ -1,165 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/spec_decode/test_utils.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from unittest.mock import MagicMock
import pytest
import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.layers.sampler import _get_ranks
from vllm.model_executor.layers.typical_acceptance_sampler import \
TypicalAcceptanceSampler
from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
from vllm.spec_decode.util import (get_sampled_token_logprobs,
split_batch_by_proposal_len)
def test_get_all_seq_ids():
"""Verify get_all_seq_ids extracts all seq ids.
"""
expected_seq_ids = list(range(10)) + list(range(100, 110))
seq_group_metadata_list = [
SequenceGroupMetadata(
request_id=str(seq_id),
is_prompt=True,
seq_data={
seq_id: MagicMock(),
},
sampling_params=MagicMock(),
block_tables={
seq_id: MagicMock(),
},
lora_request=None,
) for seq_id in expected_seq_ids
]
actual_seq_ids = get_all_seq_ids(seq_group_metadata_list)
assert actual_seq_ids == expected_seq_ids
@pytest.fixture
def fake_sequence_group_metadata():
seq_ids = list(range(3))
return [
SequenceGroupMetadata(
request_id=str(i),
is_prompt=True,
seq_data={
i: MagicMock(),
},
sampling_params=MagicMock(),
block_tables={
i: MagicMock(),
},
lora_request=None,
) for i in seq_ids
]
def test_filter_zero_length_proposals(fake_sequence_group_metadata):
proposal_lens = [0, 1, 0]
_, (filtered_groups,
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens)
expected_groups = [
fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
]
expected_indices = [0, 2]
assert filtered_groups == expected_groups
assert indices == expected_indices
def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
proposal_lens = [0, 1, 2]
(filtered_groups,
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens)
expected_groups = [
fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
]
expected_indices = [1, 2]
assert filtered_groups == expected_groups
assert indices == expected_indices
def test_empty_inputs():
_, (filtered_groups, indices) = split_batch_by_proposal_len([], [])
assert filtered_groups == []
assert indices == []
def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
proposal_lens = [0, 0, 0]
(filtered_groups,
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens)
assert filtered_groups == []
assert indices == []
def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
proposal_lens = [1, 1, 1]
_, (filtered_groups,
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
proposal_lens)
assert filtered_groups == []
assert indices == []
def mock_spec_decode_sampler(acceptance_sampler_method):
"""
Returns either a RejectionSampler or TypicalAcceptanceSampler
object depending on whether acceptance_sampler_method is
'rejection_sampler' or 'typical_acceptance_sampler' respectively.
"""
if acceptance_sampler_method == "rejection_sampler":
sampler = MagicMock(spec=RejectionSampler)
sampler.token_id_dtype = torch.int64
return sampler
elif acceptance_sampler_method == "typical_acceptance_sampler":
sampler = MagicMock(spec=TypicalAcceptanceSampler)
sampler.token_id_dtype = torch.int64
return sampler
else:
raise ValueError(f"Invalid sampler name {acceptance_sampler_method}")
def test_get_sampled_token_logprobs():
"""Verify get_sampled_token_logprobs returns consistent rankings
with regular get_ranks when probabilities match exactly.
"""
logprob_tensor = torch.tensor(
[[[-.1, -.1]] * 2]) # shape (num_steps, batch_size, vocab_size)
sampled_token_tensor = torch.tensor([[1,
0]]) # shape (num_steps, batch_size)
ranks_spec_dec, _ = get_sampled_token_logprobs(logprob_tensor,
sampled_token_tensor)
ranks_regular = _get_ranks(logprob_tensor.reshape((2, -1)),
sampled_token_tensor.reshape(-1))
assert torch.equal(ranks_spec_dec.reshape(-1), ranks_regular)

View File

@@ -1,317 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/spec_decode/utils.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections.abc import Sequence as GenericSequence
from itertools import count
from typing import Callable, Optional, TypeVar, Union
from unittest.mock import MagicMock
import torch
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sampling_params import SamplingParams
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceData, SequenceGroupMetadata, SequenceOutput)
from vllm.spec_decode.ngram_worker import NGramWorker # noqa: F401
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine
from vllm_ascend.worker.model_runner import NPUModelRunner
from vllm_ascend.worker.worker import NPUWorker
T = TypeVar("T", bound=NPUWorker)
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
return (seq_len + block_size - 1) // block_size
def mock_worker(cls=None,
vocab_size: int = 30_000,
max_model_len: int = 2048,
rank: int = 0,
use_spec: bool = True) -> MagicMock:
if cls is None:
cls = NPUWorker
spec = cls if use_spec else None
worker = MagicMock(spec=spec)
worker.vocab_size = vocab_size
worker.max_model_len = max_model_len
worker.rank = rank
worker.device = 'npu:0'
return worker
def patch_execute_model_with_seeds(worker: NPUWorker, rand_seeds: list[int]):
seed_iter = iter(rand_seeds)
original_execute_model = worker.execute_model
def new_execute_model(*args, **kwargs):
result = original_execute_model(*args, **kwargs)
set_random_seed(next(seed_iter))
return result
return new_execute_model
def zero_kv_cache(cache_engine: list[CacheEngine]):
assert cache_engine[0].gpu_cache
for key_blocks, value_blocks in cache_engine[0].gpu_cache:
key_blocks.zero_()
value_blocks.zero_()
def create_worker(cls: Callable[..., T],
model_name: str,
block_size: int,
num_gpu_blocks: int,
seed: int,
is_driver_worker: bool = True,
enforce_eager: bool = True,
model_runner_cls: Optional[NPUModelRunner] = None,
dtype: Optional[str] = "auto") -> T:
engine_args = EngineArgs(
model=model_name,
seed=seed,
block_size=block_size,
enforce_eager=enforce_eager,
dtype=dtype,
)
engine_config = engine_args.create_engine_config()
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
if cls.__name__ == "NGramWorker":
# we need to pass by device type to enable this on npu
worker = cls(vllm_config=engine_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
model_runner_cls=model_runner_cls,
device_type="npu")
else:
worker = cls(
vllm_config=engine_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
model_runner_cls=model_runner_cls,
)
worker.init_device()
worker.load_model()
engine_config.cache_config.num_gpu_blocks = num_gpu_blocks
engine_config.cache_config.num_cpu_blocks = 0
worker.initialize_cache(
num_gpu_blocks=engine_config.cache_config.num_gpu_blocks,
num_cpu_blocks=engine_config.cache_config.num_cpu_blocks)
return worker
def create_seq_group_metadata_from_prompts(
prompts: list[list[int]],
num_gpu_blocks: int,
block_size: int,
final_prompt_lens: list[int],
continuations: Optional[list[list[int]]] = None,
seq_ids: Optional[list[int]] = None,
) -> list[SequenceGroupMetadata]:
if continuations is None:
continuations = [[] for _ in prompts]
if seq_ids is None:
seq_ids = list(i for i, _ in enumerate(prompts))
free_gpu_blocks = list(range(num_gpu_blocks))
block_allocations = {
i: [
free_gpu_blocks.pop()
for _ in range(round_up_to_next_block(final_len, block_size))
]
for i, final_len in enumerate(final_prompt_lens)
}
seq_grou_metadata_list = []
for i, (prompt_token_ids,
cont_token_ids) in enumerate(zip(prompts, continuations)):
data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids)
data.update_num_computed_tokens(
len(prompt_token_ids) + len(cont_token_ids) - 1)
seq_data = {i: data}
seq_grou_metadata_list.append(
SequenceGroupMetadata(
request_id=str(i),
is_prompt=len(cont_token_ids) == 0,
seq_data=seq_data,
sampling_params=SamplingParams(temperature=0.0),
block_tables={i: block_allocations[i][:]},
))
return seq_grou_metadata_list
def create_chunked_seq_group_metadata_from_prompt(
prompt: list[int],
num_gpu_blocks: int,
chunk_size: int,
block_size: int,
seq_id: Optional[int] = None) -> list[SequenceGroupMetadata]:
if seq_id is None:
seq_id = 0
free_gpu_blocks = list(range(num_gpu_blocks))
block_allocations = [
free_gpu_blocks.pop()
for _ in range(round_up_to_next_block(len(prompt), block_size))
]
seq_group_metadata_list = []
for i, idx in enumerate(range(0, len(prompt), chunk_size)):
chunk_ids = prompt[idx:idx + chunk_size]
data = SequenceData.from_seqs(prompt)
data.update_num_computed_tokens(idx)
seq_data = {i: data}
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=str(seq_id),
is_prompt=True,
do_sample=idx + chunk_size >= len(prompt), # terminal chunk
seq_data=seq_data,
sampling_params=SamplingParams(temperature=0.0),
block_tables={i: block_allocations},
token_chunk_size=len(chunk_ids)))
return seq_group_metadata_list
def assert_logprobs_dict_allclose(
actual_logprobs: list[dict[int, Logprob]],
expected_logprobs: list[dict[int, Logprob]]) -> None:
for single_step_actual_logprobs, single_step_expected_logprobs in zip(
actual_logprobs, expected_logprobs):
assert set(single_step_actual_logprobs.keys()) == set(
single_step_expected_logprobs.keys())
for token_id in single_step_actual_logprobs:
actual = torch.tensor(
single_step_actual_logprobs[token_id].logprob)
expected = torch.tensor(
single_step_expected_logprobs[token_id].logprob)
torch.testing.assert_close(actual, expected)
def create_sampler_output_list(
token_ids: torch.Tensor,
probs: GenericSequence[Optional[torch.Tensor]],
logprobs: GenericSequence[Optional[torch.Tensor]],
seq_ids: Optional[list[int]] = None) -> list[SamplerOutput]:
num_steps, batch_size = token_ids.shape
token_ids_by_step = token_ids.tolist()
if seq_ids is None:
seq_ids = list(range(batch_size))
return [
SamplerOutput(outputs=[
CompletionSequenceGroupOutput(
samples=[
SequenceOutput(
output_token=token_id,
parent_seq_id=seq_ids[seq_index],
logprobs={token_id: Logprob(0)},
)
],
prompt_logprobs=None,
) for seq_index, token_id in enumerate(token_ids_by_step[step])
],
sampled_token_probs=probs[step],
logprobs=logprobs[step],
sampled_token_ids=token_ids[step])
for step in range(num_steps)
]
def create_batch(batch_size,
k,
prompt_len: Union[int, list[int]] = 10,
prev_output_token_len: int = 10,
seq_ids: Optional[list[int]] = None,
num_gpu_blocks: Optional[int] = None,
block_size: Optional[int] = None,
prefill_chunk_size: Optional[int] = None):
if block_size is None:
block_size = 8
if num_gpu_blocks is None:
num_gpu_blocks = 2048 // block_size
iterator = count()
if isinstance(prompt_len, int):
prompt_lens = [prompt_len for _ in range(batch_size)]
else:
prompt_lens = prompt_len
prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens]
if prefill_chunk_size:
# Create a batch of chunked prompts.
if not seq_ids:
seq_ids = list(range(len(prompts)))
seq_group_metadata_list = []
for p, sid in zip(prompts, seq_ids):
seq_group_metadata_list += \
create_chunked_seq_group_metadata_from_prompt(
p, num_gpu_blocks, prefill_chunk_size, block_size, sid)
seq_group_metadata_list = seq_group_metadata_list[:batch_size]
prev_output_tokens = []
else:
prev_output_tokens = [[
next(iterator) for _ in range(prev_output_token_len)
] for _ in range(batch_size)]
final_prompt_lens = [
len(prompt) + len(prev_output_token) + k + 1
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size, final_prompt_lens,
prev_output_tokens, seq_ids)
return seq_group_metadata_list, prompts, prev_output_tokens
def maybe_enable_chunked_prefill(prefill_chunk_size, llm_kwargs):
if prefill_chunk_size > 0:
llm_kwargs.update(
**{
"enable_chunked_prefill": True,
"max_num_batched_tokens": prefill_chunk_size,
"max_num_seqs": prefill_chunk_size
})
else:
llm_kwargs["enable_chunked_prefill"] = False

View File

@@ -1,66 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/blob/main/tests/entrypoints/llm/test_accuracy.py
#
import gc
import multiprocessing
from multiprocessing import Queue
import lm_eval
import pytest
import torch
# pre-trained model path on Hugging Face.
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
# Math reasoning benchmark (Grade School Math 8K).
TASK = "gsm8k"
# Answer validation requiring format consistency.
FILTER = "exact_match,strict-match"
# 3% relative tolerance for numerical accuracy.
RTOL = 0.03
# Baseline accuracy after VLLM optimization.
EXPECTED_VALUE = 0.316
def run_test(queue, more_args=None):
model_args = f"pretrained={MODEL_NAME},max_model_len=4096"
if more_args is not None:
model_args = f"{model_args},{more_args}"
results = lm_eval.simple_evaluate(
model="vllm",
model_args=model_args,
tasks=TASK,
batch_size="auto",
)
result = results["results"][TASK][FILTER]
print("result:", result)
queue.put(result)
del results
torch.npu.empty_cache()
gc.collect()
def test_lm_eval_accuracy(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context():
result_queue: Queue[float] = multiprocessing.Queue()
p = multiprocessing.Process(target=run_test, args=(result_queue, ))
p.start()
p.join()
result = result_queue.get()
assert (EXPECTED_VALUE - RTOL < result < EXPECTED_VALUE + RTOL), \
f"Expected: {EXPECTED_VALUE}±{RTOL} | Measured: {result}"

View File

@@ -16,8 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import pytest
import torch
from vllm import LLM, SamplingParams
from vllm.utils import GiB_bytes
@@ -26,7 +25,6 @@ from tests.utils import fork_new_process_for_each_test
from vllm_ascend.device_allocator.camem import CaMemAllocator
@fork_new_process_for_each_test
def test_basic_camem():
# some tensors from default memory pool
shape = (1024, 1024)
@@ -59,9 +57,9 @@ def test_basic_camem():
assert torch.allclose(output, torch.ones_like(output) * 3)
@pytest.mark.skipif(True, reason="test failed, should be fixed later")
@fork_new_process_for_each_test
def test_end_to_end():
os.environ["VLLM_USE_V1"] = "0"
free, total = torch.npu.mem_get_info()
used_bytes_baseline = total - free # in case other process is running
llm = LLM("Qwen/Qwen2.5-0.5B-Instruct", enable_sleep_mode=True)

View File

@@ -35,7 +35,6 @@ MODELS = [
"Qwen/Qwen3-0.6B-Base",
]
MULTIMODALITY_MODELS = ["Qwen/Qwen2.5-VL-3B-Instruct"]
os.environ["VLLM_USE_MODELSCOPE"] = "True"
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
@@ -82,8 +81,3 @@ def test_multimodal(model, prompt_template, vllm_runner):
vllm_model.generate_greedy(prompts=prompts,
images=images,
max_tokens=64)
if __name__ == "__main__":
import pytest
pytest.main([__file__])