Files
xc-llm-ascend/tests/ut/spec_decode/test_mtp_proposer.py
lilinsiman 52863c4165 [Refactor][EAGLE] 2/N: load model and generate token (#5437)
### What this PR does / why we need it?
1. Refactor eagle and mtp function: load_model and generate_token_ids
2. Remove redundant code in mtp and eagle file
3. Refactor the UT of file

2/N of Refactor and merge mtp and eagle
Relational RFC: https://github.com/vllm-project/vllm-ascend/issues/5467

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
ut and tests

- vLLM version: release/v0.13.0
- vLLM main:
81786c8774

---------

Signed-off-by: lilinsiman <lilinsiman@gmail.com>
2026-01-05 14:07:54 +08:00

339 lines
15 KiB
Python

from unittest.mock import MagicMock, patch
import numpy as np
import pytest
import torch
from vllm.config import (CacheConfig, CompilationConfig, CUDAGraphMode,
ModelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
class TestMtpProposer:
@pytest.fixture(autouse=True)
def patch_supports_multimodal_inputs(self):
with patch(
"vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs"
):
yield
@pytest.fixture
def vllm_config(self):
config = MagicMock(spec=VllmConfig)
config.additional_config = None
config.speculative_config = MagicMock(spec=SpeculativeConfig)
config.speculative_config.num_speculative_tokens = 2
config.speculative_config.method = "deepseek_mtp"
config.speculative_config.draft_model_config = MagicMock()
config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096
config.speculative_config.speculative_token_tree = str([
(i + 1) * (0, ) for i in range(2)
])
config.model_config = MagicMock(spec=ModelConfig)
config.model_config.dtype = torch.float16
config.model_config.max_model_len = 2048
config.model_config.uses_mrope = False
config.model_config.hf_config = None
config.load_config = None
config.cache_config = MagicMock(spec=CacheConfig)
config.cache_config.block_size = 16
config.scheduler_config = MagicMock(spec=SchedulerConfig)
config.scheduler_config.max_num_batched_tokens = 4096
config.scheduler_config.max_num_seqs = 256
config.compilation_config = MagicMock(spec=CompilationConfig)
config.compilation_config.cudagraph_capture_sizes = [1, 2, 4, 8]
config.compilation_config.static_forward_context = dict()
config.device_config = MagicMock()
config.device_config.device = torch.device("cpu")
init_ascend_config(config)
return config
@pytest.fixture
def runner(self):
runner = MagicMock()
runner.pcp_size = 1
runner.dcp_size = 1
runner.pcp_rank = 0
runner.max_num_tokens = 4096
runner.max_num_reqs = 256
runner._use_aclgraph.return_value = False
runner.reserved_mc2_mask = None
return runner
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
def test_init(self, mock_cpu_gpu_buffer, vllm_config, runner):
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
# Test basic initialization
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
assert proposer.vllm_config == vllm_config
assert proposer.device == torch.device("cpu")
assert proposer.dtype == torch.float16
assert proposer.num_speculative_tokens == 2
assert proposer.hidden_size == 4096
assert proposer.block_size == 16
# Test with mrope enabled
assert hasattr(proposer, "positions")
assert not hasattr(proposer, "mrope_positions")
assert proposer.use_sparse is False
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
def test_init_with_aclgraph(self, mock_cpu_gpu_buffer, vllm_config,
runner):
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
runner._use_aclgraph.return_value = True
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
assert proposer.use_aclgraph is True
@patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context")
@patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context")
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
def test_dummy_run(self, mock_cpu_gpu_buffer, mock_set_context,
mock_get_forward_context, vllm_config, runner):
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
proposer.model = MagicMock()
proposer.enable_shared_expert_dp = False
runner._sync_metadata_across_dp.return_value = (8, 8, False)
mock_get_forward_context = MagicMock()
mock_get_forward_context.cudagraph_runtime_mode = None
mock_get_forward_context.capturing = True
# Execute
proposer.dummy_run(8)
# Verify
runner._sync_metadata_across_dp.assert_called_once()
mock_set_context.assert_called()
# Check that model was called correct number of times
assert proposer.model.call_count == vllm_config.speculative_config.num_speculative_tokens
@patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context")
@patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context")
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
def test_dummy_run_full_graph(self, mock_cpu_gpu_buffer, mock_set_context,
mock_get_forward_context, vllm_config,
runner):
# Setup
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
proposer.enable_shared_expert_dp = False
proposer.model = MagicMock()
runner._sync_metadata_across_dp.return_value = (8, 8, False)
runner.attn_groups = []
mock_get_forward_context = MagicMock()
mock_get_forward_context.cudagraph_runtime_mode = None
mock_get_forward_context.capturing = True
# Execute
proposer.dummy_run(num_tokens=8,
num_reqs=5,
aclgraph_runtime_mode=CUDAGraphMode.FULL)
# Verify
runner._sync_metadata_across_dp.assert_called_once()
mock_set_context.assert_called()
# Check that model was called correct number of times
assert proposer.model.call_count == vllm_config.speculative_config.num_speculative_tokens
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
def test_prepare_next_token_ids_cpu(self, mock_cpu_gpu_buffer):
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
sampled_token_ids = [[10, 20, 30], [40, 50], [60]]
mock_gpu_batch = MagicMock()
mock_gpu_batch.req_ids = ["req1", "req2", "req3"]
mock_num_scheduled = {"req1": 0, "req2": 0, "req3": 0}
proposer = MagicMock(spec=MtpProposer)
proposer.input_ids = MagicMock(device=torch.device("cpu"))
proposer.prepare_next_token_ids_cpu = MtpProposer.prepare_next_token_ids_cpu.__get__(
proposer)
result = proposer.prepare_next_token_ids_cpu(
sampled_token_ids=sampled_token_ids,
requests={},
gpu_input_batch=mock_gpu_batch,
num_scheduled_tokens=mock_num_scheduled)
assert torch.all(
result == torch.tensor([30, 50, 60], dtype=torch.int32))
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
def test_prepare_next_token_ids_padded(self, mock_cpu_gpu_buffer):
mock_common_attn_metadata = MagicMock(spec=CommonAttentionMetadata)
mock_common_attn_metadata.seq_lens_cpu = torch.tensor(
[10, 8, 5, 12], dtype=torch.int32)
mock_sampled_token_ids = torch.tensor([
[101, 102, 103],
[201, -1, 203],
[-1, -1, -1],
[301, 10000, 303],
],
dtype=torch.int32,
device=torch.device("cpu"))
mock_requests = {} # dict[str, CachedRequestState]
req0 = MagicMock(spec=CachedRequestState)
req0.get_token_id = MagicMock(return_value=1000)
mock_requests["req_0"] = req0
req1 = MagicMock(spec=CachedRequestState)
req1.get_token_id = MagicMock(return_value=2000)
mock_requests["req_1"] = req1
req2 = MagicMock(spec=CachedRequestState)
req2.get_token_id = MagicMock(return_value=3000)
mock_requests["req_2"] = req2
req3 = MagicMock(spec=CachedRequestState)
req3.get_token_id = MagicMock(return_value=4000)
mock_requests["req_3"] = req3
mock_gpu_input_batch = MagicMock(spec=InputBatch)
mock_gpu_input_batch.num_reqs = 4
mock_gpu_input_batch.req_ids = ["req_0", "req_1", "req_2", "req_3"]
mock_gpu_input_batch.vocab_size = 5000
mock_backup = MagicMock()
mock_backup.np = np.array([1, 2, 3, 4, 5, 6, 7], dtype=np.int32)
mock_backup.gpu = torch.tensor([1, 2, 3, 4, 5, 6, 7],
dtype=torch.int32)
mock_backup.copy_to_gpu = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_backup
proposer = MagicMock(spec=MtpProposer)
proposer.backup_next_token_ids = mock_backup
proposer.input_ids = MagicMock(device=torch.device("cpu"))
proposer.prepare_next_token_ids_padded = MtpProposer.prepare_next_token_ids_padded.__get__(
proposer)
discard_request_indices = torch.tensor([1, 3], dtype=torch.int64)
num_discarded_requests = 2
next_token_ids, valid_sampled_tokens_count = proposer.prepare_next_token_ids_padded(
common_attn_metadata=mock_common_attn_metadata,
sampled_token_ids=mock_sampled_token_ids,
requests=mock_requests,
gpu_input_batch=mock_gpu_input_batch,
discard_request_indices=discard_request_indices,
num_discarded_requests=num_discarded_requests)
mock_backup_output = proposer.backup_next_token_ids
expected_backup_cpu = np.array(
[1000, 2000, 3000, 4000, 0, 0, 0, 0, 0, 0])
assert np.array_equal(mock_backup_output.np[:4],
expected_backup_cpu[:4])
mock_backup_output.copy_to_gpu.assert_called_once_with(4)
modified_sampled = mock_sampled_token_ids.clone()
modified_sampled.index_fill_(
0, discard_request_indices[:num_discarded_requests], -1)
assert valid_sampled_tokens_count[1].item() == 0
assert valid_sampled_tokens_count[3].item() == 0
expected_valid_count = torch.tensor([3, 0, 0, 0], dtype=torch.int32)
assert torch.equal(valid_sampled_tokens_count, expected_valid_count)
expected_next_tokens = torch.tensor([103, 2, 3, 4],
dtype=torch.int32,
device=torch.device("cpu"))
assert torch.equal(next_token_ids, expected_next_tokens)
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
def test_prepare_inputs_padded(self, mock_cpu_gpu_buffer):
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
mock_common_attn_metadata = MagicMock(spec=CommonAttentionMetadata)
mock_common_attn_metadata.query_start_loc_cpu = torch.tensor(
[0, 8, 16, 24], dtype=torch.int32)
mock_common_attn_metadata.seq_lens_cpu = torch.tensor(
[8, 8, 8], dtype=torch.int32)
mock_common_attn_metadata.num_input_tokens = 3
mock_common_attn_metadata.query_start_loc = torch.tensor(
[0, 8, 16, 24], dtype=torch.int32)
mock_common_attn_metadata.seq_lens = torch.tensor([8, 8, 8],
dtype=torch.int32)
mock_common_attn_metadata.num_reqs = 3
mock_common_attn_metadata.num_computed_tokens_cpu = torch.tensor(
[5, 6, 7], dtype=torch.int32)
mock_common_attn_metadata.block_table_tensor = MagicMock()
mock_common_attn_metadata.slot_mapping = MagicMock()
mock_common_attn_metadata.positions = MagicMock()
mock_spec_decode_metadata = MagicMock(spec=SpecDecodeMetadata)
mock_spec_decode_metadata.cu_num_draft_tokens = torch.tensor(
[3, 5, 7], dtype=torch.int32)
mock_runner = MagicMock()
mock_runner.actual_seq_lengths_q = MagicMock()
mock_runner.attn_mask = MagicMock()
mock_runner.spec_attn_mask = MagicMock()
mock_runner.attn_state = MagicMock()
mock_runner.graph_pad_size = 0
mock_runner.decode_token_per_req = MagicMock()
proposer = MagicMock(spec=MtpProposer)
proposer.runner = mock_runner
proposer.arange = torch.arange(100, dtype=torch.int32)
proposer.prepare_inputs_padded = MtpProposer.prepare_inputs_padded.__get__(
proposer)
mock_valid_sampled_tokens_count = torch.tensor([2, 1, 2],
dtype=torch.int32)
(spec_common_attn_metadata, token_indices,
token_indices_to_sample) = proposer.prepare_inputs_padded(
common_attn_metadata=mock_common_attn_metadata,
spec_decode_metadata=mock_spec_decode_metadata,
valid_sampled_tokens_count=mock_valid_sampled_tokens_count)
total_num_tokens = mock_common_attn_metadata.query_start_loc_cpu[
-1].item()
expected_token_indices = proposer.arange[:total_num_tokens]
assert torch.equal(token_indices, expected_token_indices)
assert token_indices.shape == (24, )
assert token_indices.dtype == torch.int32
expected_sample_indices = torch.tensor([5, 13, 22], dtype=torch.int32)
assert torch.equal(token_indices_to_sample, expected_sample_indices)
assert isinstance(spec_common_attn_metadata,
AscendCommonAttentionMetadata)
assert torch.equal(spec_common_attn_metadata.query_start_loc,
mock_common_attn_metadata.query_start_loc)
assert torch.equal(spec_common_attn_metadata.query_start_loc_cpu,
mock_common_attn_metadata.query_start_loc_cpu)
assert torch.equal(spec_common_attn_metadata.seq_lens_cpu,
mock_common_attn_metadata.seq_lens)
assert spec_common_attn_metadata.num_reqs == mock_common_attn_metadata.num_reqs
assert spec_common_attn_metadata.num_actual_tokens == total_num_tokens
assert spec_common_attn_metadata.max_query_len == 8
assert spec_common_attn_metadata.actual_seq_lengths_q == proposer.runner.actual_seq_lengths_q
assert spec_common_attn_metadata.attn_mask == proposer.runner.attn_mask
assert spec_common_attn_metadata.spec_attn_mask == proposer.runner.spec_attn_mask