[CI] Refactor CI (#952)
1. remove some useless test func and file 2. fix format.sh problem 3. enable full test for singlecard and multicard 4. move long term test to long_term folder. For this kind of test, it only runs by labeled and daily test. Include: spec decode、accuracy test ## After refactor: There are 4 test modules - `singlecard`: contains the test running on one NPU. It'll be run for each PR and daily test. - `multicard`: contains the test running on multi NPUs. It'll be run for each PR and daily test. - `long_term`: contains the test that cost much time(Now include `spec decode` and `accuracy` test). It'll be run for the PR with `long-term-test` labeled and daily test. - `e2e`: contains the test for doc and pd feature. It'll be run for the PR with `pd-test` labeled and daily test. ## Todo: 1. some test are skipped, they should be fixed and reenabled in the future. 2. pyhccl test for multicard doesn't work at all. It should be enabled as well. 3. ensure long-term-test pass by daily test. ### Know issue Now, `ready` labels is required to start pd test or long term test. And when `long-term-test` or `pd-test` is labeled after another one, the old labeled test will be re-run again. So the labeled test should be ran in the following step: 1. decide which test need run, then label it. `long-term-test` or `pd-test` or both. 2. add `ready-for-test` label, then the test will be ran. Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
118
tests/singlecard/compile/test_simple.py
Normal file
118
tests/singlecard/compile/test_simple.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Test the piecewise compilation with a simple model so that we
|
||||
can exactly calculate the expected output and side effects.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.library import Library
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
global_counter = 0
|
||||
|
||||
# create a library to hold the custom op
|
||||
silly_lib = Library("silly", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
global global_counter
|
||||
global_counter += 1
|
||||
print(f"{global_counter=}")
|
||||
out.copy_(q)
|
||||
out[0] += 1
|
||||
|
||||
|
||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="attention",
|
||||
op_func=silly_attention,
|
||||
mutates_args=["out"],
|
||||
fake_impl=silly_attention_fake,
|
||||
dispatch_key="PrivateUse1",
|
||||
target_lib=silly_lib,
|
||||
)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class SillyModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Overall effect:
|
||||
x += 1
|
||||
x[0] += 2
|
||||
global_counter += 2
|
||||
"""
|
||||
x = x + 1
|
||||
x = x + 2
|
||||
out = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, out)
|
||||
x = out
|
||||
x = x - 2
|
||||
x = x - 1
|
||||
out = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, out)
|
||||
x = out
|
||||
x = x + 1
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="requires unreleased components")
|
||||
def test_simple_piecewise_compile():
|
||||
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_inductor=False,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_copy_inputs=True,
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
))
|
||||
vllm_config.compilation_config.pass_config.enable_fusion = False
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = SillyModel(vllm_config=vllm_config, prefix="")
|
||||
|
||||
inputs = torch.randn(100).npu()
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
|
||||
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
|
||||
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_caputured=
|
||||
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
|
||||
model(inputs)
|
||||
|
||||
model(torch.randn(2).npu())
|
||||
model(torch.randn(1).npu())
|
||||
|
||||
input = torch.zeros(2).npu()
|
||||
global global_counter
|
||||
global_counter = 0
|
||||
output = model(input)
|
||||
assert global_counter == 2
|
||||
assert torch.allclose(output.cpu(), torch.tensor([3.0, 1.0]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_simple_piecewise_compile()
|
||||
0
tests/singlecard/ops/__init__.py
Normal file
0
tests/singlecard/ops/__init__.py
Normal file
100
tests/singlecard/ops/test_fused_moe.py
Normal file
100
tests/singlecard/ops/test_fused_moe.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/kernels/test_moe.py
|
||||
"""Tests for the MOE layers.
|
||||
|
||||
Run `pytest tests/ops/test_fused_moe.py`.
|
||||
"""
|
||||
# fused moe ops test will hit the infer_schema error, we need add the patch
|
||||
# here to make the test pass.
|
||||
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
|
||||
from vllm_ascend.ops.fused_moe import fused_experts
|
||||
|
||||
NUM_EXPERTS = [8, 64]
|
||||
EP_SIZE = [1, 4]
|
||||
TOP_KS = [2, 6]
|
||||
DEVICE = ["npu"]
|
||||
|
||||
|
||||
def torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
topk_weights = topk_weights.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = SiluAndMul()(
|
||||
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weights.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
|
||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("device", DEVICE)
|
||||
def test_fused_experts(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
ep_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
|
||||
|
||||
score = torch.randn((m, e), device=device, dtype=dtype)
|
||||
|
||||
if ep_size > 1:
|
||||
local_e = e // ep_size
|
||||
e_ids = torch.randint(0,
|
||||
e, (local_e, ),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
e_map = torch.full((e, ), -1, device=device, dtype=torch.int32)
|
||||
e_map[e_ids] = torch.arange(local_e, device=device, dtype=torch.int32)
|
||||
w1 = w1[e_ids]
|
||||
w2 = w2[e_ids]
|
||||
else:
|
||||
e_map = None
|
||||
|
||||
score = torch.softmax(score, dim=-1, dtype=dtype)
|
||||
topk_weights, topk_ids = torch.topk(score, topk)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
|
||||
output = fused_experts(a, w1, w2, topk_weights, topk_ids, topk, e_map)
|
||||
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, e_map)
|
||||
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
|
||||
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1)
|
||||
torch.npu.empty_cache()
|
||||
190
tests/singlecard/ops/test_multi_step.py
Normal file
190
tests/singlecard/ops/test_multi_step.py
Normal file
@@ -0,0 +1,190 @@
|
||||
# Copyright (c) China Merchants Bank Co., Ltd. 2025. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#/
|
||||
|
||||
# to run this test, you need to cd to the upper package which is 'tests',
|
||||
# and run with command 'pytest -s ops/test_multi_step.py'
|
||||
|
||||
import torch
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
DTYPES = [torch.int32, torch.int64]
|
||||
DEVICES = [f"npu:{0}"]
|
||||
# Set tolerance to 0 for equals
|
||||
DEFAULT_ATOL = 0
|
||||
DEFAULT_RTOL = 0
|
||||
|
||||
# test custom ops of https://github.com/vllm-project/vllm-ascend/tree/main/csrc/kernels/advance_step.cpp
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_single_generation_multi_step() -> None:
|
||||
input_tokens_data = [2926]
|
||||
input_tokens_ascendc = torch.tensor(input_tokens_data, device='npu:0')
|
||||
input_tokens_python = torch.tensor(input_tokens_data, device='npu:0')
|
||||
|
||||
sampled_token_ids_data = [[13]]
|
||||
sampled_token_ids = torch.tensor(sampled_token_ids_data, device='npu:0')
|
||||
|
||||
input_positions_data = [5]
|
||||
input_positions_ascendc = torch.tensor(input_positions_data,
|
||||
device='npu:0')
|
||||
input_positions_python = torch.tensor(input_positions_data, device='npu:0')
|
||||
|
||||
seq_lens_data = [6]
|
||||
seq_lens_ascendc = torch.tensor(seq_lens_data,
|
||||
device='npu:0',
|
||||
dtype=torch.int32)
|
||||
seq_lens_python = torch.tensor(seq_lens_data,
|
||||
device='npu:0',
|
||||
dtype=torch.int32)
|
||||
|
||||
slot_mapping_data = [5]
|
||||
slot_mapping_ascendc = torch.tensor(slot_mapping_data,
|
||||
device='npu:0',
|
||||
dtype=torch.int32)
|
||||
slot_mapping_python = torch.tensor(slot_mapping_data,
|
||||
device='npu:0',
|
||||
dtype=torch.int32)
|
||||
|
||||
block_tables_data = [[0]]
|
||||
|
||||
block_tables = torch.tensor(block_tables_data,
|
||||
device='npu:0',
|
||||
dtype=torch.int32)
|
||||
|
||||
torch.ops._C.advance_step_flashattn_ascendc(
|
||||
1, 1, 128, input_tokens_ascendc, sampled_token_ids,
|
||||
input_positions_ascendc, seq_lens_ascendc, slot_mapping_ascendc,
|
||||
block_tables)
|
||||
|
||||
normal(1, 1, 128, input_tokens_python, sampled_token_ids,
|
||||
input_positions_python, seq_lens_python, slot_mapping_python,
|
||||
block_tables)
|
||||
|
||||
# Compare the results.
|
||||
torch.testing.assert_close(input_tokens_ascendc,
|
||||
input_tokens_python,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
|
||||
torch.testing.assert_close(input_positions_ascendc,
|
||||
input_positions_python,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
|
||||
torch.testing.assert_close(seq_lens_ascendc,
|
||||
seq_lens_python,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
|
||||
torch.testing.assert_close(slot_mapping_ascendc,
|
||||
slot_mapping_python,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_multi_result_generation_multi_step() -> None:
|
||||
input_tokens_data = [2926, 279, 12095, 1588]
|
||||
input_tokens_ascendc = torch.tensor(input_tokens_data, device='npu:0')
|
||||
input_tokens_python = torch.tensor(input_tokens_data, device='npu:0')
|
||||
|
||||
sampled_token_ids_data = [[13], [1968], [13], [13]]
|
||||
sampled_token_ids = torch.tensor(sampled_token_ids_data, device='npu:0')
|
||||
|
||||
input_positions_data = [5, 7, 5, 5]
|
||||
input_positions_ascendc = torch.tensor(input_positions_data,
|
||||
device='npu:0')
|
||||
input_positions_python = torch.tensor(input_positions_data, device='npu:0')
|
||||
|
||||
seq_lens_data = [6, 8, 6, 6]
|
||||
seq_lens_ascendc = torch.tensor(seq_lens_data,
|
||||
device='npu:0',
|
||||
dtype=torch.int32)
|
||||
seq_lens_python = torch.tensor(seq_lens_data,
|
||||
device='npu:0',
|
||||
dtype=torch.int32)
|
||||
|
||||
slot_mapping_data = [5, 135, 261, 389]
|
||||
slot_mapping_ascendc = torch.tensor(slot_mapping_data,
|
||||
device='npu:0',
|
||||
dtype=torch.int32)
|
||||
slot_mapping_python = torch.tensor(slot_mapping_data,
|
||||
device='npu:0',
|
||||
dtype=torch.int32)
|
||||
|
||||
block_tables_data = [[0], [1], [2], [3]]
|
||||
|
||||
block_tables = torch.tensor(block_tables_data,
|
||||
device='npu:0',
|
||||
dtype=torch.int32)
|
||||
|
||||
torch.ops._C.advance_step_flashattn_ascendc(
|
||||
4, 4, 128, input_tokens_ascendc, sampled_token_ids,
|
||||
input_positions_ascendc, seq_lens_ascendc, slot_mapping_ascendc,
|
||||
block_tables)
|
||||
|
||||
normal(4, 4, 128, input_tokens_python, sampled_token_ids,
|
||||
input_positions_python, seq_lens_python, slot_mapping_python,
|
||||
block_tables)
|
||||
|
||||
# Compare the results.
|
||||
torch.testing.assert_close(input_tokens_ascendc,
|
||||
input_tokens_python,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
|
||||
torch.testing.assert_close(input_positions_ascendc,
|
||||
input_positions_python,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
|
||||
torch.testing.assert_close(seq_lens_ascendc,
|
||||
seq_lens_python,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
|
||||
torch.testing.assert_close(slot_mapping_ascendc,
|
||||
slot_mapping_python,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
|
||||
|
||||
def normal(num_seqs: int, num_queries: int, block_size: int,
|
||||
input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor,
|
||||
input_positions: torch.Tensor, seq_lens_tensor: torch.Tensor,
|
||||
slot_mapping: torch.Tensor, block_tables: torch.Tensor) -> None:
|
||||
sampled_token_ids_list = sampled_token_ids[:num_queries].squeeze(-1)
|
||||
input_tokens[:num_queries] = sampled_token_ids_list
|
||||
|
||||
# get seq_lens and input_positions
|
||||
seq_lens = seq_lens_tensor[:num_queries]
|
||||
next_seq_lens = seq_lens + 1
|
||||
next_input_pos = next_seq_lens - 1
|
||||
|
||||
# update seq_lens and input_positions
|
||||
seq_lens_tensor[:num_queries] = next_seq_lens
|
||||
input_positions[:num_queries] = next_input_pos # type: ignore
|
||||
|
||||
# get block index and offset
|
||||
block_idx = next_input_pos // block_size
|
||||
block_offset = next_input_pos % block_size
|
||||
|
||||
current_block_table = block_tables.gather(
|
||||
1, block_idx.unsqueeze(-1)).squeeze(-1)
|
||||
slot_num = current_block_table * block_size + block_offset
|
||||
|
||||
# update slot_mapping
|
||||
slot_mapping[:num_queries] = slot_num
|
||||
198
tests/singlecard/ops/test_rotary_embedding.py
Normal file
198
tests/singlecard/ops/test_rotary_embedding.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# Copyright 2023 The vLLM team.
|
||||
|
||||
# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/main/vllm/tests/kernels/test_rotary_embedding.py
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm_ascend.platform # noqa: F401
|
||||
|
||||
# Only Neox style true scenario is supported for now
|
||||
IS_NEOX_STYLE = [True]
|
||||
DTYPES = [torch.half]
|
||||
HEAD_SIZES = [64, 96, 128, 256]
|
||||
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
|
||||
NUM_HEADS = [17] # Arbitrary values for testing
|
||||
BATCH_SIZES = [5] # Arbitrary values for testing
|
||||
SEQ_LENS = [11, 4096] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
DEVICES = [f"npu:{0}"]
|
||||
# Set tolerance to 1 for quant ops
|
||||
DEFAULT_ATOL = 1e-3
|
||||
DEFAULT_RTOL = 1e-3
|
||||
|
||||
|
||||
def _apply_rotary_emb(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
is_neox_style: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: [num_tokens, num_heads, head_size]
|
||||
cos: [num_tokens, head_size // 2]
|
||||
sin: [num_tokens, head_size // 2]
|
||||
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
||||
positional embeddings.
|
||||
"""
|
||||
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||
if is_neox_style:
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
else:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
o1 = x1 * cos - x2 * sin
|
||||
o2 = x2 * cos + x1 * sin
|
||||
if is_neox_style:
|
||||
return torch.cat((o1, o2), dim=-1)
|
||||
else:
|
||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/vllm/model_executor/layers/rotary_embedding.py
|
||||
class RotaryEmbedding(nn.Module):
|
||||
"""Original rotary positional embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
cache = cache.to(dtype)
|
||||
self.cos_sin_cache: torch.Tensor
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
"""Compute the inverse frequency."""
|
||||
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
||||
# use CPU to compute the cache and then move it to GPU. However, we
|
||||
# create the cache on GPU for faster initialization. This may cause
|
||||
# a slight numerical difference between the HF implementation and ours.
|
||||
inv_freq = 1.0 / (base**(torch.arange(
|
||||
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
"""Compute the cos and sin cache."""
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""A PyTorch-native implementation of forward()."""
|
||||
if offsets is not None:
|
||||
positions = positions + offsets
|
||||
positions = positions.flatten()
|
||||
num_tokens = positions.shape[0]
|
||||
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
|
||||
# test with leading dimension and merge seqlen and batch_size as num_tokens
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_rotary_embedding_quant_with_leading_dim(
|
||||
is_neox_style: bool,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
rotary_dim: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||
is_neox_style, dtype)
|
||||
rope = rope.to(dtype=dtype)
|
||||
num_tokens = batch_size * seq_len
|
||||
positions = torch.randint(0, max_position, (batch_size * seq_len, ))
|
||||
qkv_tensor = torch.randn(num_tokens,
|
||||
num_heads * head_size * 3,
|
||||
dtype=dtype)
|
||||
query, key, _ = qkv_tensor.split(
|
||||
[num_heads * head_size, num_heads * head_size, num_heads * head_size],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
ref_query, ref_key = rope.forward_native(positions, query, key)
|
||||
query, key = torch.ops._C.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
rope.head_size,
|
||||
rope.cos_sin_cache,
|
||||
rope.is_neox_style,
|
||||
)
|
||||
|
||||
# Compare the results.
|
||||
torch.testing.assert_close(query.view(ref_query.size()),
|
||||
ref_query,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
torch.testing.assert_close(key.view(ref_key.size()),
|
||||
ref_key,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
0
tests/singlecard/sample/__init__.py
Normal file
0
tests/singlecard/sample/__init__.py
Normal file
611
tests/singlecard/sample/test_rejection_sampler.py
Normal file
611
tests/singlecard/sample/test_rejection_sampler.py
Normal file
@@ -0,0 +1,611 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
from vllm_ascend.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
|
||||
AscendRejectionSampler)
|
||||
|
||||
DEVICE = "npu"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rejection_sampler():
|
||||
return AscendRejectionSampler()
|
||||
|
||||
|
||||
def create_logits_tensor(output_token_ids: list[list[int]],
|
||||
vocab_size: int = 100) -> torch.Tensor:
|
||||
"""Helper function to create logits tensor that
|
||||
will produce desired token ids on argmax"""
|
||||
token_ids = [tokens[:-1] for tokens in output_token_ids]
|
||||
num_total_tokens = sum(len(tokens) for tokens in token_ids)
|
||||
logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE)
|
||||
start_loc = 0
|
||||
for tokens in token_ids:
|
||||
for j, token_id in enumerate(tokens):
|
||||
logits[start_loc + j, token_id] = 100.0
|
||||
start_loc += len(tokens)
|
||||
return logits
|
||||
|
||||
|
||||
def create_sampling_metadata(
|
||||
all_greedy: bool,
|
||||
temperature: Optional[torch.Tensor] = None,
|
||||
top_k: Optional[torch.Tensor] = None,
|
||||
top_p: Optional[torch.Tensor] = None,
|
||||
generators: Optional[dict[int, Any]] = None,
|
||||
) -> SamplingMetadata:
|
||||
"""Create a v1 sampling metadata object with all_greedy set
|
||||
to the given value. Either all greedy or all random sampling
|
||||
is used.
|
||||
"""
|
||||
generators = generators or {}
|
||||
if all_greedy:
|
||||
temperature = None
|
||||
else:
|
||||
assert temperature is not None
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=temperature,
|
||||
all_greedy=all_greedy,
|
||||
all_random=not all_greedy,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=torch.empty(1, ),
|
||||
generators=generators,
|
||||
max_num_logprobs=0,
|
||||
no_penalties=False,
|
||||
prompt_token_ids=None,
|
||||
frequency_penalties=torch.tensor([]),
|
||||
presence_penalties=torch.tensor([]),
|
||||
repetition_penalties=torch.tensor([]),
|
||||
output_token_ids=[],
|
||||
min_tokens={},
|
||||
logit_bias=[None],
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
)
|
||||
|
||||
|
||||
########################### Tests for Greedy Sampling ###################
|
||||
def test_perfect_match(rejection_sampler):
|
||||
"""Test when output tokens perfectly match speculated tokens"""
|
||||
spec_tokens = [[1, 2, 3]]
|
||||
output_tokens = [[1, 2, 3, 4]] # 4 is the bonus token
|
||||
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
|
||||
device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor([[1, 2, 3, 4]],
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
assert torch.equal(output, expected)
|
||||
|
||||
|
||||
def test_early_mismatch(rejection_sampler):
|
||||
"""Test when there's an early mismatch in tokens"""
|
||||
spec_tokens = [[1, 2, 3]]
|
||||
output_tokens = [[1, 5, 3, 4]] # Mismatch at position 1
|
||||
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
|
||||
device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor(
|
||||
[[1, 5, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]],
|
||||
dtype=torch.int,
|
||||
device=logits.device,
|
||||
)
|
||||
assert torch.equal(output, expected)
|
||||
|
||||
|
||||
def test_multiple_sequences(rejection_sampler):
|
||||
"""Test handling multiple sequences of speculated tokens"""
|
||||
spec_tokens = [[1, 2], [3]]
|
||||
output_tokens = [[1, 2, 5], [3,
|
||||
4]] # Two sequences with bonus tokens 5 and 4
|
||||
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor(
|
||||
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor([[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]],
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
assert torch.equal(output, expected)
|
||||
|
||||
|
||||
def test_single_token_sequence(rejection_sampler):
|
||||
"""Test handling sequences with single token"""
|
||||
spec_tokens = [[1]]
|
||||
output_tokens = [[1, 2]] # Single token with bonus token 2
|
||||
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
|
||||
device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
|
||||
assert torch.equal(output, expected)
|
||||
|
||||
|
||||
def test_empty_sequence(rejection_sampler):
|
||||
"""Test handling empty sequence of speculated tokens"""
|
||||
spec_tokens: list[list[int]] = [[]]
|
||||
output_tokens = [[5]] # Just the bonus token
|
||||
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
|
||||
device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
|
||||
assert torch.equal(output, expected)
|
||||
|
||||
|
||||
def test_multiple_mismatches(rejection_sampler):
|
||||
"""Test handling multiple sequences with mismatches"""
|
||||
spec_tokens = [[1, 2, 3], [4, 5, 6]]
|
||||
output_tokens = [[1, 2, 7, 6], [4, 8, 6,
|
||||
9]] # Mismatches in both sequences
|
||||
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor(
|
||||
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor(
|
||||
[[1, 2, 7, PLACEHOLDER_TOKEN_ID],
|
||||
[4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]],
|
||||
dtype=torch.int,
|
||||
device=logits.device,
|
||||
)
|
||||
assert torch.equal(output, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec_tokens,output_tokens,expected",
|
||||
[
|
||||
([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]), # Perfect match with bonus
|
||||
([[1]], [[2, 3]], [[2, PLACEHOLDER_TOKEN_ID]]), # First mismatch
|
||||
([[1, 2], [3, 4]], [[1, 5, 6], [3, 4, 7]],
|
||||
[[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]]), # Mixed matches
|
||||
])
|
||||
def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens,
|
||||
expected):
|
||||
"""Parametrized test for various matching scenarios"""
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens],
|
||||
device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
|
||||
device=logits.device)
|
||||
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected_tensor = torch.tensor(expected,
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
assert torch.equal(output, expected_tensor)
|
||||
|
||||
|
||||
########################### Tests for Random Sampling ###################
|
||||
@pytest.mark.parametrize("k", [1, 3, 5])
|
||||
@pytest.mark.parametrize("vocab_size", [1000])
|
||||
@pytest.mark.parametrize("batch_size", [1, 4, 8])
|
||||
@pytest.mark.parametrize("frac_seeded", [0.0, 0.5])
|
||||
@pytest.mark.parametrize("n_rep", [20])
|
||||
def test_deterministic_when_seeded(
|
||||
rejection_sampler,
|
||||
k: int,
|
||||
vocab_size: int,
|
||||
batch_size: int,
|
||||
frac_seeded: float,
|
||||
n_rep: int,
|
||||
):
|
||||
num_tokens = batch_size * k
|
||||
draft_probs = torch.rand(num_tokens,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=DEVICE)
|
||||
draft_probs = F.softmax(draft_probs, dim=-1)
|
||||
target_logits = torch.rand_like(draft_probs)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64,
|
||||
device=DEVICE)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64,
|
||||
device=DEVICE)
|
||||
|
||||
seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded
|
||||
|
||||
results = []
|
||||
for _ in range(n_rep):
|
||||
seeded_seqs = {
|
||||
i: torch.Generator(device=DEVICE).manual_seed(i)
|
||||
for i in range(batch_size) if seeded_mask[i]
|
||||
}
|
||||
|
||||
temperature = torch.ones(batch_size,
|
||||
dtype=torch.float32,
|
||||
device=DEVICE)
|
||||
sampling_metadata = create_sampling_metadata(all_greedy=False,
|
||||
temperature=temperature,
|
||||
generators=seeded_seqs)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
draft_token_ids.tolist(), device=DEVICE)
|
||||
rep_result = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=draft_probs,
|
||||
target_logits=target_logits,
|
||||
bonus_token_ids=bonus_token_ids,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
results.append(rep_result)
|
||||
|
||||
for i in range(batch_size):
|
||||
if seeded_mask[i]:
|
||||
for j in range(1, n_rep):
|
||||
assert torch.equal(results[j][i], results[0][i])
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="Test failed, need fix")
|
||||
def test_rejection_sampling_approximates_target_distribution():
|
||||
"""Verify rejection sampling approximates target distribution,
|
||||
despite sampling from a potentially distinct draft distribution.
|
||||
|
||||
This is done by first creating a random target probability
|
||||
distribution and a random draft probability distribution. We then
|
||||
sample token ids from the rejection sampler using these draft
|
||||
and target distributions. The samples are used to estimate
|
||||
the output probability distribution, which we expect to approximate
|
||||
the target distribution.
|
||||
|
||||
A basic distance metric is used to determine similarity between
|
||||
distributions.
|
||||
|
||||
We expect that as we increase the number of samples,
|
||||
the distance between the observed distribution and the target
|
||||
distribution decreases. To measure this, we compare the distance
|
||||
of the observed distribution against both the target distribution
|
||||
and a uniform random distribution. We expect the distance between
|
||||
the observed distribution and the target distribution to improve
|
||||
much more than the distance improvement between the observed
|
||||
distribution and the random distribution.
|
||||
"""
|
||||
torch.set_default_device(DEVICE)
|
||||
vocab_size = 10
|
||||
k = 2
|
||||
num_reference_probs = 100
|
||||
|
||||
# Prepare draft, target, and reference probability distributions
|
||||
draft_probs = F.softmax(torch.rand(vocab_size, dtype=torch.float32),
|
||||
dim=-1)
|
||||
target_logits = torch.rand(vocab_size, dtype=torch.float32)
|
||||
target_probs = F.softmax(target_logits, dim=-1)
|
||||
reference_probs = F.softmax(
|
||||
torch.rand(num_reference_probs, vocab_size, dtype=torch.float32),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
sample_sizes = [10, 100, 1_000, 10_000, 100_000]
|
||||
distance_wrt_reference: list[float] = []
|
||||
distance_wrt_target: list[float] = []
|
||||
|
||||
for num_samples in sample_sizes:
|
||||
# Sample using rejection sampling.
|
||||
rej_sample_probs = estimate_rejection_sampling_pdf(
|
||||
draft_probs, target_logits, k, vocab_size, num_samples)
|
||||
rej_sample_probs = rej_sample_probs.to(DEVICE)
|
||||
|
||||
# Average distance from reference probs.
|
||||
reference_vs_rejsample_dist = torch.dist(
|
||||
reference_probs,
|
||||
rej_sample_probs).item() / reference_probs.shape[0]
|
||||
target_vs_rejsample_dist = torch.dist(target_probs,
|
||||
rej_sample_probs).item()
|
||||
|
||||
distance_wrt_reference.append(reference_vs_rejsample_dist)
|
||||
distance_wrt_target.append(target_vs_rejsample_dist)
|
||||
|
||||
relative_change_in_distance_wrt_target = get_ratio_first_to_last(
|
||||
distance_wrt_target)
|
||||
relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
|
||||
distance_wrt_reference)
|
||||
|
||||
print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
|
||||
f"{reference_vs_rejsample_dist=:.05f}")
|
||||
print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} "
|
||||
f"{relative_change_in_distance_wrt_reference=:.02f}")
|
||||
|
||||
relative_change_in_distance_wrt_target = get_ratio_first_to_last(
|
||||
distance_wrt_target)
|
||||
relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
|
||||
distance_wrt_reference)
|
||||
|
||||
expected_improvement_multiplier = 20
|
||||
assert (relative_change_in_distance_wrt_target >
|
||||
relative_change_in_distance_wrt_reference *
|
||||
expected_improvement_multiplier)
|
||||
|
||||
|
||||
def get_ratio_first_to_last(elements: list[float]) -> float:
|
||||
return elements[0] / elements[-1]
|
||||
|
||||
|
||||
def estimate_rejection_sampling_pdf(
|
||||
draft_probs: torch.Tensor,
|
||||
target_logits: torch.Tensor,
|
||||
k: int,
|
||||
vocab_size: int,
|
||||
num_samples: int,
|
||||
) -> torch.Tensor:
|
||||
"""Estimate the probability distribution of the output tokens
|
||||
using rejection sampling.
|
||||
|
||||
Args:
|
||||
draft_probs: Draft probability distribution.
|
||||
target_logits: Target logits.
|
||||
num_samples: Number of samples to draw.
|
||||
|
||||
Returns:
|
||||
Estimated probability distribution of the output tokens.
|
||||
"""
|
||||
rejection_sampler = AscendRejectionSampler()
|
||||
num_tokens = num_samples * k
|
||||
# Repeat draft probs num_samples * k times.
|
||||
draft_probs = draft_probs.reshape(1, 1,
|
||||
vocab_size).repeat(num_samples, k, 1)
|
||||
|
||||
# Repeat target probs num_tokens times.
|
||||
target_logits = target_logits.reshape(1, vocab_size).repeat(num_tokens, 1)
|
||||
|
||||
# Randomly sample draft token ids from draft probs.
|
||||
draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
|
||||
num_samples=k,
|
||||
replacement=True).reshape(
|
||||
num_samples, k)
|
||||
draft_probs = draft_probs.view(num_tokens, vocab_size)
|
||||
|
||||
# Bonus tokens not used but required.
|
||||
bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64,
|
||||
device=DEVICE).repeat(num_samples, 1)
|
||||
|
||||
temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE)
|
||||
sampling_metadata = create_sampling_metadata(all_greedy=False,
|
||||
temperature=temperature)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
draft_token_ids.tolist(), device=bonus_token_ids.device)
|
||||
output_token_ids = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=draft_probs,
|
||||
target_logits=target_logits,
|
||||
bonus_token_ids=bonus_token_ids,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
output_token_ids = output_token_ids[:, :-1].flatten()
|
||||
|
||||
hist = torch.histogram(output_token_ids.to(dtype=torch.float,
|
||||
device="cpu"),
|
||||
bins=vocab_size,
|
||||
range=(0, vocab_size),
|
||||
density=True)
|
||||
|
||||
return hist.hist
|
||||
|
||||
|
||||
def _test_masked_logits(
|
||||
rejection_sampler,
|
||||
batch_size: int,
|
||||
num_draft_tokens: int,
|
||||
vocab_size: int,
|
||||
target_logits: torch.Tensor,
|
||||
unmasked_indices: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
):
|
||||
# Set up test parameters
|
||||
num_tokens = batch_size * num_draft_tokens
|
||||
|
||||
# Create random draft probabilities.
|
||||
draft_probs = torch.rand((num_tokens, vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=DEVICE)
|
||||
draft_probs = F.softmax(draft_probs, dim=-1)
|
||||
|
||||
# Randomly sample draft token ids from draft probs
|
||||
draft_token_ids = torch.multinomial(draft_probs, num_samples=1)
|
||||
draft_token_ids = draft_token_ids.reshape(batch_size, num_draft_tokens)
|
||||
draft_token_ids = draft_token_ids.tolist()
|
||||
|
||||
# Bonus tokens not used but required
|
||||
bonus_token_ids = torch.zeros((batch_size, 1),
|
||||
dtype=torch.int64,
|
||||
device=DEVICE)
|
||||
|
||||
# Create spec decode metadata
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
draft_token_ids,
|
||||
device=DEVICE,
|
||||
)
|
||||
|
||||
# Run rejection sampling
|
||||
output_token_ids = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=draft_probs,
|
||||
target_logits=target_logits,
|
||||
bonus_token_ids=bonus_token_ids,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
# Remove bonus tokens and reshape
|
||||
output_token_ids = output_token_ids[:, :-1].flatten().tolist()
|
||||
|
||||
# Check that all sampled tokens are within the unmasked indices.
|
||||
for i in range(num_tokens):
|
||||
token_id = output_token_ids[i]
|
||||
if token_id == PLACEHOLDER_TOKEN_ID:
|
||||
continue
|
||||
assert token_id in unmasked_indices[i]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("top_k", [1, 5, 99])
|
||||
def test_top_k(rejection_sampler, top_k):
|
||||
"""Test rejection sampling with top-k sampling"""
|
||||
vocab_size = 100
|
||||
batch_size = 100
|
||||
num_draft_tokens = 3
|
||||
num_tokens = batch_size * num_draft_tokens
|
||||
|
||||
# Randomly create top-k indices.
|
||||
top_k_indices = [
|
||||
torch.randperm(vocab_size, device=DEVICE)[:top_k]
|
||||
for _ in range(num_tokens)
|
||||
]
|
||||
top_k_indices = torch.stack(top_k_indices)
|
||||
|
||||
# Create logits with the uniform distribution.
|
||||
target_logits = torch.zeros((num_tokens, vocab_size), device=DEVICE)
|
||||
|
||||
# Increment the logits for top-k indices, a little bit more than the other
|
||||
# ones. If the masking is effective, the non-topk indices will never be
|
||||
# sampled despite the small difference in logits.
|
||||
for i in range(num_tokens):
|
||||
target_logits[i, top_k_indices[i]] += 0.1
|
||||
|
||||
# Create sampling metadata
|
||||
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
|
||||
sampling_metadata = create_sampling_metadata(
|
||||
all_greedy=False,
|
||||
temperature=temperature,
|
||||
top_k=torch.tensor([top_k] * batch_size,
|
||||
device=DEVICE,
|
||||
dtype=torch.int64),
|
||||
)
|
||||
|
||||
_test_masked_logits(
|
||||
rejection_sampler,
|
||||
batch_size=batch_size,
|
||||
num_draft_tokens=num_draft_tokens,
|
||||
vocab_size=vocab_size,
|
||||
target_logits=target_logits,
|
||||
unmasked_indices=top_k_indices,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("top_p", [0.5, 0.9, 0.99])
|
||||
def test_top_p(rejection_sampler, top_p):
|
||||
"""Test rejection sampling with top-p sampling"""
|
||||
vocab_size = 100
|
||||
batch_size = 100
|
||||
num_draft_tokens = 3
|
||||
num_tokens = batch_size * num_draft_tokens
|
||||
|
||||
# Create logits with the uniform distribution.
|
||||
target_logits = torch.randn((num_tokens, vocab_size), device=DEVICE)
|
||||
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
|
||||
rescaled_logits = target_logits / temperature
|
||||
|
||||
logits_sort, logits_idx = rescaled_logits.sort(dim=-1, descending=False)
|
||||
probs_sort = logits_sort.softmax(dim=-1)
|
||||
probs_sum = probs_sort.cumsum(dim=-1)
|
||||
top_p_mask = probs_sum <= 1 - top_p
|
||||
# at least one
|
||||
top_p_mask[:, -1] = False
|
||||
|
||||
# Get the top-p indices.
|
||||
top_p_indices = []
|
||||
for i in range(num_tokens):
|
||||
top_p_indices.append(logits_idx[i][~top_p_mask[i]].tolist())
|
||||
|
||||
# Create sampling metadata
|
||||
sampling_metadata = create_sampling_metadata(
|
||||
all_greedy=False,
|
||||
temperature=temperature,
|
||||
top_p=torch.tensor([top_p] * batch_size,
|
||||
device=DEVICE,
|
||||
dtype=torch.float32),
|
||||
)
|
||||
|
||||
_test_masked_logits(
|
||||
rejection_sampler,
|
||||
batch_size=batch_size,
|
||||
num_draft_tokens=num_draft_tokens,
|
||||
vocab_size=vocab_size,
|
||||
target_logits=target_logits,
|
||||
unmasked_indices=top_p_indices,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
@@ -1,18 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from vllm_ascend.patch import worker # noqa: F401
|
||||
@@ -1,28 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/tests/spec_decode/conftest.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
Since this module is V0 only, set VLLM_USE_V1=0 for
|
||||
all tests in the module.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
@@ -1,274 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/conftest.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import shutil
|
||||
from itertools import cycle
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import PromptLogprobs, SampleLogprobs
|
||||
|
||||
from ....model_utils import (TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs,
|
||||
check_logprobs_close, check_outputs_equal)
|
||||
|
||||
PROMPTS = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
"San Francisco is know for its",
|
||||
"Facebook was created in 2004 by",
|
||||
"Curious George is a",
|
||||
"Python 3.11 brings improvements to its",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
test_llm_kwargs, seed):
|
||||
|
||||
def generate():
|
||||
kwargs = {
|
||||
**common_llm_kwargs,
|
||||
**per_test_common_llm_kwargs,
|
||||
**test_llm_kwargs,
|
||||
}
|
||||
|
||||
llm = LLM(**kwargs)
|
||||
|
||||
if seed is not None:
|
||||
set_random_seed(seed)
|
||||
|
||||
yield llm
|
||||
|
||||
del llm
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
return generate
|
||||
|
||||
|
||||
def maybe_assert_ngram_worker(llm):
|
||||
# Verify the proposer worker is ngram if ngram is specified.
|
||||
if (llm.llm_engine.speculative_config is not None
|
||||
and llm.llm_engine.speculative_config.method == "ngram"):
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
assert isinstance(
|
||||
llm.llm_engine.model_executor.driver_worker.proposer_worker,
|
||||
NGramWorker)
|
||||
|
||||
|
||||
def get_output_from_llm_generator(
|
||||
llm_generator, prompts,
|
||||
sampling_params) -> Tuple[List[str], List[List[int]], float]:
|
||||
tokens: List[str] = []
|
||||
token_ids: List[List[int]] = []
|
||||
acceptance_rate: float = -1.0
|
||||
for llm in llm_generator():
|
||||
maybe_assert_ngram_worker(llm)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
|
||||
token_ids = [output.outputs[0].token_ids for output in outputs]
|
||||
tokens = [output.outputs[0].text for output in outputs]
|
||||
|
||||
# Fetch acceptance rate if logging is enabled.
|
||||
if stat_loggers := getattr(llm.llm_engine, "stat_loggers", None):
|
||||
stat_logger = stat_loggers["prometheus"]
|
||||
acceptance_rate = (stat_logger.metrics.
|
||||
gauge_spec_decode_draft_acceptance_rate.labels(
|
||||
**stat_logger.labels)._value.get())
|
||||
del llm
|
||||
|
||||
return tokens, token_ids, acceptance_rate
|
||||
|
||||
|
||||
def check_logprobs_correctness(
|
||||
spec_outputs: Sequence[Union[TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs]],
|
||||
baseline_outputs: Sequence[Union[TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs]],
|
||||
disable_logprobs: bool = False,
|
||||
):
|
||||
"""Compare sampled and prompt logprobs between baseline and spec decoding
|
||||
"""
|
||||
if not disable_logprobs:
|
||||
return check_logprobs_close(
|
||||
outputs_0_lst=baseline_outputs,
|
||||
outputs_1_lst=spec_outputs,
|
||||
name_0="org",
|
||||
name_1="sd",
|
||||
)
|
||||
|
||||
# Check correctness when disable_logprobs == True
|
||||
for spec_output, baseline_output in zip(spec_outputs, baseline_outputs):
|
||||
# Check generated token logprobs.
|
||||
spec_logprobs = spec_output[2]
|
||||
baseline_logprobs = baseline_output[2]
|
||||
_check_logprobs_when_output_disabled(spec_logprobs,
|
||||
baseline_logprobs,
|
||||
is_prompt_logprobs=False)
|
||||
|
||||
# Check prompt logprobs too, if they exist
|
||||
if len(baseline_output) == 4:
|
||||
assert len(spec_output) == 4
|
||||
spec_prompt_logprobs = spec_output[3]
|
||||
baseline_prompt_logprobs = baseline_output[3]
|
||||
_check_logprobs_when_output_disabled(spec_prompt_logprobs,
|
||||
baseline_prompt_logprobs,
|
||||
is_prompt_logprobs=True)
|
||||
|
||||
|
||||
def _check_logprobs_when_output_disabled(
|
||||
spec_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
|
||||
baseline_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs],
|
||||
is_prompt_logprobs: bool = False,
|
||||
):
|
||||
# Prompt logprobs are optional
|
||||
if is_prompt_logprobs and baseline_logprobs is None:
|
||||
assert spec_logprobs is None
|
||||
return
|
||||
|
||||
assert spec_logprobs is not None
|
||||
assert baseline_logprobs is not None
|
||||
assert len(spec_logprobs) == len(baseline_logprobs)
|
||||
|
||||
# For each generated position of the sequence.
|
||||
for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
|
||||
zip(spec_logprobs, baseline_logprobs)):
|
||||
|
||||
# First prompt logprob is expected to be None
|
||||
if is_prompt_logprobs and baseline_pos_logprobs is None:
|
||||
assert spec_pos_logprobs is None
|
||||
assert pos == 0
|
||||
continue
|
||||
|
||||
assert spec_pos_logprobs is not None
|
||||
assert baseline_pos_logprobs is not None
|
||||
|
||||
# When disabled, the 1 logprob is returned with dummy values for the
|
||||
# score and rank, but the token id should match the baseline model
|
||||
assert len(spec_pos_logprobs) == 1
|
||||
(spec_pos_logprob_token_id,
|
||||
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
|
||||
assert spec_pos_logprob.rank == -1
|
||||
assert spec_pos_logprob.logprob == 0.0
|
||||
if isinstance(spec_pos_logprob_token_id, torch.Tensor):
|
||||
spec_pos_logprob_token_id = spec_pos_logprob_token_id.item()
|
||||
assert spec_pos_logprob_token_id in baseline_pos_logprobs
|
||||
|
||||
|
||||
def _clean_torchair_cache():
|
||||
cache_path = Path.cwd() / '.torchair_cache'
|
||||
if cache_path.exists() and cache_path.is_dir():
|
||||
shutil.rmtree(cache_path)
|
||||
|
||||
|
||||
def run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size: int,
|
||||
max_output_len: int,
|
||||
seed: Optional[int] = 0,
|
||||
temperature: float = 0.0,
|
||||
disable_seed: bool = False,
|
||||
ignore_eos: bool = True,
|
||||
ensure_all_accepted: bool = False,
|
||||
expected_acceptance_rate: Optional[float] = None,
|
||||
logprobs: Optional[int] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
disable_logprobs: bool = False):
|
||||
|
||||
org_args = {
|
||||
**common_llm_kwargs,
|
||||
**per_test_common_llm_kwargs,
|
||||
**baseline_llm_kwargs,
|
||||
}
|
||||
|
||||
sd_args = {
|
||||
**common_llm_kwargs,
|
||||
**per_test_common_llm_kwargs,
|
||||
**test_llm_kwargs,
|
||||
}
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
|
||||
|
||||
if disable_seed:
|
||||
seed = None
|
||||
|
||||
sampling_params = SamplingParams(temperature=temperature,
|
||||
max_tokens=max_output_len,
|
||||
seed=seed,
|
||||
ignore_eos=ignore_eos,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
|
||||
# TODO current torchair graph mode needs clean torchair cache.
|
||||
# if do not clean, it will raise error
|
||||
additional_config = common_llm_kwargs.get("additional_config")
|
||||
enable_graph_mode = additional_config.get(
|
||||
"enable_graph_mode") if additional_config else False
|
||||
|
||||
with vllm_runner(**org_args) as vllm_model:
|
||||
if enable_graph_mode:
|
||||
_clean_torchair_cache()
|
||||
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
||||
|
||||
with vllm_runner(**sd_args) as vllm_model:
|
||||
if enable_graph_mode:
|
||||
_clean_torchair_cache()
|
||||
if ensure_all_accepted or expected_acceptance_rate is not None:
|
||||
# Force log interval to be 0 to catch all metrics.
|
||||
stat_logger = vllm_model.model.llm_engine.stat_loggers[
|
||||
'prometheus']
|
||||
stat_logger.local_interval = -100
|
||||
|
||||
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
||||
|
||||
if ensure_all_accepted or expected_acceptance_rate is not None:
|
||||
acceptance_rate = (stat_logger.metrics.
|
||||
gauge_spec_decode_draft_acceptance_rate.labels(
|
||||
**stat_logger.labels)._value.get())
|
||||
|
||||
if ensure_all_accepted:
|
||||
assert True
|
||||
# FIXME: ci fails to log acceptance rate.
|
||||
# It works locally.
|
||||
# assert acceptance_rate == 1.0
|
||||
|
||||
if expected_acceptance_rate is not None:
|
||||
assert acceptance_rate >= expected_acceptance_rate - 1e-2
|
||||
|
||||
# Only pass token entries, not the logprobs
|
||||
check_outputs_equal(outputs_0_lst=[out[0:2] for out in org_outputs],
|
||||
outputs_1_lst=[out[0:2] for out in sd_outputs],
|
||||
name_0="org",
|
||||
name_1="sd")
|
||||
|
||||
# Check logprobs if requested
|
||||
if logprobs is not None or prompt_logprobs is not None:
|
||||
check_logprobs_correctness(sd_outputs, org_outputs, disable_logprobs)
|
||||
@@ -1,450 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_medusa_correctness.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""This docstring details important information on the testing methodology.
|
||||
|
||||
Most of the tests rely on "greedy equality", where we expect the output of
|
||||
speculative decoding on a sequence to exactly match the output of normal non-
|
||||
speculative decoding.
|
||||
|
||||
Since speculative decoding with rejection sampling guarantees that the output
|
||||
distribution matches the target model's output distribution (up to hardware
|
||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||
equality.
|
||||
|
||||
However, we still need to verify below scenario could be passed:
|
||||
* Batch size 1 greedy equality
|
||||
* Batch size >1 greedy equality
|
||||
* Test greedy equality under preemption
|
||||
* Test greedy equality under various number of speculative tokens.
|
||||
|
||||
With those tests, we can say at least, Medusa would not break the
|
||||
correctess for the target model outputs.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.singlecard.spec_decode.e2e.conftest import \
|
||||
run_equality_correctness_test
|
||||
from tests.singlecard.spec_decode.utils import maybe_enable_chunked_prefill
|
||||
|
||||
# main model
|
||||
# lmsys/vicuna-7b-v1.3 was to be used but it's causing
|
||||
# OOM in CI pipeline, so using a smaller model.
|
||||
MAIN_MODEL = "JackFram/llama-68m"
|
||||
|
||||
# speculative model
|
||||
SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random"
|
||||
|
||||
# max number of speculative tokens: this corresponds to
|
||||
# num_heads in the config.json of the speculator model.
|
||||
MAX_SPEC_TOKENS = 5
|
||||
|
||||
# precision
|
||||
# TODO: The vLLM here uses float32, but some op on the vllm-ascend
|
||||
# do not support float32, such as ROPE, When it is fixed, it is
|
||||
# recommended to change this to float32 to keep it consistent
|
||||
# with vLLM.
|
||||
PRECISION = "float16"
|
||||
|
||||
PREFILL_CHUNK_SIZE = [
|
||||
-1,
|
||||
# TODO:enable chunked prefill when it is supported
|
||||
# 32
|
||||
]
|
||||
|
||||
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
|
||||
def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int, prefill_chunk_size: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
8,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
|
||||
def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int, logprobs: int,
|
||||
prefill_chunk_size: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="Open it when graph mode ready.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"enforce_eager": False,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
|
||||
def test_medusa_e2e_greedy_correctness_cuda_graph(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int, prefill_chunk_size: int):
|
||||
"""Verify greedy equality with cuda graph enabled and different
|
||||
batch sizes."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="Open it when preempt ready.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 16,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
|
||||
def test_medusa_e2e_greedy_correctness_with_preemption(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int, prefill_chunk_size: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
}
|
||||
# Try a range of num. speculative tokens
|
||||
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
|
||||
def test_medusa_different_k(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int, prefill_chunk_size: int):
|
||||
"""Verify that medusa speculative decoding produces exact equality
|
||||
to without spec decode with different values of num_speculative_tokens.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_by_batch_size": 4,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
|
||||
def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int,
|
||||
prefill_chunk_size: int):
|
||||
"""Verify that medusa speculative decoding produces exact equality
|
||||
to without spec decode when speculation is disabled for large
|
||||
batch sizes.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_by_batch_size": 4,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE)
|
||||
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int, prefill_chunk_size: int):
|
||||
"""Verify that speculative decoding generates the same output
|
||||
with batch expansion scorer and mqa scorer.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
pytest.main([__file__])
|
||||
@@ -1,560 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_mlp_correctness.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""This docstring details important information on the testing methodology.
|
||||
|
||||
Most of the tests rely on "greedy equality", where we expect the output of
|
||||
speculative decoding on a sequence to exactly match the output of normal non-
|
||||
speculative decoding.
|
||||
|
||||
Since speculative decoding with rejection sampling guarantees that the output
|
||||
distribution matches the target model's output distribution (up to hardware
|
||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||
equality.
|
||||
|
||||
However, we still need to verify below scenario could be passed:
|
||||
* Batch size 1 greedy equality
|
||||
* Batch size >1 greedy equality
|
||||
* Test greedy equality under preemption
|
||||
* Test greedy equality under various number of speculative tokens.
|
||||
|
||||
With those tests, we can say at least, MLPSpeculator would not break the
|
||||
correctness for the target model outputs.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import \
|
||||
pad_vocab_size # noqa: F401
|
||||
|
||||
from tests.singlecard.spec_decode.e2e.conftest import \
|
||||
run_equality_correctness_test
|
||||
from tests.singlecard.spec_decode.utils import maybe_enable_chunked_prefill
|
||||
|
||||
# main model
|
||||
MAIN_MODEL = "JackFram/llama-160m"
|
||||
|
||||
# speculative model
|
||||
SPEC_MODEL = "ibm-ai-platform/llama-160m-accelerator"
|
||||
|
||||
# max. number of speculative tokens: this corresponds to
|
||||
# n_predict in the config.json of the speculator model.
|
||||
MAX_SPEC_TOKENS = 3
|
||||
|
||||
PREFILL_CHUNK_SIZE_1 = [
|
||||
-1,
|
||||
# TODO:enable chunked prefill when it is supported
|
||||
# 4
|
||||
]
|
||||
PREFILL_CHUNK_SIZE_2 = [
|
||||
-1,
|
||||
# TODO:enable chunked prefill when it is supported
|
||||
# 32
|
||||
]
|
||||
# precision
|
||||
# TODO: The vLLM here uses float32, but some op on the vllm-ascend
|
||||
# do not support float32, such as ROPE, When it is fixed, it is
|
||||
# recommended to change this to float32 to keep it consistent
|
||||
# with vLLM.
|
||||
PRECISION = "float16"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_2)
|
||||
def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int, prefill_chunk_size: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [8])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
|
||||
def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int,
|
||||
logprobs: int, prefill_chunk_size: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
# NOTE Test is sensitive enough st if we don't enable chunked prefill
|
||||
# scheduling on baseline too, we get slightly different logprobs, ending
|
||||
# up sampling different tokens at the tail (ie top tokens don't change).
|
||||
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [2048])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
|
||||
def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify acceptance rate with different batch size and large output
|
||||
length."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
temperature=0.0,
|
||||
seed=seed,
|
||||
expected_acceptance_rate=0.48)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
|
||||
# Speculative config
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
|
||||
@pytest.mark.parametrize("output_len", [64])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("temperature", [1.0])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
temperature: float,
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify seeded runs produce the same output."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
temperature=temperature,
|
||||
seed=seed)
|
||||
|
||||
# Ensure this same test does fail if we _don't_ include per-request seeds
|
||||
with pytest.raises(AssertionError):
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
temperature=temperature,
|
||||
seed=seed,
|
||||
disable_seed=True)
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="Open it when preempt ready.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 16,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_e2e_greedy_correctness_with_preemption(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="Open it when preempt ready.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 16,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
|
||||
def test_mlp_e2e_greedy_correctness_with_padding(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify greedy equality when the vocab dimension is padded
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
|
||||
# Default pad_to is 64, test model has vocab_size of 32000
|
||||
def patched_pad_vocab_size(vocab_size, pad_to=None):
|
||||
return pad_vocab_size(vocab_size, pad_to=32064)
|
||||
|
||||
# NOTE: Compared with vLLM, the patch method has been modified
|
||||
pad_vocab_size = patched_pad_vocab_size # noqa: F811
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
}
|
||||
# Try a range of num. speculative tokens
|
||||
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_different_k(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
prefill_chunk_size: int, seed: int, output_len: int):
|
||||
"""Verify that mlp speculative decoding produces exact equality
|
||||
to without spec decode with different values of num_speculative_tokens.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"disable_by_batch_size": 4,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
# Speculative decoding is disabled when sequences reach decoding and the batch
|
||||
# consists of single-token requests. Hence we set `max_num_seqs`
|
||||
# >= `speculative_disable_by_batch_size` to test feature interaction.
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
prefill_chunk_size: int, seed: int,
|
||||
output_len: int):
|
||||
"""Verify that mlp speculative decoding produces exact equality
|
||||
to without spec decode when speculation is disabled for large
|
||||
batch sizes.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": MAIN_MODEL,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"model": SPEC_MODEL,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("prefill_chunk_size", PREFILL_CHUNK_SIZE_1)
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||
output_len: int, prefill_chunk_size: int, seed: int):
|
||||
"""Verify that speculative decoding generates the same output
|
||||
with batch expansion scorer and mqa scorer.
|
||||
"""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
@@ -1,457 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_mtp_correctness.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""This docstring details important information on the testing methodology.
|
||||
|
||||
Most of the tests rely on "greedy equality", where we expect the output of
|
||||
speculative decoding on a sequence to exactly match the output of normal non-
|
||||
speculative decoding.
|
||||
|
||||
Since speculative decoding with rejection sampling guarantees that the output
|
||||
distribution matches the target model's output distribution (up to hardware
|
||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||
equality.
|
||||
|
||||
However, we still need to verify below scenario could be passed:
|
||||
* Batch size 1 greedy equality
|
||||
* Batch size >1 greedy equality
|
||||
* Test greedy equality under preemption
|
||||
* Test greedy equality under various number of speculative tokens.
|
||||
|
||||
With those tests, we can say at least, mtp would not break the
|
||||
correctess for the target model outputs.
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
# NOTE both main model and MTP are bfloat16
|
||||
FLOAT_MODEL = "wemaster/deepseek_mtp_main_random_bf16"
|
||||
|
||||
# NOTE main model is w8a8, MTP is bfloat16
|
||||
QUANT_MODEL = "wemaster/deepseek_mtp_main_random_w8a8_part"
|
||||
|
||||
# TODO when msmodelslim can quantify both main and MTP model
|
||||
# This UT should use w8a8 fully weights.
|
||||
|
||||
# max. number of speculative tokens: this corresponds to
|
||||
# num_nextn_predict_layers in the config.json of the speculator model.
|
||||
MAX_SPEC_TOKENS = 1
|
||||
|
||||
# precision
|
||||
PRECISION = "bfloat16"
|
||||
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
|
||||
reason="mtp is not supported on v1")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": FLOAT_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.85
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="quant model is not ready.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": QUANT_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.8
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_e2e_quant_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
|
||||
reason="mtp is not supported on v1")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": FLOAT_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.8
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int,
|
||||
logprobs: int):
|
||||
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="torchair ut can not clean mem.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"additional_config": {
|
||||
'enable_graph_mode': True,
|
||||
},
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": FLOAT_MODEL,
|
||||
"gpu_memory_utilization": 0.8
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_e2e_greedy_correctness_torchair_graph(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality with torchair graph enabled and different
|
||||
batch sizes using bfloat16 weights."""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="quant model is not ready.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"additional_config": {
|
||||
'enable_graph_mode': True,
|
||||
},
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": QUANT_MODEL,
|
||||
"gpu_memory_utilization": 0.8
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_e2e_quant_greedy_correctness_torchair_graph(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality with torchair graph enabled and different
|
||||
batch sizes using quant weights."""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
|
||||
reason="mtp is not supported on v1")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 16,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": FLOAT_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.8
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
128,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_e2e_greedy_correctness_with_preemption(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
|
||||
reason="mtp is not supported on v1")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": FLOAT_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.8
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": k,
|
||||
},
|
||||
}
|
||||
# Try a range of num. speculative tokens
|
||||
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_different_k(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that mtp speculative decoding produces exact equality
|
||||
to without spec decode with different values of num_speculative_tokens.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
|
||||
reason="mtp is not supported on v1")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model_name": FLOAT_MODEL,
|
||||
|
||||
# GPU memory utilization
|
||||
"gpu_memory_utilization": 0.8
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||
"disable_by_batch_size": 4
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mtp_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that mtp speculative decoding produces exact equality
|
||||
to without spec decode when speculation is disabled for large
|
||||
batch sizes.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
pytest.main([__file__])
|
||||
@@ -1,404 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/tests/spec_decode/e2e/test_ngram_correctness.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""This docstring details important information on the testing methodology.
|
||||
|
||||
Most of the tests rely on "greedy equality", where we expect the output of
|
||||
speculative decoding on a sequence to exactly match the output of normal non-
|
||||
speculative decoding.
|
||||
|
||||
Since speculative decoding with rejection sampling guarantees that the output
|
||||
distribution matches the target model's output distribution (up to hardware
|
||||
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||
equality.
|
||||
|
||||
For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding,
|
||||
and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775.
|
||||
Since there is no model is needed for generate the proposal, we could make
|
||||
the testcase much simpler than drafter multi-step one.
|
||||
|
||||
However, we still need to verify below scenario could be passed:
|
||||
* Batch size 1 greedy equality
|
||||
* Batch size >1 greedy equality
|
||||
* Test greedy equality under preemption
|
||||
* Test greedy equality under various ngram sizes / speculative sizes
|
||||
|
||||
With those tests, we can say at least, ngram spec would not break the correctess
|
||||
for the target model outputs.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.singlecard.spec_decode.e2e.conftest import \
|
||||
run_equality_correctness_test
|
||||
from tests.singlecard.spec_decode.utils import maybe_enable_chunked_prefill
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_mqa_scorer": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize(
|
||||
"prefill_chunk_size",
|
||||
[
|
||||
-1,
|
||||
# TODO:enable chunked prefill when it is supported
|
||||
# 4
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
prefill_chunk_size: int, seed: int):
|
||||
"""Verify greedy equality on a tiny model with different batch size."""
|
||||
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_logprobs": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_logprobs": True,
|
||||
},
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("output_len", [
|
||||
8,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int,
|
||||
logprobs: int):
|
||||
"""Verify greedy equality on a tiny model with different batch size."""
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=logprobs,
|
||||
disable_logprobs=test_llm_kwargs["speculative_config"]
|
||||
["disable_logprobs"])
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="Open it when preempt ready.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"block_size": 16,
|
||||
# 2 for small prompt, 256//8 for generated.
|
||||
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||
"max_model_len": (2 + 256 // 8) * 8,
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"model_name": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use small output len for fast test.
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_e2e_greedy_correctness_with_preemption(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
temperature=0,
|
||||
seed=seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": k,
|
||||
"prompt_lookup_max": 3,
|
||||
},
|
||||
}
|
||||
# Try a range of common k, as well as large speculation.
|
||||
for k in [1, 3, 5]
|
||||
] + [
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": k,
|
||||
"prompt_lookup_max": 1,
|
||||
},
|
||||
}
|
||||
# Try a range of common k, as well as large speculation.
|
||||
for k in [1, 3, 5]
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_different_k(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that ngram speculative decoding produces exact equality
|
||||
to without spec decode with many different values of k and
|
||||
different ngram_prompt_lookup_max.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_by_batch_size": 4
|
||||
},
|
||||
},
|
||||
{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_by_batch_size": 4,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
"enable_chunked_prefill": False,
|
||||
# FIXME: enable me when chunked prefill is available
|
||||
# "max_num_batched_tokens": 4,
|
||||
"max_num_seqs": 4
|
||||
}
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that ngram speculative decoding produces exact equality
|
||||
to without spec decode with many different values of k and
|
||||
different ngram_prompt_lookup_max.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model_name": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{
|
||||
"speculative_config": {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": 5,
|
||||
"prompt_lookup_max": 3,
|
||||
"disable_mqa_scorer": True,
|
||||
},
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_scorer(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that ngram speculative decoding generates the same output
|
||||
with batch expansion scorer and mqa scorer.
|
||||
"""
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
@@ -1,156 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import random
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_prompts():
|
||||
prompt_types = ["repeat", "sentence"]
|
||||
num_prompts = 100
|
||||
prompts = []
|
||||
|
||||
random.seed(0)
|
||||
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
|
||||
|
||||
# Generate a mixed batch of prompts, some of which can be easily
|
||||
# predicted by n-gram matching and some which likely cannot.
|
||||
for kind in random_prompt_type_choices:
|
||||
word_choices = ["test", "temp", "hello", "where"]
|
||||
word = random.choice(word_choices)
|
||||
if kind == "repeat":
|
||||
prompt = f"""
|
||||
please repeat the word '{word}' 10 times.
|
||||
give no other output than the word at least ten times in a row,
|
||||
in lowercase with spaces between each word and without quotes.
|
||||
"""
|
||||
elif kind == "sentence":
|
||||
prompt = f"""
|
||||
please give a ten-word sentence that
|
||||
uses the word {word} at least once.
|
||||
give no other output than that simple sentence without quotes.
|
||||
"""
|
||||
else:
|
||||
raise ValueError(f"Unknown prompt type: {kind}")
|
||||
prompts.append([{"role": "user", "content": prompt}])
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sampling_config():
|
||||
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_name():
|
||||
return "LLM-Research/Meta-Llama-3.1-8B-Instruct"
|
||||
|
||||
|
||||
def eagle_model_name():
|
||||
return "vllm-ascend/EAGLE-LLaMA3.1-Instruct-8B"
|
||||
|
||||
|
||||
def eagle3_model_name():
|
||||
return "vllm-ascend/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
|
||||
|
||||
def test_ngram_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
test_prompts: list[list[dict[str, Any]]],
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
):
|
||||
'''
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using ngram speculative decoding.
|
||||
'''
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
ref_llm = LLM(model=model_name, max_model_len=1024)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
speculative_config={
|
||||
"method": "ngram",
|
||||
"prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 3,
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
max_model_len=1024,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||
|
||||
# Heuristic: expect at least 70% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(0.7 * len(ref_outputs))
|
||||
del spec_llm
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
|
||||
def test_eagle_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
test_prompts: list[list[dict[str, Any]]],
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
use_eagle3: bool,
|
||||
):
|
||||
'''
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using eagle speculative decoding.
|
||||
'''
|
||||
pytest.skip("Not current support for the test.")
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
ref_llm = LLM(model=model_name, max_model_len=2048)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
|
||||
spec_model_name = eagle3_model_name(
|
||||
) if use_eagle3 else eagle_model_name()
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
speculative_config={
|
||||
"method": "eagle3" if use_eagle3 else "eagle",
|
||||
"model": spec_model_name,
|
||||
"num_speculative_tokens": 3,
|
||||
"max_model_len": 2048,
|
||||
},
|
||||
max_model_len=2048,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||
|
||||
# Heuristic: expect at least 66% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(0.66 * len(ref_outputs))
|
||||
del spec_llm
|
||||
@@ -1,105 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/tests/spec_decode/test_dynamic_spec_decode.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
|
||||
from tests.singlecard.spec_decode.test_utils import mock_spec_decode_sampler
|
||||
from tests.singlecard.spec_decode.utils import create_batch, mock_worker
|
||||
|
||||
|
||||
@pytest.mark.parametrize('queue_size', [4])
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
@pytest.mark.parametrize('k', [1])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify that speculative tokens are disabled when the batch size
|
||||
exceeds the threshold.
|
||||
"""
|
||||
disable_by_batch_size = 3
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
worker = SpecDecodeWorker(proposer_worker=draft_worker,
|
||||
scorer_worker=target_worker,
|
||||
spec_decode_sampler=mock_spec_decode_sampler(
|
||||
acceptance_sampler_method),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector,
|
||||
disable_by_batch_size=disable_by_batch_size)
|
||||
|
||||
exception_secret = 'artificial stop'
|
||||
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k,
|
||||
running_queue_size=queue_size)
|
||||
|
||||
if queue_size > disable_by_batch_size:
|
||||
with patch.object(worker,
|
||||
'_run_no_spec',
|
||||
side_effect=ValueError(exception_secret)), \
|
||||
pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
# When the batch size is larger than the threshold,
|
||||
# we expect no speculative tokens (0).
|
||||
expected_num_spec_tokens = None if queue_size < disable_by_batch_size else 0
|
||||
assert seq_group_metadata_list[
|
||||
0].num_speculative_tokens == expected_num_spec_tokens
|
||||
|
||||
draft_worker.sampler_output.side_effect = ValueError(exception_secret)
|
||||
|
||||
proposer = Top1Proposer(
|
||||
worker=draft_worker,
|
||||
device='cpu', # not used
|
||||
vocab_size=100, # not used
|
||||
# Must be long enough to avoid being skipped due to length.
|
||||
max_proposal_len=1024,
|
||||
)
|
||||
|
||||
if queue_size < disable_by_batch_size:
|
||||
# Should raise exception when executing the mocked draft model.
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k),
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
else:
|
||||
# Should not execute the draft model because spec decode is disabled
|
||||
# for all requests. Accordingly, the proposal length should be 0.
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k),
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
assert proposals.proposal_lens.tolist() == [0] * batch_size
|
||||
@@ -1,846 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/tests/spec_decode/test_multi_step_worker.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import random
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob,
|
||||
get_all_seq_ids)
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
|
||||
from tests.singlecard.spec_decode.utils import (
|
||||
assert_logprobs_dict_allclose, create_batch,
|
||||
create_seq_group_metadata_from_prompts, create_worker,
|
||||
patch_execute_model_with_seeds, zero_kv_cache)
|
||||
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
|
||||
from vllm_ascend.worker.worker import NPUWorker
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_steps', list(range(1, 17)))
|
||||
def test_assert_enough_kv_space(num_steps: int):
|
||||
"""Test that the multi step worker checks for sufficient space in the KV
|
||||
cache. It should throw if it cannot run all the steps.
|
||||
"""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
|
||||
prompts = [
|
||||
list(range(block_size * 3)),
|
||||
list(range(block_size * 2)),
|
||||
]
|
||||
|
||||
prev_output_tokens = [
|
||||
list(range(block_size * 1)),
|
||||
list(range(block_size * 2)),
|
||||
]
|
||||
|
||||
final_prompt_lens = [
|
||||
len(prompt + output) + num_steps
|
||||
for prompt, output in zip(prompts, prev_output_tokens)
|
||||
]
|
||||
|
||||
inputs = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens,
|
||||
continuations=prev_output_tokens)
|
||||
|
||||
assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access
|
||||
worker = MagicMock()
|
||||
worker.model_runner.block_size = block_size
|
||||
|
||||
for seq_group_metadata in inputs:
|
||||
original_block_tables = seq_group_metadata.block_tables
|
||||
|
||||
# No exception.
|
||||
assert_enough_kv_space(worker, inputs, num_steps)
|
||||
|
||||
seq_group_metadata.block_tables = {
|
||||
seq_id: []
|
||||
for seq_id, physical_blocks in original_block_tables.items()
|
||||
}
|
||||
|
||||
# Expect exception.
|
||||
with pytest.raises(ValueError,
|
||||
match='times but found insufficient KV space for'):
|
||||
assert_enough_kv_space(worker, inputs, num_steps)
|
||||
|
||||
seq_group_metadata.block_tables = original_block_tables
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_same_output_for_single_step():
|
||||
"""Verify the multi step worker produces the same output as the normal
|
||||
worker for num_steps=1.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
multi_step_worker = create_worker(
|
||||
MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
worker = create_worker(
|
||||
NPUWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
# multi_step_worker.model_runner = worker.model_runner
|
||||
# multi_step_worker.cache_engine = worker.cache_engine
|
||||
|
||||
num_steps = 1
|
||||
|
||||
prompts = [
|
||||
[1, 2, 3, 4, 5],
|
||||
[6, 7, 8, 9, 10],
|
||||
]
|
||||
|
||||
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
|
||||
|
||||
multi_step_seq_group = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
actual_output, _ = multi_step_worker.sampler_output(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=multi_step_seq_group),
|
||||
sample_len=num_steps,
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
assert len(actual_output) == num_steps
|
||||
actual_output = actual_output[0]
|
||||
|
||||
single_step_seq_group = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
expected_output = worker.execute_model(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=single_step_seq_group))[0]
|
||||
|
||||
actual_token_ids = [
|
||||
output.samples[0].output_token for output in actual_output
|
||||
]
|
||||
actual_logprobs = [output.samples[0].logprobs for output in actual_output]
|
||||
|
||||
expected_token_ids = [
|
||||
output.samples[0].output_token for output in expected_output
|
||||
]
|
||||
expected_logprobs = [
|
||||
output.samples[0].logprobs for output in expected_output
|
||||
]
|
||||
|
||||
assert actual_token_ids == expected_token_ids
|
||||
|
||||
print(f'{actual_logprobs=}')
|
||||
print(f'{expected_logprobs=}')
|
||||
assert_logprobs_dict_allclose(actual_logprobs, expected_logprobs)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_same_output_for_multi_step():
|
||||
"""Verify the multi-step worker produces the same output as the normal
|
||||
worker when num_steps > 1. This test runs the multi-step worker once, and
|
||||
then runs the worker num_steps times, and compares the output.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
|
||||
block_size = 16
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
multi_step_worker = create_worker(
|
||||
MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
|
||||
worker = create_worker(
|
||||
NPUWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
|
||||
# Make sure we go over the block boundary.
|
||||
num_steps = block_size + 1
|
||||
|
||||
random.seed(seed)
|
||||
prompts = [[
|
||||
random.randint(0, 1000) for _ in range(random.randint(10, 20))
|
||||
] for _ in range(10)]
|
||||
|
||||
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
|
||||
|
||||
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
|
||||
multi_step_worker.execute_model = patch_execute_model_with_seeds(
|
||||
multi_step_worker, rand_seeds)
|
||||
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
|
||||
|
||||
continuations = [[1] for _ in prompts]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Run multi-step.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
multi_step_output, _ = multi_step_worker.sampler_output(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list),
|
||||
sample_len=num_steps,
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
|
||||
# Run single-step repeatedly.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
single_step_output: list[SamplerOutput] = []
|
||||
continuations = [[1] for _ in prompts]
|
||||
set_random_seed(seed)
|
||||
|
||||
for _ in multi_step_output:
|
||||
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
single_step_output.extend(
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list)))
|
||||
|
||||
# Append output tokens to new sequence data.
|
||||
for i, seq_group_output in enumerate(single_step_output[-1]):
|
||||
continuations[i].append(seq_group_output.samples[0].output_token)
|
||||
|
||||
# Get token ids and logprobs for comparison.
|
||||
multi_step_output_logprobs: list[list[dict[int,
|
||||
Logprob]]] = [[]
|
||||
for _ in prompts]
|
||||
single_step_output_logprobs: list[list[dict[int,
|
||||
Logprob]]] = [[]
|
||||
for _ in prompts]
|
||||
|
||||
multi_step_output_token_ids: list[list[int]] = [[] for _ in prompts]
|
||||
single_step_output_token_ids: list[list[int]] = [[] for _ in prompts]
|
||||
for i, _ in enumerate(prompts):
|
||||
for multi_step, single_step in zip(multi_step_output,
|
||||
single_step_output):
|
||||
multi_step_output_token_ids[i].append(
|
||||
multi_step[i].samples[0].output_token)
|
||||
single_step_output_token_ids[i].append(
|
||||
single_step[i].samples[0].output_token)
|
||||
|
||||
multi_step_output_logprobs[i].append(
|
||||
multi_step[i].samples[0].logprobs)
|
||||
single_step_output_logprobs[i].append(
|
||||
single_step[i].samples[0].logprobs)
|
||||
|
||||
# Print per-sequence token ids
|
||||
for i, (multi_step_tokens, single_step_tokens) in enumerate(
|
||||
zip(multi_step_output_token_ids, single_step_output_token_ids)):
|
||||
print(f'{i=} {multi_step_tokens=}')
|
||||
print(f'{i=} {single_step_tokens=}')
|
||||
print(f'{i=} equal {multi_step_tokens == single_step_tokens}')
|
||||
|
||||
# Assert token ids are equal.
|
||||
for multi_step_tokens, single_step_tokens in zip(
|
||||
multi_step_output_token_ids, single_step_output_token_ids):
|
||||
assert multi_step_tokens == single_step_tokens
|
||||
|
||||
# Assert logprobs are equal.
|
||||
for multi_step_logprobs, single_step_logprobs in zip(
|
||||
multi_step_output_logprobs, single_step_output_logprobs):
|
||||
assert_logprobs_dict_allclose(multi_step_logprobs,
|
||||
single_step_logprobs)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_multi_step_with_batch_expansion_correct_output():
|
||||
"""
|
||||
In this test we verify that the MultiStepWorker is able to handle bonus
|
||||
tokens correctly. The test verifies that if a sequence has a
|
||||
bonus token then the MultiStepWorker is able to expand the batch by adding
|
||||
new sequences corresponding to the sequences with bonus tokens. The
|
||||
expanded batch is then used for predicting the next tokens.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
|
||||
block_size = 16
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
batch_size = 128
|
||||
multi_step_worker = create_worker(
|
||||
MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
multi_step_worker.set_include_gpu_probs_tensor()
|
||||
worker = create_worker(
|
||||
NPUWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
random.seed(seed)
|
||||
prompts = [[0] for _ in range(batch_size)]
|
||||
num_steps = 2
|
||||
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
|
||||
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
|
||||
multi_step_worker.execute_model = patch_execute_model_with_seeds(
|
||||
multi_step_worker, rand_seeds)
|
||||
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
|
||||
# Create the test continuations
|
||||
continuations = [[random.randint(0, 1000)] for _ in prompts]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Run single-step twice to generate 2 tokens. This
|
||||
# will simulate the bonus token case with the second token
|
||||
# being the bonus token.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
single_step_output: list[SamplerOutput] = []
|
||||
set_random_seed(seed)
|
||||
for _ in range(num_steps):
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
single_step_output.extend(
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list)))
|
||||
# Append output tokens to new sequence data.
|
||||
for i, seq_group_output in enumerate(single_step_output[-1]):
|
||||
continuations[i].append(seq_group_output.samples[0].output_token)
|
||||
|
||||
# Create continuations for the MultiStepWorker. The continuations have
|
||||
# 2 tokens in order to simulate the bonus token case.
|
||||
multi_step_continuations = []
|
||||
for continuation in continuations:
|
||||
multi_step_continuations.append(continuation[:2])
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=multi_step_continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Run multi-step and verify that the third token prediction is accurate
|
||||
# for all sequences.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
all_seq_ids = {i for i in range(batch_size)}
|
||||
multi_step_output, _ = multi_step_worker.sampler_output(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list),
|
||||
sample_len=1,
|
||||
seq_ids_with_bonus_token_in_last_step=all_seq_ids)
|
||||
for index, output in enumerate(multi_step_output[-1].outputs):
|
||||
assert (continuations[index][-1] == output.samples[0].output_token)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_multi_step_with_batch_expansion_incorrect_output():
|
||||
"""
|
||||
Tests the MultiStepWorker's ability to handle batch expansion with bonus
|
||||
tokens in a negative case scenario. This test provides the MultiStepWorker
|
||||
with a batch containing sequences with bonus tokens but specifies the
|
||||
sequence IDs with bonus tokens incorrectly. The test verifies that the
|
||||
MultiStepWorker generates correct tokens for the sequences where the
|
||||
sequence ID is specified correctly and incorrect tokens for those where
|
||||
the sequence ID is specified incorrectly.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
|
||||
block_size = 16
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
batch_size = 128
|
||||
multi_step_worker = create_worker(
|
||||
MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
multi_step_worker.set_include_gpu_probs_tensor()
|
||||
worker = create_worker(
|
||||
NPUWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
random.seed(seed)
|
||||
prompts = [[0] for _ in range(batch_size)]
|
||||
num_steps = 2
|
||||
final_prompt_lens = [(num_steps + 1) for prompt in prompts]
|
||||
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
|
||||
multi_step_worker.execute_model = patch_execute_model_with_seeds(
|
||||
multi_step_worker, rand_seeds)
|
||||
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
|
||||
# Create the test continuations
|
||||
continuations = [[random.randint(0, 1000)] for _ in prompts]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
# Run single-step twice to generate 2 tokens. This
|
||||
# will simulate the bonus token case with the second token
|
||||
# being the bonus token.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
single_step_output: list[SamplerOutput] = []
|
||||
set_random_seed(seed)
|
||||
for _ in range(num_steps):
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
single_step_output.extend(
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list)))
|
||||
# Append output tokens to new sequence data.
|
||||
for i, seq_group_output in enumerate(single_step_output[-1]):
|
||||
continuations[i].append(seq_group_output.samples[0].output_token)
|
||||
|
||||
# Create continuations for the MultiStepWorker. The continuations have
|
||||
# 2 tokens in order to simulate the bonus token case.
|
||||
multi_step_continuations = []
|
||||
for continuation in continuations:
|
||||
multi_step_continuations.append(continuation[:2])
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=multi_step_continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Run multi-step. In this run INCORRECTLY specify that only the odd number
|
||||
# sequences have bonus tokens. Verify that with this setting the third token
|
||||
# prediction is accurate only for the odd numbered sequences. Also verify
|
||||
# that the prediction might be wrong for some of the even numbered
|
||||
# sequences.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
odd_seq_ids = {i for i in range(batch_size) if i % 2 != 0}
|
||||
multi_step_output, _ = multi_step_worker.sampler_output(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list),
|
||||
sample_len=1,
|
||||
seq_ids_with_bonus_token_in_last_step=odd_seq_ids)
|
||||
num_mismatch = 0
|
||||
for index, output in enumerate(multi_step_output[-1].outputs):
|
||||
if (index % 2) != 0:
|
||||
assert (continuations[index][-1] == output.samples[0].output_token)
|
||||
elif (continuations[index][-1] != output.samples[0].output_token):
|
||||
num_mismatch += 1
|
||||
# The prediction is accurate for some of the sequences even without proper
|
||||
# handling of the bonus tokens. Hence verify that the number of sequences
|
||||
# for which there is a mismatch is > 0.
|
||||
assert (num_mismatch > 0)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize('num_steps', [1, 2, 3, 4])
|
||||
def test_multi_step_correct_kvcache(num_steps):
|
||||
"""Verify that the KV cache of the draft model
|
||||
is correctly updated for sequences with bonus token.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = "JackFram/llama-68m"
|
||||
|
||||
block_size = 16
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
batch_size = 1
|
||||
|
||||
dtype = 'float16'
|
||||
multi_step_worker = create_worker(MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
dtype=dtype)
|
||||
multi_step_worker.set_include_gpu_probs_tensor()
|
||||
worker = create_worker(NPUWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
dtype=dtype)
|
||||
|
||||
prompts = [[0] for _ in range(batch_size)]
|
||||
# Already generate two tokens for the sequence
|
||||
# so that we can simulate the bonus token case
|
||||
multi_step_continuations = [[
|
||||
random.randint(0, 1000),
|
||||
random.randint(0, 1000)
|
||||
] for _ in prompts]
|
||||
final_prompt_lens = [len(prompt) + 2 + num_steps for prompt in prompts]
|
||||
|
||||
seq_ids_with_bonus_token_in_last_step = set(range(batch_size))
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=multi_step_continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Run multi-step.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
multi_step_worker.sampler_output(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list),
|
||||
sample_len=num_steps,
|
||||
seq_ids_with_bonus_token_in_last_step=
|
||||
seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
# Run single-step repeatedly.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
# Generate the kv cache for the bonus token first
|
||||
single_step_continuations = [c[:1] for c in multi_step_continuations]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=single_step_continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
single_step_output = worker.execute_model(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list))
|
||||
for _ in range(num_steps):
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=multi_step_continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
single_step_output = worker.execute_model(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list))
|
||||
|
||||
for i, seq_group_output in enumerate(single_step_output[-1]):
|
||||
multi_step_continuations[i].append(
|
||||
seq_group_output.samples[0].output_token)
|
||||
|
||||
# Verify that the KV cache of the single-step and
|
||||
# multi-step workers are the same.
|
||||
single_step_gpu_cache = worker.cache_engine[0].gpu_cache
|
||||
multi_step_gpu_cache = multi_step_worker.cache_engine[0].gpu_cache
|
||||
num_layers = len(single_step_gpu_cache)
|
||||
allclose = lambda a, b: torch.allclose( # noqa: E731
|
||||
a.npu(), b.npu(), rtol=1e-2, atol=1e-2)
|
||||
for i in range(num_layers):
|
||||
assert allclose(single_step_gpu_cache[i][0],
|
||||
multi_step_gpu_cache[i][0])
|
||||
assert allclose(single_step_gpu_cache[i][1],
|
||||
multi_step_gpu_cache[i][1])
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_draft_proposals_full_speculation_len():
|
||||
"""Verify Top1Proposer correctly handles case where all sequences
|
||||
can speculate.
|
||||
"""
|
||||
k = 10
|
||||
batch_size = 32
|
||||
vocab_size = 32_000
|
||||
device = 'npu:0'
|
||||
|
||||
draft_worker = MagicMock()
|
||||
proposer = Top1Proposer(
|
||||
worker=draft_worker,
|
||||
device=device,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=2048,
|
||||
)
|
||||
draft_worker.sampler_output.return_value = [
|
||||
SamplerOutput(
|
||||
outputs=[],
|
||||
sampled_token_probs=torch.rand(batch_size,
|
||||
vocab_size,
|
||||
device=device,
|
||||
dtype=torch.float32),
|
||||
logprobs=torch.rand(batch_size,
|
||||
vocab_size,
|
||||
device=device,
|
||||
dtype=torch.float32),
|
||||
sampled_token_ids=torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, ),
|
||||
device=device,
|
||||
dtype=torch.long),
|
||||
) for _ in range(k)
|
||||
], True
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k),
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
|
||||
|
||||
assert proposals.proposal_lens.shape == torch.Size([batch_size])
|
||||
assert proposals.proposal_lens.tolist() == [k for _ in range(batch_size)]
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_draft_proposals_no_speculations():
|
||||
"""Verify Top1Proposer correctly handles case where no sequences
|
||||
can speculate.
|
||||
"""
|
||||
k = 10
|
||||
batch_size = 32
|
||||
vocab_size = 32_000
|
||||
device = 'npu:0'
|
||||
prompt_len = 10
|
||||
|
||||
draft_worker = MagicMock()
|
||||
proposer = Top1Proposer(
|
||||
worker=draft_worker,
|
||||
device=device,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=prompt_len + k - 1,
|
||||
)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
prompt_len=prompt_len)
|
||||
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k),
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
|
||||
|
||||
assert proposals.proposal_lens.shape == torch.Size([batch_size])
|
||||
assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)]
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_draft_proposals_mixed_k():
|
||||
"""Verify Top1Proposer correctly handles case some sequences can
|
||||
speculate and some can't.
|
||||
"""
|
||||
k = 10
|
||||
batch_size = 32
|
||||
vocab_size = 32_000
|
||||
device = 'npu:0'
|
||||
|
||||
small_prompt_len = 5
|
||||
long_prompt_len = 10
|
||||
prev_output_token_len = 20
|
||||
|
||||
expected_num_proposal_seqs = 6
|
||||
expected_num_no_proposal_seqs = batch_size - expected_num_proposal_seqs
|
||||
|
||||
prompt_len = [
|
||||
small_prompt_len for _ in range(expected_num_proposal_seqs - 1)
|
||||
] + [long_prompt_len
|
||||
for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len]
|
||||
|
||||
draft_worker = MagicMock()
|
||||
proposer = Top1Proposer(
|
||||
worker=draft_worker,
|
||||
device=device,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=long_prompt_len + prev_output_token_len + k - 1,
|
||||
)
|
||||
|
||||
draft_worker.sampler_output.return_value = [
|
||||
SamplerOutput(
|
||||
outputs=[],
|
||||
sampled_token_probs=torch.rand(expected_num_proposal_seqs,
|
||||
vocab_size,
|
||||
device=device,
|
||||
dtype=torch.float32),
|
||||
logprobs=torch.rand(expected_num_proposal_seqs,
|
||||
vocab_size,
|
||||
device=device,
|
||||
dtype=torch.float32),
|
||||
sampled_token_ids=torch.randint(
|
||||
low=0,
|
||||
high=vocab_size,
|
||||
size=(expected_num_proposal_seqs, ),
|
||||
device=device,
|
||||
dtype=torch.long),
|
||||
) for _ in range(k)
|
||||
], True
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(
|
||||
batch_size,
|
||||
k,
|
||||
prompt_len=prompt_len,
|
||||
prev_output_token_len=prev_output_token_len,
|
||||
)
|
||||
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k),
|
||||
seq_ids_with_bonus_token_in_last_step=set())
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
|
||||
|
||||
assert proposals.proposal_lens.shape == torch.Size([batch_size])
|
||||
assert proposals.proposal_lens.tolist() == [
|
||||
k for _ in range(expected_num_proposal_seqs - 1)
|
||||
] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k]
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_use_draft_model_runner_advance_step():
|
||||
"""Verify that draft model runner triggers advance step
|
||||
when applicable.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
|
||||
k = 5
|
||||
batch_size = 32
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
worker = create_worker(
|
||||
MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
|
||||
# Mock "_gpu_advance_step" to raise an exception when called.
|
||||
exception_secret = "artificial stop"
|
||||
worker.model_runner._gpu_advance_step = MagicMock()
|
||||
worker.model_runner._gpu_advance_step.side_effect = ValueError(
|
||||
exception_secret)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# Fallback (should not call) when num_steps=1.
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k,
|
||||
num_steps=1)
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
# Expect exception if _gpu_advance_step is called.
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k,
|
||||
num_steps=k)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
call_args_list = worker.model_runner._gpu_advance_step.call_args_list
|
||||
assert len(call_args_list) == 1
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_expand_execute_model_request_sync_with_expand_hidden_states():
|
||||
"""
|
||||
In this test we verify that the logic for expanding the
|
||||
seq_group_metadata_list remains in sync with the expansion logic of
|
||||
the HiddenStates in _expand_execute_model_request.
|
||||
"""
|
||||
k = 5
|
||||
batch_size = 16
|
||||
seq_with_bonus_token_in_last_step = [1, 3, 8, 10, 13, 15]
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
execute_model_request = ExecuteModelRequest(
|
||||
seq_group_metadata_list,
|
||||
previous_hidden_states=HiddenStates(
|
||||
torch.arange(batch_size), seq_group_metadata_list,
|
||||
torch.arange(batch_size, 2 * batch_size)))
|
||||
|
||||
expanded_execute_model_request, orig_seq_group_ids = MultiStepWorker.\
|
||||
_expand_execute_model_request(execute_model_request,
|
||||
seq_with_bonus_token_in_last_step)
|
||||
|
||||
all_seq_ids = torch.tensor(
|
||||
get_all_seq_ids(
|
||||
expanded_execute_model_request.seq_group_metadata_list))
|
||||
ref_expanded_hidden_states = all_seq_ids + batch_size
|
||||
ref_expanded_hidden_states[orig_seq_group_ids] -= batch_size
|
||||
|
||||
assert (ref_expanded_hidden_states == expanded_execute_model_request.
|
||||
previous_hidden_states.hidden_states).all().item()
|
||||
@@ -1,237 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/tests/spec_decode/test_ngram_worker.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import torch
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
|
||||
from tests.singlecard.spec_decode.utils import (
|
||||
create_seq_group_metadata_from_prompts, create_worker)
|
||||
|
||||
|
||||
def test_ngram_algo_correctness_for_single_no_match():
|
||||
"""Verify our ngram algo find the right candidate in the prompt
|
||||
|
||||
For the scenario cannot find any candidate in one single batch
|
||||
"""
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
vocab_size = 32_000
|
||||
device = 'npu:0'
|
||||
|
||||
ngram_worker = create_worker(
|
||||
NGramWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
|
||||
proposer = Top1Proposer(
|
||||
worker=ngram_worker,
|
||||
device=device,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=20,
|
||||
)
|
||||
|
||||
# set ngram window [1, 3], which is window=1/2/3
|
||||
ngram_worker.set_ngram_window_size(1, 3)
|
||||
|
||||
prompts = [
|
||||
# shall find no candidate
|
||||
[1, 2, 3, 4, 5, 6, 7],
|
||||
]
|
||||
|
||||
proposal_len = 5
|
||||
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=proposal_len),
|
||||
seq_ids_with_bonus_token_in_last_step=None)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([1, proposal_len])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([1, proposal_len])
|
||||
assert proposals.proposal_lens.shape == torch.Size([1])
|
||||
assert proposals.proposal_lens.tolist() == [0]
|
||||
|
||||
|
||||
def test_ngram_algo_correctness_for_batches_not_match_all():
|
||||
"""Verify our ngram algo find the right candidate in the prompt
|
||||
|
||||
For the scenario find some candidate not full in batchs
|
||||
"""
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
vocab_size = 32_000
|
||||
device = 'npu:0'
|
||||
|
||||
ngram_worker = create_worker(
|
||||
NGramWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
|
||||
proposer = Top1Proposer(
|
||||
worker=ngram_worker,
|
||||
device=device,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=20,
|
||||
)
|
||||
|
||||
# set ngram window [1, 3], which is window=1/2/3
|
||||
ngram_worker.set_ngram_window_size(1, 3)
|
||||
|
||||
prompts = [
|
||||
# shall find no candidate
|
||||
[1, 2, 3, 4, 5, 6, 7],
|
||||
# shall find candidate 12,13,14,15,16
|
||||
[11, 12, 13, 14, 15, 16, 11],
|
||||
# shall find candidate 23,24,25,26,21
|
||||
[21, 21, 22, 23, 24, 25, 26, 21, 22],
|
||||
# shall find candidate 34,35,36,37,38
|
||||
[31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
|
||||
# shall find no candidate as exceed max_proposal_len
|
||||
[
|
||||
31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 33, 34, 35, 36, 37,
|
||||
38, 31, 32, 33
|
||||
],
|
||||
]
|
||||
|
||||
proposal_len = 5
|
||||
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
for sg in seq_group_metadata_list:
|
||||
sg.is_prompt = False
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=proposal_len),
|
||||
seq_ids_with_bonus_token_in_last_step=None)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([5, proposal_len])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([5, proposal_len])
|
||||
assert proposals.proposal_lens.shape == torch.Size([5])
|
||||
|
||||
# the first sequence has no match so proposal_len should be overwritten to 0
|
||||
assert proposals.proposal_lens.tolist(
|
||||
) == [0] + [proposal_len for _ in range(3)] + [0]
|
||||
|
||||
for i in range(proposal_len):
|
||||
assert proposals.proposal_token_ids[0][i] == -1
|
||||
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 1]
|
||||
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 3]
|
||||
assert proposals.proposal_token_ids[3][i] == prompts[3][i + 5]
|
||||
assert proposals.proposal_token_ids[4][i] == -1
|
||||
|
||||
|
||||
def test_ngram_algo_correctness_for_batches_match_all():
|
||||
"""Verify our ngram algo find the right candidate in the prompt
|
||||
|
||||
For the scenario find candidate in all batches
|
||||
"""
|
||||
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
vocab_size = 32_000
|
||||
device = 'npu:0'
|
||||
|
||||
ngram_worker = create_worker(
|
||||
NGramWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
|
||||
proposer = Top1Proposer(
|
||||
worker=ngram_worker,
|
||||
device=device,
|
||||
vocab_size=vocab_size,
|
||||
max_proposal_len=20,
|
||||
)
|
||||
|
||||
# set ngram window [0, 3], which is window=1/2/3
|
||||
ngram_worker.set_ngram_window_size(1, 3)
|
||||
|
||||
prompts = [
|
||||
# shall find candidate 12,13,14,15,16
|
||||
[11, 12, 13, 14, 15, 16, 11],
|
||||
# shall find candidate 23,24,25,26,21
|
||||
[21, 21, 22, 23, 24, 25, 26, 21, 22],
|
||||
# shall find candidate 34,35,36,37,38
|
||||
[31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
|
||||
]
|
||||
|
||||
proposal_len = 5
|
||||
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Normally drafter is run on decode requests only; here we check the output
|
||||
# of the ngram worker as it is the sole proposer that has no forward.
|
||||
for sg in seq_group_metadata_list:
|
||||
sg.is_prompt = False
|
||||
proposals = proposer.get_spec_proposals(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=proposal_len),
|
||||
seq_ids_with_bonus_token_in_last_step=None)
|
||||
|
||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||
assert torch.is_tensor(proposals.proposal_probs)
|
||||
|
||||
assert proposals.proposal_token_ids.shape == torch.Size([3, proposal_len])
|
||||
assert proposals.proposal_probs.shape[:-1] == torch.Size([3, proposal_len])
|
||||
assert proposals.proposal_lens.shape == torch.Size([3])
|
||||
|
||||
assert proposals.proposal_lens.tolist() == [proposal_len for _ in range(3)]
|
||||
|
||||
for i in range(proposal_len):
|
||||
assert proposals.proposal_token_ids[0][i] == prompts[0][i + 1]
|
||||
assert proposals.proposal_token_ids[1][i] == prompts[1][i + 3]
|
||||
assert proposals.proposal_token_ids[2][i] == prompts[2][i + 5]
|
||||
@@ -1,957 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/tests/spec_decode/test_spec_decode_worker.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest, SequenceOutput
|
||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
|
||||
SpecDecodeWorkerMetrics)
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
|
||||
split_num_cache_blocks_evenly)
|
||||
|
||||
from tests.singlecard.spec_decode.test_utils import mock_spec_decode_sampler
|
||||
from tests.singlecard.spec_decode.utils import (create_batch,
|
||||
create_sampler_output_list,
|
||||
create_worker, mock_worker)
|
||||
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
|
||||
from vllm_ascend.worker.worker import NPUWorker
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_correctly_calls_draft_model(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker calls the draft worker with correct
|
||||
inputs. Everything else is mocked out.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker,
|
||||
target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
exception_secret = 'artificial stop'
|
||||
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
call_args_list = draft_worker.get_spec_proposals.call_args_list
|
||||
assert len(call_args_list) == 1
|
||||
|
||||
for args, _ in call_args_list:
|
||||
actual_execute_model_data = args[0]
|
||||
assert actual_execute_model_data == execute_model_req
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_batch_expansion_correctly_calls_target_model(
|
||||
k: int, batch_size: int, acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker calls the target model with correct
|
||||
inputs with batch expansion. Everything else is mocked out.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
||||
target_worker = mock_worker(use_spec=False)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
draft_worker.device = 'npu'
|
||||
target_worker.device = 'npu'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker,
|
||||
target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector,
|
||||
disable_mqa_scorer=True)
|
||||
worker.init_device()
|
||||
|
||||
vocab_size = 32_000
|
||||
|
||||
proposal_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64,
|
||||
device='npu')
|
||||
proposal_probs = torch.rand(batch_size,
|
||||
k,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='npu') * k
|
||||
|
||||
seq_group_metadata_list, prompts, prev_output_tokens = create_batch(
|
||||
batch_size, k)
|
||||
|
||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_token_ids,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens)
|
||||
|
||||
exception_secret = 'artificial stop'
|
||||
target_worker.execute_model.side_effect = ValueError(exception_secret)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
|
||||
seen_contexts: list[list[int]] = []
|
||||
|
||||
call_args_list = target_worker.execute_model.call_args_list
|
||||
assert len(call_args_list) == 1
|
||||
for _, kwargs in call_args_list:
|
||||
seq_group_metadata_list = kwargs[
|
||||
"execute_model_req"].seq_group_metadata_list
|
||||
|
||||
assert len(seq_group_metadata_list) == (k + 1) * batch_size
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
for seq_data in seq_group_metadata.seq_data.values():
|
||||
seen_contexts.append(seq_data.get_token_ids())
|
||||
|
||||
expected_seen_contexts: list[list[int]] = []
|
||||
|
||||
for prompt, prev_generated, draft_tokens in zip(
|
||||
prompts, prev_output_tokens, proposal_token_ids.tolist()):
|
||||
|
||||
for i in range(len(draft_tokens) + 1):
|
||||
expected_seen_contexts.append(prompt + prev_generated +
|
||||
draft_tokens[:i])
|
||||
|
||||
seen_contexts.sort()
|
||||
expected_seen_contexts.sort()
|
||||
assert expected_seen_contexts == seen_contexts
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker calls the rejection sampler with
|
||||
correct inputs. Everything else is mocked out.
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
|
||||
draft_worker = mock_worker(cls=MultiStepWorker,
|
||||
vocab_size=vocab_size,
|
||||
use_spec=False)
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
draft_worker.device = 'npu'
|
||||
target_worker.device = 'npu'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
spec_decode_sampler,
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
proposal_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64,
|
||||
device='npu')
|
||||
proposal_probs = torch.rand(batch_size,
|
||||
k,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
|
||||
proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='npu') * k
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_token_ids,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens)
|
||||
|
||||
target_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(1, batch_size * (k + 1)),
|
||||
dtype=torch.int64,
|
||||
device='npu')
|
||||
target_token_probs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
target_token_logprobs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs,
|
||||
target_token_logprobs)
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
exception_secret = 'artificial stop'
|
||||
|
||||
spec_decode_sampler.side_effect = ValueError(exception_secret)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
|
||||
assert len(spec_decode_sampler.call_args_list) == 1
|
||||
_, kwargs = spec_decode_sampler.call_args_list[0]
|
||||
actual = SimpleNamespace(**kwargs)
|
||||
|
||||
assert torch.equal(actual.bonus_token_ids,
|
||||
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
|
||||
assert torch.equal(actual.target_with_bonus_probs,
|
||||
target_token_probs.reshape(batch_size, k + 1, -1))
|
||||
assert torch.equal(actual.draft_token_ids, proposal_token_ids)
|
||||
assert torch.equal(actual.draft_probs, proposal_probs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_correctly_formats_output(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker formats sampler output correctly.
|
||||
Everything else is mocked out.
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
|
||||
draft_worker = mock_worker(cls=MultiStepWorker,
|
||||
vocab_size=vocab_size,
|
||||
use_spec=False)
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
draft_worker.device = 'npu'
|
||||
target_worker.device = 'npu'
|
||||
|
||||
set_random_seed(1)
|
||||
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
spec_decode_sampler,
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
proposal_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64,
|
||||
device='npu')
|
||||
proposal_probs = torch.rand(batch_size,
|
||||
k,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
|
||||
proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='npu') * k
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_token_ids,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens)
|
||||
|
||||
target_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(1, batch_size * (k + 1)),
|
||||
dtype=torch.int64,
|
||||
device='npu')
|
||||
target_token_probs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
target_token_logprobs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs,
|
||||
target_token_logprobs)
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
spec_decode_sampler_output = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k + 1),
|
||||
dtype=torch.int64,
|
||||
device='npu')
|
||||
for i in range(batch_size):
|
||||
minimum_accepted_tokens = 1
|
||||
spec_decode_sampler_output[i][
|
||||
-random.randint(minimum_accepted_tokens, k + 1):] = -1
|
||||
|
||||
spec_decode_sampler.return_value = spec_decode_sampler_output
|
||||
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
|
||||
expected_output = create_sampler_output_list(
|
||||
token_ids=spec_decode_sampler_output.transpose(0, 1),
|
||||
probs=[None for _ in range(k + 1)],
|
||||
logprobs=[None for _ in range(k + 1)])
|
||||
|
||||
seq_ids = [
|
||||
next(iter(seq_group_metadata.seq_data.keys()))
|
||||
for seq_group_metadata in seq_group_metadata_list
|
||||
]
|
||||
actual_output_by_seq: dict[int, list[SequenceOutput]] = {
|
||||
seq_id: []
|
||||
for seq_id in seq_ids
|
||||
}
|
||||
expected_output_by_seq: dict[int, list[SequenceOutput]] = {
|
||||
seq_id: []
|
||||
for seq_id in seq_ids
|
||||
}
|
||||
|
||||
for step in output:
|
||||
for seq_group in step:
|
||||
for sample in seq_group.samples:
|
||||
seq_id = sample.parent_seq_id
|
||||
actual_output_by_seq[seq_id].append(sample)
|
||||
|
||||
for step in expected_output:
|
||||
for seq_group in step:
|
||||
for sample in seq_group.samples:
|
||||
seq_id = sample.parent_seq_id
|
||||
expected_output_by_seq[seq_id].append(sample)
|
||||
|
||||
all_seen_seq_ids = set(
|
||||
list(actual_output_by_seq.keys()) +
|
||||
list(expected_output_by_seq.keys()))
|
||||
for seq_id in all_seen_seq_ids:
|
||||
actual_by_step = actual_output_by_seq[seq_id]
|
||||
expected_by_step = expected_output_by_seq[seq_id]
|
||||
|
||||
for i in range(k + 1):
|
||||
if i >= len(actual_by_step):
|
||||
assert expected_by_step[i].output_token == -1
|
||||
continue
|
||||
assert actual_by_step[i].output_token == expected_by_step[
|
||||
i].output_token
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [1, 2])
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
@pytest.mark.parametrize('returns_metrics', [True, False])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker collects metrics.
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
|
||||
draft_worker = mock_worker(cls=MultiStepWorker,
|
||||
vocab_size=vocab_size,
|
||||
use_spec=False)
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
draft_worker.device = 'npu'
|
||||
target_worker.device = 'npu'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
spec_decode_sampler,
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
proposal_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64,
|
||||
device='npu')
|
||||
proposal_probs = torch.rand(batch_size,
|
||||
k,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
|
||||
proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='npu') * k
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||
proposal_token_ids=proposal_token_ids,
|
||||
proposal_probs=proposal_probs,
|
||||
proposal_lens=proposal_lens)
|
||||
|
||||
target_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(1, batch_size * (k + 1)),
|
||||
dtype=torch.int64,
|
||||
device='npu')
|
||||
target_token_probs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
target_token_logprobs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs,
|
||||
target_token_logprobs)
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
spec_decode_sampler_output = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k + 1),
|
||||
dtype=torch.int64,
|
||||
device='npu')
|
||||
for i in range(batch_size):
|
||||
minimum_accepted_tokens = 1
|
||||
spec_decode_sampler_output[i][
|
||||
-random.randint(minimum_accepted_tokens, k + 1):] = -1
|
||||
spec_decode_sampler.return_value = spec_decode_sampler_output
|
||||
|
||||
mock_rejsample_metrics = MagicMock(
|
||||
spec=SpecDecodeWorkerMetrics) if returns_metrics else None
|
||||
metrics_collector.maybe_collect_rejsample_metrics.return_value = (
|
||||
mock_rejsample_metrics)
|
||||
|
||||
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k))
|
||||
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
|
||||
|
||||
call_args_list = (
|
||||
metrics_collector.maybe_collect_rejsample_metrics.call_args_list)
|
||||
assert len(call_args_list) == 1
|
||||
args, kwargs = call_args_list[0]
|
||||
assert args[0] == k or kwargs.get('k', -1) == k
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [0])
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, 32])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_k_equals_zero(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify that the SpecDecodeWorker calls the draft and target workers
|
||||
when k is zero. This happens during prefill.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
sampler_output = MagicMock(spec=SamplerOutput)
|
||||
sampler_output.hidden_states = None
|
||||
target_worker.execute_model.return_value = [sampler_output]
|
||||
|
||||
draft_worker.device = 'npu'
|
||||
target_worker.device = 'npu'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(
|
||||
proposer_worker=draft_worker,
|
||||
scorer_worker=target_worker,
|
||||
spec_decode_sampler=mock_spec_decode_sampler(
|
||||
acceptance_sampler_method),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector,
|
||||
)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
prev_output_token_len=0)
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
|
||||
|
||||
out = worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
assert len(out) == 1, f"expected only one token output when {k=}"
|
||||
assert out[0].sampled_token_probs is None, (
|
||||
"expect gpu tensor references to be None")
|
||||
assert out[
|
||||
0].sampled_token_ids is None, "expect gpu tensor references to be None"
|
||||
|
||||
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [0, 5])
|
||||
@pytest.mark.parametrize('batch_size', [0])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_empty_input_batch(k: int, batch_size: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify that the SpecDecodeWorker calls the draft and target workers
|
||||
when the input batch is empty. This can happen if the engine communicates
|
||||
to the workers information without scheduling a batch.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
sampler_output = MagicMock(spec=SamplerOutput)
|
||||
sampler_output.hidden_states = None
|
||||
target_worker.execute_model.return_value = [sampler_output]
|
||||
|
||||
draft_worker.device = 'npu'
|
||||
target_worker.device = 'npu'
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(
|
||||
proposer_worker=draft_worker,
|
||||
scorer_worker=target_worker,
|
||||
spec_decode_sampler=mock_spec_decode_sampler(
|
||||
acceptance_sampler_method),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector,
|
||||
)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
prev_output_token_len=0)
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
|
||||
|
||||
out = worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
assert len(out) == 1, f"expected only one token output when {k=}"
|
||||
assert out[0].sampled_token_probs is None, (
|
||||
"expect gpu tensor references to be None")
|
||||
assert out[
|
||||
0].sampled_token_ids is None, "expect gpu tensor references to be None"
|
||||
|
||||
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_init_device(acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
|
||||
well as other GPU initialization.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
||||
target_worker = mock_worker(use_spec=False)
|
||||
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
worker = SpecDecodeWorker(
|
||||
proposer_worker=draft_worker,
|
||||
scorer_worker=target_worker,
|
||||
spec_decode_sampler=spec_decode_sampler,
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector,
|
||||
)
|
||||
worker.init_device()
|
||||
|
||||
draft_worker.init_device.assert_called_once()
|
||||
|
||||
target_worker.init_device.assert_called_once()
|
||||
|
||||
metrics_collector.init_tensors.assert_called_once()
|
||||
spec_decode_sampler.init_tensors.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@torch.inference_mode()
|
||||
def test_initialize_cache(acceptance_sampler_method):
|
||||
"""Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer
|
||||
workers.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
worker = SpecDecodeWorker(proposer_worker=draft_worker,
|
||||
scorer_worker=target_worker,
|
||||
spec_decode_sampler=mock_spec_decode_sampler(
|
||||
acceptance_sampler_method),
|
||||
metrics_collector=metrics_collector)
|
||||
|
||||
kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
|
||||
worker.initialize_cache(**kwargs)
|
||||
|
||||
draft_worker.initialize_cache.assert_called_once_with(**kwargs)
|
||||
target_worker.initialize_cache.assert_called_once_with(**kwargs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('available_gpu_blocks', [1, 1024])
|
||||
@pytest.mark.parametrize('available_cpu_blocks', [500])
|
||||
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
|
||||
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
|
||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||
["rejection_sampler", "typical_acceptance_sampler"])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_determine_num_available_blocks(available_gpu_blocks: int,
|
||||
available_cpu_blocks: int,
|
||||
target_cache_block_size_bytes: int,
|
||||
draft_kv_size_bytes: int,
|
||||
acceptance_sampler_method: str):
|
||||
"""Verify SpecDecodeWorker correctly profiles num available GPU blocks.
|
||||
Specifically, it should run profiling in the scorer worker, and then evenly
|
||||
split the blocks between proposer and scorer worker.
|
||||
"""
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
target_worker.determine_num_available_blocks.return_value = (
|
||||
available_gpu_blocks, available_cpu_blocks)
|
||||
target_worker.get_cache_block_size_bytes.return_value = (
|
||||
target_cache_block_size_bytes)
|
||||
draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes
|
||||
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker, target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||
|
||||
num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks()
|
||||
|
||||
target_worker.determine_num_available_blocks.assert_called_once()
|
||||
assert num_cpu_blocks == available_cpu_blocks
|
||||
|
||||
assert num_gpu_blocks == split_num_cache_blocks_evenly(
|
||||
target_cache_block_size_bytes, draft_kv_size_bytes,
|
||||
available_gpu_blocks)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('available_gpu_blocks',
|
||||
list(range(20)) + [1024, 1024**2])
|
||||
@pytest.mark.parametrize('target_cache_block_size_bytes',
|
||||
[2 * 2 * 4096, 2 * 2 * 8192])
|
||||
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
|
||||
target_cache_block_size_bytes: int,
|
||||
draft_kv_size_bytes: int):
|
||||
"""Verify split_num_cache_blocks_evenly does not exceed original memory
|
||||
allocation in bytes.
|
||||
"""
|
||||
num_blocks = split_num_cache_blocks_evenly(target_cache_block_size_bytes,
|
||||
draft_kv_size_bytes,
|
||||
available_gpu_blocks)
|
||||
assert (num_blocks * target_cache_block_size_bytes) + (
|
||||
num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks *
|
||||
target_cache_block_size_bytes)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_populate_seq_ids_with_bonus_tokens():
|
||||
"""
|
||||
Verify that a call to _create_output_sampler_list correctly updates
|
||||
seq_with_bonus_token_in_last_step.
|
||||
|
||||
seq_with_bonus_token_in_last_step is an internal data structure in
|
||||
SpecDecodeWorker that tracks the sequence IDs which are assigned bonus
|
||||
tokens by the target model in their last forward pass. This state is
|
||||
maintained only for models relying on the KV cache, such as those using
|
||||
the MultiStepWorker.
|
||||
"""
|
||||
batch_size = 10
|
||||
k = 5
|
||||
vocab_size = 10000
|
||||
num_sequences_with_bonus_tokens = 5
|
||||
target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
|
||||
target_worker.device = 'npu'
|
||||
|
||||
set_random_seed(1)
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
draft_worker.device = 'npu'
|
||||
# The sequence_ids attached to each sequence in the batch.
|
||||
# The sequence at index i has seq_id assigned_seq_ids[i]
|
||||
assigned_seq_ids = list(range(batch_size))
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
seq_ids=assigned_seq_ids,
|
||||
prev_output_token_len=10)
|
||||
target_token_logprobs = torch.rand(batch_size, (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
accepted_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, (k + 1)),
|
||||
dtype=torch.int64,
|
||||
device='npu')
|
||||
expected_request_id_seq_ids_mapping: dict[str, set[int]] = defaultdict(set)
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
for seq_id in seq_group_metadata.seq_data:
|
||||
expected_request_id_seq_ids_mapping[
|
||||
seq_group_metadata.request_id].add(seq_id)
|
||||
# Generate a random sample of sequence indexes with bonus tokens
|
||||
seq_indexes_with_bonus_tokens = random.sample(
|
||||
range(batch_size), num_sequences_with_bonus_tokens)
|
||||
# Create a mask that is True for indices in seq_indexes_with_bonus_tokens
|
||||
mask = torch.ones(batch_size, dtype=torch.bool, device='npu')
|
||||
mask[seq_indexes_with_bonus_tokens] = False
|
||||
# Set the last token ID to -1 for all indices not in
|
||||
# seq_indexes_with_bonus_tokens to indicate the lack of bonus token in
|
||||
# those indices.
|
||||
accepted_token_ids[mask, -1:] = -1
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
mock_spec_decode_sampler("rejection_sampler"),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
# Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs.
|
||||
# This set includes all sequence IDs in the batch as well as an additional
|
||||
# `num_extra_sequence_ids` sequence IDs. Note that the sequence IDs are in
|
||||
# the range [0, batch_size + num_extra_sequence_ids).
|
||||
num_extra_sequence_ids = 10
|
||||
worker._seq_with_bonus_token_in_last_step = set(
|
||||
range(batch_size + num_extra_sequence_ids))
|
||||
worker._create_output_sampler_list(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
accepted_token_ids=accepted_token_ids,
|
||||
target_logprobs=target_token_logprobs,
|
||||
prompt_logprobs=None,
|
||||
k=k,
|
||||
stage_times=(0, 0, 0))
|
||||
# Verify that _seq_with_bonus_token_in_last_step contains the following:
|
||||
# 1. Sequence IDs that were already present in
|
||||
# _seq_with_bonus_token_in_last_step but were not part of the current
|
||||
# batch are retained.
|
||||
# 2. Of the sequence IDs present in the current batch, only those with a
|
||||
# bonus token are retained in _seq_with_bonus_token_in_last_step.
|
||||
# Sequence IDs that are present in the current batch but do not have
|
||||
# bonus tokens are removed from _seq_with_bonus_token_in_last_step.
|
||||
expected_seq_ids_with_bonus_tokens = \
|
||||
set([assigned_seq_ids[i] for i in seq_indexes_with_bonus_tokens])
|
||||
additional_sequence_ids = \
|
||||
set(range(batch_size, batch_size + num_extra_sequence_ids))
|
||||
assert worker._seq_with_bonus_token_in_last_step == \
|
||||
expected_seq_ids_with_bonus_tokens.union(additional_sequence_ids)
|
||||
assert worker._request_id_seq_id_mapping == \
|
||||
expected_request_id_seq_ids_mapping
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_handle_finished_requests():
|
||||
"""
|
||||
Test to verify that finished request IDs are appropriately processed to
|
||||
update the internal state of the SpecDecodeWorker.
|
||||
|
||||
This test initializes the SpecDecodeWorker with mock data, marks certain
|
||||
requests as finished, and ensures that the corresponding sequence IDs are
|
||||
correctly removed from the internal mappings.
|
||||
"""
|
||||
batch_size = 32
|
||||
k = 3
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker,
|
||||
mock_spec_decode_sampler("rejection_sampler"),
|
||||
metrics_collector)
|
||||
# Initialize the request_id_seq_id_mapping mapping dict with a few fake
|
||||
# request ids and corresponding sequence ids.
|
||||
worker._request_id_seq_id_mapping = \
|
||||
{'request-1': {1,2,3}, 'request-2': {4,5,6,7},
|
||||
'request-3': {8,9}, 'request-4': {10,11}}
|
||||
# Initialize seq_with_bonus_token_in_last_step with a few fake
|
||||
# sequence ids.
|
||||
worker._seq_with_bonus_token_in_last_step = {1, 4, 5, 8, 9, 10}
|
||||
exception_secret = 'artificial stop'
|
||||
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
# Mark requests with ids request-1 and request-3 as finished.
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k,
|
||||
finished_requests_ids=['request-1', 'request-3'])
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
# Verify that request-1 and request-3 are removed from
|
||||
# request_id_seq_id_mapping
|
||||
assert worker._request_id_seq_id_mapping == \
|
||||
{'request-2': {4,5,6,7}, 'request-4': {10,11}}
|
||||
# Verify that all sequence ids corresponding to 'request-1'
|
||||
# and 'request-3' are removed from seq_with_bonus_token_in_last_step.
|
||||
assert worker._seq_with_bonus_token_in_last_step == \
|
||||
{4,5,10}
|
||||
|
||||
|
||||
@pytest.mark.parametrize('k', [3])
|
||||
@pytest.mark.parametrize('batch_size', [2, 32])
|
||||
@pytest.mark.parametrize("batch_composition",
|
||||
["prefill_only", "decode_only", "mixed"])
|
||||
@torch.inference_mode()
|
||||
def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str):
|
||||
"""
|
||||
Verify SpecDecodeWorker calls match the expected flow.
|
||||
"""
|
||||
vocab_size = 32_000
|
||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
mock_spec_decode_sampler("rejection_sampler"),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
exception_secret = 'artificial stop'
|
||||
worker.scorer = mock_worker(BatchExpansionTop1Scorer)
|
||||
worker.scorer.score_proposals.side_effect = ValueError(exception_secret)
|
||||
|
||||
# Create batch with combination of terminal/non-terminal prefill chunks
|
||||
# and decodes (different seq_ids).
|
||||
decodes, _, _ = create_batch(batch_size, k)
|
||||
# Pre-chunking here, get 'batch_size' chunks.
|
||||
prefill, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
prefill_chunk_size=4,
|
||||
seq_ids=list(range(batch_size,
|
||||
batch_size * 2)))
|
||||
|
||||
if batch_composition == "prefill_only":
|
||||
n_prefills = batch_size
|
||||
elif batch_composition == "decode_only":
|
||||
n_prefills = 0
|
||||
else:
|
||||
n_prefills = random.randint(1, batch_size - 1)
|
||||
n_decodes = batch_size - n_prefills
|
||||
|
||||
prefill = random.sample(prefill, n_prefills)
|
||||
decodes = random.sample(decodes, n_decodes)
|
||||
target_group_metadata_list = prefill + decodes
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=target_group_metadata_list,
|
||||
# For prefill only batches we expect num_lookahead_slots = 0.
|
||||
num_lookahead_slots=k if n_decodes > 0 else 0)
|
||||
|
||||
target_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(1, batch_size * (k + 1)),
|
||||
dtype=torch.int64,
|
||||
device='npu')
|
||||
target_token_probs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
target_token_logprobs = torch.rand(1,
|
||||
batch_size * (k + 1),
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device='npu')
|
||||
target_output = create_sampler_output_list(target_token_ids,
|
||||
target_token_probs,
|
||||
target_token_logprobs)
|
||||
|
||||
target_worker.execute_model.return_value = [target_output[0]]
|
||||
|
||||
if not len(decodes):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
# no spec run (prefill only)
|
||||
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||
else:
|
||||
# Decode-only run OR mixed batch, scorer call fails (it's mocked)
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
# but first draft still counted
|
||||
assert draft_worker.get_spec_proposals.call_count == 1
|
||||
|
||||
|
||||
def test_correctly_load_weight_for_eagle():
|
||||
"""
|
||||
Verify SpecDecodeWorker loads lm_head weight for eagle correctly.
|
||||
"""
|
||||
seed = 100
|
||||
block_size = 32
|
||||
num_gpu_blocks = 8096 // block_size
|
||||
target_worker = create_worker(
|
||||
NPUWorker,
|
||||
"JackFram/llama-68m",
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
draft_worker = create_worker(
|
||||
MultiStepWorker,
|
||||
"abhigoyal/vllm-eagle-llama-68m-random",
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
|
||||
spec_decode_sampler = mock_spec_decode_sampler("rejection_sampler")
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
spec_decode_sampler,
|
||||
disable_logprobs=False)
|
||||
worker.proposer_worker.maybe_load_lm_head_weight(
|
||||
target_worker.model_runner.model.lm_head.weight.data)
|
||||
assert torch.allclose(
|
||||
worker.proposer_worker.worker.model_runner.model.lm_head.weight.data,
|
||||
worker.scorer_worker.model_runner.model.lm_head.weight.data)
|
||||
@@ -1,165 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/tests/spec_decode/test_utils.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.layers.sampler import _get_ranks
|
||||
from vllm.model_executor.layers.typical_acceptance_sampler import \
|
||||
TypicalAcceptanceSampler
|
||||
from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids
|
||||
from vllm.spec_decode.util import (get_sampled_token_logprobs,
|
||||
split_batch_by_proposal_len)
|
||||
|
||||
|
||||
def test_get_all_seq_ids():
|
||||
"""Verify get_all_seq_ids extracts all seq ids.
|
||||
"""
|
||||
expected_seq_ids = list(range(10)) + list(range(100, 110))
|
||||
|
||||
seq_group_metadata_list = [
|
||||
SequenceGroupMetadata(
|
||||
request_id=str(seq_id),
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
seq_id: MagicMock(),
|
||||
},
|
||||
sampling_params=MagicMock(),
|
||||
block_tables={
|
||||
seq_id: MagicMock(),
|
||||
},
|
||||
lora_request=None,
|
||||
) for seq_id in expected_seq_ids
|
||||
]
|
||||
|
||||
actual_seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
||||
assert actual_seq_ids == expected_seq_ids
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_sequence_group_metadata():
|
||||
seq_ids = list(range(3))
|
||||
return [
|
||||
SequenceGroupMetadata(
|
||||
request_id=str(i),
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
i: MagicMock(),
|
||||
},
|
||||
sampling_params=MagicMock(),
|
||||
block_tables={
|
||||
i: MagicMock(),
|
||||
},
|
||||
lora_request=None,
|
||||
) for i in seq_ids
|
||||
]
|
||||
|
||||
|
||||
def test_filter_zero_length_proposals(fake_sequence_group_metadata):
|
||||
proposal_lens = [0, 1, 0]
|
||||
_, (filtered_groups,
|
||||
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
||||
proposal_lens)
|
||||
|
||||
expected_groups = [
|
||||
fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
|
||||
]
|
||||
expected_indices = [0, 2]
|
||||
|
||||
assert filtered_groups == expected_groups
|
||||
assert indices == expected_indices
|
||||
|
||||
|
||||
def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
|
||||
proposal_lens = [0, 1, 2]
|
||||
(filtered_groups,
|
||||
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
||||
proposal_lens)
|
||||
|
||||
expected_groups = [
|
||||
fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
|
||||
]
|
||||
expected_indices = [1, 2]
|
||||
|
||||
assert filtered_groups == expected_groups
|
||||
assert indices == expected_indices
|
||||
|
||||
|
||||
def test_empty_inputs():
|
||||
_, (filtered_groups, indices) = split_batch_by_proposal_len([], [])
|
||||
|
||||
assert filtered_groups == []
|
||||
assert indices == []
|
||||
|
||||
|
||||
def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
|
||||
proposal_lens = [0, 0, 0]
|
||||
(filtered_groups,
|
||||
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
||||
proposal_lens)
|
||||
|
||||
assert filtered_groups == []
|
||||
assert indices == []
|
||||
|
||||
|
||||
def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
|
||||
proposal_lens = [1, 1, 1]
|
||||
_, (filtered_groups,
|
||||
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
|
||||
proposal_lens)
|
||||
|
||||
assert filtered_groups == []
|
||||
assert indices == []
|
||||
|
||||
|
||||
def mock_spec_decode_sampler(acceptance_sampler_method):
|
||||
"""
|
||||
Returns either a RejectionSampler or TypicalAcceptanceSampler
|
||||
object depending on whether acceptance_sampler_method is
|
||||
'rejection_sampler' or 'typical_acceptance_sampler' respectively.
|
||||
"""
|
||||
if acceptance_sampler_method == "rejection_sampler":
|
||||
sampler = MagicMock(spec=RejectionSampler)
|
||||
sampler.token_id_dtype = torch.int64
|
||||
return sampler
|
||||
elif acceptance_sampler_method == "typical_acceptance_sampler":
|
||||
sampler = MagicMock(spec=TypicalAcceptanceSampler)
|
||||
sampler.token_id_dtype = torch.int64
|
||||
return sampler
|
||||
else:
|
||||
raise ValueError(f"Invalid sampler name {acceptance_sampler_method}")
|
||||
|
||||
|
||||
def test_get_sampled_token_logprobs():
|
||||
"""Verify get_sampled_token_logprobs returns consistent rankings
|
||||
with regular get_ranks when probabilities match exactly.
|
||||
"""
|
||||
logprob_tensor = torch.tensor(
|
||||
[[[-.1, -.1]] * 2]) # shape (num_steps, batch_size, vocab_size)
|
||||
sampled_token_tensor = torch.tensor([[1,
|
||||
0]]) # shape (num_steps, batch_size)
|
||||
ranks_spec_dec, _ = get_sampled_token_logprobs(logprob_tensor,
|
||||
sampled_token_tensor)
|
||||
|
||||
ranks_regular = _get_ranks(logprob_tensor.reshape((2, -1)),
|
||||
sampled_token_tensor.reshape(-1))
|
||||
|
||||
assert torch.equal(ranks_spec_dec.reshape(-1), ranks_regular)
|
||||
@@ -1,317 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/tests/spec_decode/utils.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from itertools import count
|
||||
from typing import Callable, Optional, TypeVar, Union
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
SequenceData, SequenceGroupMetadata, SequenceOutput)
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker # noqa: F401
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
|
||||
from vllm_ascend.worker.model_runner import NPUModelRunner
|
||||
from vllm_ascend.worker.worker import NPUWorker
|
||||
|
||||
T = TypeVar("T", bound=NPUWorker)
|
||||
|
||||
|
||||
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
|
||||
return (seq_len + block_size - 1) // block_size
|
||||
|
||||
|
||||
def mock_worker(cls=None,
|
||||
vocab_size: int = 30_000,
|
||||
max_model_len: int = 2048,
|
||||
rank: int = 0,
|
||||
use_spec: bool = True) -> MagicMock:
|
||||
if cls is None:
|
||||
cls = NPUWorker
|
||||
|
||||
spec = cls if use_spec else None
|
||||
|
||||
worker = MagicMock(spec=spec)
|
||||
worker.vocab_size = vocab_size
|
||||
worker.max_model_len = max_model_len
|
||||
worker.rank = rank
|
||||
worker.device = 'npu:0'
|
||||
return worker
|
||||
|
||||
|
||||
def patch_execute_model_with_seeds(worker: NPUWorker, rand_seeds: list[int]):
|
||||
seed_iter = iter(rand_seeds)
|
||||
original_execute_model = worker.execute_model
|
||||
|
||||
def new_execute_model(*args, **kwargs):
|
||||
result = original_execute_model(*args, **kwargs)
|
||||
set_random_seed(next(seed_iter))
|
||||
return result
|
||||
|
||||
return new_execute_model
|
||||
|
||||
|
||||
def zero_kv_cache(cache_engine: list[CacheEngine]):
|
||||
assert cache_engine[0].gpu_cache
|
||||
for key_blocks, value_blocks in cache_engine[0].gpu_cache:
|
||||
key_blocks.zero_()
|
||||
value_blocks.zero_()
|
||||
|
||||
|
||||
def create_worker(cls: Callable[..., T],
|
||||
model_name: str,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
seed: int,
|
||||
is_driver_worker: bool = True,
|
||||
enforce_eager: bool = True,
|
||||
model_runner_cls: Optional[NPUModelRunner] = None,
|
||||
dtype: Optional[str] = "auto") -> T:
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
seed=seed,
|
||||
block_size=block_size,
|
||||
enforce_eager=enforce_eager,
|
||||
dtype=dtype,
|
||||
)
|
||||
engine_config = engine_args.create_engine_config()
|
||||
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
|
||||
if cls.__name__ == "NGramWorker":
|
||||
# we need to pass by device type to enable this on npu
|
||||
worker = cls(vllm_config=engine_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker,
|
||||
model_runner_cls=model_runner_cls,
|
||||
device_type="npu")
|
||||
else:
|
||||
worker = cls(
|
||||
vllm_config=engine_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker,
|
||||
model_runner_cls=model_runner_cls,
|
||||
)
|
||||
|
||||
worker.init_device()
|
||||
worker.load_model()
|
||||
|
||||
engine_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
engine_config.cache_config.num_cpu_blocks = 0
|
||||
worker.initialize_cache(
|
||||
num_gpu_blocks=engine_config.cache_config.num_gpu_blocks,
|
||||
num_cpu_blocks=engine_config.cache_config.num_cpu_blocks)
|
||||
|
||||
return worker
|
||||
|
||||
|
||||
def create_seq_group_metadata_from_prompts(
|
||||
prompts: list[list[int]],
|
||||
num_gpu_blocks: int,
|
||||
block_size: int,
|
||||
final_prompt_lens: list[int],
|
||||
continuations: Optional[list[list[int]]] = None,
|
||||
seq_ids: Optional[list[int]] = None,
|
||||
) -> list[SequenceGroupMetadata]:
|
||||
|
||||
if continuations is None:
|
||||
continuations = [[] for _ in prompts]
|
||||
|
||||
if seq_ids is None:
|
||||
seq_ids = list(i for i, _ in enumerate(prompts))
|
||||
|
||||
free_gpu_blocks = list(range(num_gpu_blocks))
|
||||
|
||||
block_allocations = {
|
||||
i: [
|
||||
free_gpu_blocks.pop()
|
||||
for _ in range(round_up_to_next_block(final_len, block_size))
|
||||
]
|
||||
for i, final_len in enumerate(final_prompt_lens)
|
||||
}
|
||||
|
||||
seq_grou_metadata_list = []
|
||||
for i, (prompt_token_ids,
|
||||
cont_token_ids) in enumerate(zip(prompts, continuations)):
|
||||
data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids)
|
||||
data.update_num_computed_tokens(
|
||||
len(prompt_token_ids) + len(cont_token_ids) - 1)
|
||||
seq_data = {i: data}
|
||||
seq_grou_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=str(i),
|
||||
is_prompt=len(cont_token_ids) == 0,
|
||||
seq_data=seq_data,
|
||||
sampling_params=SamplingParams(temperature=0.0),
|
||||
block_tables={i: block_allocations[i][:]},
|
||||
))
|
||||
return seq_grou_metadata_list
|
||||
|
||||
|
||||
def create_chunked_seq_group_metadata_from_prompt(
|
||||
prompt: list[int],
|
||||
num_gpu_blocks: int,
|
||||
chunk_size: int,
|
||||
block_size: int,
|
||||
seq_id: Optional[int] = None) -> list[SequenceGroupMetadata]:
|
||||
|
||||
if seq_id is None:
|
||||
seq_id = 0
|
||||
|
||||
free_gpu_blocks = list(range(num_gpu_blocks))
|
||||
|
||||
block_allocations = [
|
||||
free_gpu_blocks.pop()
|
||||
for _ in range(round_up_to_next_block(len(prompt), block_size))
|
||||
]
|
||||
|
||||
seq_group_metadata_list = []
|
||||
for i, idx in enumerate(range(0, len(prompt), chunk_size)):
|
||||
chunk_ids = prompt[idx:idx + chunk_size]
|
||||
data = SequenceData.from_seqs(prompt)
|
||||
data.update_num_computed_tokens(idx)
|
||||
seq_data = {i: data}
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=str(seq_id),
|
||||
is_prompt=True,
|
||||
do_sample=idx + chunk_size >= len(prompt), # terminal chunk
|
||||
seq_data=seq_data,
|
||||
sampling_params=SamplingParams(temperature=0.0),
|
||||
block_tables={i: block_allocations},
|
||||
token_chunk_size=len(chunk_ids)))
|
||||
return seq_group_metadata_list
|
||||
|
||||
|
||||
def assert_logprobs_dict_allclose(
|
||||
actual_logprobs: list[dict[int, Logprob]],
|
||||
expected_logprobs: list[dict[int, Logprob]]) -> None:
|
||||
for single_step_actual_logprobs, single_step_expected_logprobs in zip(
|
||||
actual_logprobs, expected_logprobs):
|
||||
assert set(single_step_actual_logprobs.keys()) == set(
|
||||
single_step_expected_logprobs.keys())
|
||||
for token_id in single_step_actual_logprobs:
|
||||
actual = torch.tensor(
|
||||
single_step_actual_logprobs[token_id].logprob)
|
||||
expected = torch.tensor(
|
||||
single_step_expected_logprobs[token_id].logprob)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
def create_sampler_output_list(
|
||||
token_ids: torch.Tensor,
|
||||
probs: GenericSequence[Optional[torch.Tensor]],
|
||||
logprobs: GenericSequence[Optional[torch.Tensor]],
|
||||
seq_ids: Optional[list[int]] = None) -> list[SamplerOutput]:
|
||||
num_steps, batch_size = token_ids.shape
|
||||
token_ids_by_step = token_ids.tolist()
|
||||
|
||||
if seq_ids is None:
|
||||
seq_ids = list(range(batch_size))
|
||||
|
||||
return [
|
||||
SamplerOutput(outputs=[
|
||||
CompletionSequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
output_token=token_id,
|
||||
parent_seq_id=seq_ids[seq_index],
|
||||
logprobs={token_id: Logprob(0)},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
) for seq_index, token_id in enumerate(token_ids_by_step[step])
|
||||
],
|
||||
sampled_token_probs=probs[step],
|
||||
logprobs=logprobs[step],
|
||||
sampled_token_ids=token_ids[step])
|
||||
for step in range(num_steps)
|
||||
]
|
||||
|
||||
|
||||
def create_batch(batch_size,
|
||||
k,
|
||||
prompt_len: Union[int, list[int]] = 10,
|
||||
prev_output_token_len: int = 10,
|
||||
seq_ids: Optional[list[int]] = None,
|
||||
num_gpu_blocks: Optional[int] = None,
|
||||
block_size: Optional[int] = None,
|
||||
prefill_chunk_size: Optional[int] = None):
|
||||
if block_size is None:
|
||||
block_size = 8
|
||||
|
||||
if num_gpu_blocks is None:
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
|
||||
iterator = count()
|
||||
|
||||
if isinstance(prompt_len, int):
|
||||
prompt_lens = [prompt_len for _ in range(batch_size)]
|
||||
else:
|
||||
prompt_lens = prompt_len
|
||||
|
||||
prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens]
|
||||
|
||||
if prefill_chunk_size:
|
||||
# Create a batch of chunked prompts.
|
||||
if not seq_ids:
|
||||
seq_ids = list(range(len(prompts)))
|
||||
seq_group_metadata_list = []
|
||||
for p, sid in zip(prompts, seq_ids):
|
||||
seq_group_metadata_list += \
|
||||
create_chunked_seq_group_metadata_from_prompt(
|
||||
p, num_gpu_blocks, prefill_chunk_size, block_size, sid)
|
||||
seq_group_metadata_list = seq_group_metadata_list[:batch_size]
|
||||
prev_output_tokens = []
|
||||
else:
|
||||
prev_output_tokens = [[
|
||||
next(iterator) for _ in range(prev_output_token_len)
|
||||
] for _ in range(batch_size)]
|
||||
final_prompt_lens = [
|
||||
len(prompt) + len(prev_output_token) + k + 1
|
||||
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
|
||||
]
|
||||
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size, final_prompt_lens,
|
||||
prev_output_tokens, seq_ids)
|
||||
return seq_group_metadata_list, prompts, prev_output_tokens
|
||||
|
||||
|
||||
def maybe_enable_chunked_prefill(prefill_chunk_size, llm_kwargs):
|
||||
if prefill_chunk_size > 0:
|
||||
llm_kwargs.update(
|
||||
**{
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": prefill_chunk_size,
|
||||
"max_num_seqs": prefill_chunk_size
|
||||
})
|
||||
else:
|
||||
llm_kwargs["enable_chunked_prefill"] = False
|
||||
@@ -1,66 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/blob/main/tests/entrypoints/llm/test_accuracy.py
|
||||
#
|
||||
|
||||
import gc
|
||||
import multiprocessing
|
||||
from multiprocessing import Queue
|
||||
|
||||
import lm_eval
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# pre-trained model path on Hugging Face.
|
||||
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
|
||||
# Math reasoning benchmark (Grade School Math 8K).
|
||||
TASK = "gsm8k"
|
||||
# Answer validation requiring format consistency.
|
||||
FILTER = "exact_match,strict-match"
|
||||
# 3% relative tolerance for numerical accuracy.
|
||||
RTOL = 0.03
|
||||
# Baseline accuracy after VLLM optimization.
|
||||
EXPECTED_VALUE = 0.316
|
||||
|
||||
|
||||
def run_test(queue, more_args=None):
|
||||
model_args = f"pretrained={MODEL_NAME},max_model_len=4096"
|
||||
if more_args is not None:
|
||||
model_args = f"{model_args},{more_args}"
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="vllm",
|
||||
model_args=model_args,
|
||||
tasks=TASK,
|
||||
batch_size="auto",
|
||||
)
|
||||
result = results["results"][TASK][FILTER]
|
||||
print("result:", result)
|
||||
queue.put(result)
|
||||
del results
|
||||
torch.npu.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
def test_lm_eval_accuracy(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context():
|
||||
result_queue: Queue[float] = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(target=run_test, args=(result_queue, ))
|
||||
p.start()
|
||||
p.join()
|
||||
result = result_queue.get()
|
||||
assert (EXPECTED_VALUE - RTOL < result < EXPECTED_VALUE + RTOL), \
|
||||
f"Expected: {EXPECTED_VALUE}±{RTOL} | Measured: {result}"
|
||||
@@ -16,8 +16,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.utils import GiB_bytes
|
||||
@@ -26,7 +25,6 @@ from tests.utils import fork_new_process_for_each_test
|
||||
from vllm_ascend.device_allocator.camem import CaMemAllocator
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
def test_basic_camem():
|
||||
# some tensors from default memory pool
|
||||
shape = (1024, 1024)
|
||||
@@ -59,9 +57,9 @@ def test_basic_camem():
|
||||
assert torch.allclose(output, torch.ones_like(output) * 3)
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="test failed, should be fixed later")
|
||||
@fork_new_process_for_each_test
|
||||
def test_end_to_end():
|
||||
os.environ["VLLM_USE_V1"] = "0"
|
||||
free, total = torch.npu.mem_get_info()
|
||||
used_bytes_baseline = total - free # in case other process is running
|
||||
llm = LLM("Qwen/Qwen2.5-0.5B-Instruct", enable_sleep_mode=True)
|
||||
|
||||
@@ -35,7 +35,6 @@ MODELS = [
|
||||
"Qwen/Qwen3-0.6B-Base",
|
||||
]
|
||||
MULTIMODALITY_MODELS = ["Qwen/Qwen2.5-VL-3B-Instruct"]
|
||||
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
||||
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
|
||||
|
||||
|
||||
@@ -82,8 +81,3 @@ def test_multimodal(model, prompt_template, vllm_runner):
|
||||
vllm_model.generate_greedy(prompts=prompts,
|
||||
images=images,
|
||||
max_tokens=64)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
pytest.main([__file__])
|
||||
|
||||
Reference in New Issue
Block a user