Files
xc-llm-ascend/tests/ut/spec_decode/test_mtp_proposer.py
MengLong Chen fa367e3b1a [CI] Add mtp_proposer ut (#4397)
### What this PR does / why we need it?
Add mtp_proposer ut

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: chenmenglong <chenmenglong1@huawei.com>
2025-12-12 20:41:31 +08:00

446 lines
20 KiB
Python

from unittest.mock import MagicMock, patch
import numpy as np
import pytest
import torch
from vllm.config import (CacheConfig, CompilationConfig, CUDAGraphMode,
ModelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.spec_decode.interface import SpecDcodeType
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
class TestMtpProposer:
@pytest.fixture
def vllm_config(self):
config = MagicMock(spec=VllmConfig)
config.additional_config = None
config.speculative_config = MagicMock(spec=SpeculativeConfig)
config.speculative_config.num_speculative_tokens = 2
config.speculative_config.method = "deepseek_mtp"
config.speculative_config.draft_model_config = MagicMock()
config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096
config.model_config = MagicMock(spec=ModelConfig)
config.model_config.dtype = torch.float16
config.model_config.max_model_len = 2048
config.model_config.uses_mrope = False
config.model_config.hf_config = None
config.load_config = None
config.cache_config = MagicMock(spec=CacheConfig)
config.cache_config.block_size = 16
config.scheduler_config = MagicMock(spec=SchedulerConfig)
config.scheduler_config.max_num_batched_tokens = 4096
config.scheduler_config.max_num_seqs = 256
config.compilation_config = MagicMock(spec=CompilationConfig)
config.compilation_config.cudagraph_capture_sizes = [1, 2, 4, 8]
config.compilation_config.static_forward_context = dict()
config.device_config = MagicMock()
config.device_config.device = torch.device("cpu")
init_ascend_config(config)
return config
@pytest.fixture
def runner(self):
runner = MagicMock()
runner.pcp_size = 1
runner.dcp_size = 1
runner.pcp_rank = 0
runner.max_num_tokens = 4096
runner.max_num_reqs = 256
runner._use_aclgraph.return_value = False
runner.reserved_mc2_mask = None
runner.in_profile_run = False
return runner
@patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer")
def test_init(self, mock_cpu_gpu_buffer, vllm_config, runner):
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
# Test basic initialization
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
assert proposer.name == SpecDcodeType.MTP
assert proposer.vllm_config == vllm_config
assert proposer.device == torch.device("cpu")
assert proposer.dtype == torch.float16
assert proposer.num_speculative_tokens == 2
assert proposer.hidden_size == 4096
assert proposer.block_size == 16
# Test with mrope enabled
assert hasattr(proposer, "positions")
assert not hasattr(proposer, "mrope_positions")
assert proposer.use_sparse is False
@patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer")
def test_init_with_aclgraph(self, mock_cpu_gpu_buffer, vllm_config,
runner):
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
runner._use_aclgraph.return_value = True
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
assert proposer.use_aclgraph is True
assert proposer.cudagraph_batch_sizes == [1, 2, 4, 8]
@patch("vllm.config.get_layers_from_vllm_config")
@patch("vllm_ascend.spec_decode.mtp_proposer.get_model_loader")
@patch(
"vllm_ascend.spec_decode.mtp_proposer.process_weights_after_loading")
@patch("vllm_ascend.spec_decode.mtp_proposer.set_default_torch_dtype")
@patch("vllm_ascend.spec_decode.mtp_proposer.set_current_vllm_config")
@patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer")
def test_load_model(self, mock_cpu_gpu_buffer, mock_set_config,
mock_set_dtype, mock_process_weights, mock_get_loader,
mock_get_layers, vllm_config, runner):
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
attn_layers_all = {
"target_attn_layer": "val0",
"draft_attn_layer": "val1",
"draft_attn_exclude_by_indexer": "val2",
}
indexer_layers_all = {
"target_indexer_0": "val3",
"draft_attn_exclude_by_indexer": "val4"
}
def get_layers_side_effect(vllm_config, cache_cls):
if cache_cls == AttentionLayerBase:
return attn_layers_all
elif cache_cls == DeepseekV32IndexerCache:
return indexer_layers_all
else:
return {}
# Setup
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
proposer._init_mtp_model = MagicMock()
mock_model = MagicMock()
proposer.model = mock_model
mock_loader = MagicMock()
mock_get_loader.return_value = mock_loader
mock_loader.get_all_weights.return_value = {
"dummy_weight": torch.tensor([1.0])
}
mock_get_layers.side_effect = get_layers_side_effect
with pytest.raises(AssertionError):
proposer.load_model(mock_model)
@patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context")
@patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context")
@patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer")
def test_dummy_run(self, mock_cpu_gpu_buffer, mock_set_context,
mock_get_forward_context, vllm_config, runner):
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
proposer.model = MagicMock()
proposer.enable_shared_expert_dp = False
runner._sync_metadata_across_dp.return_value = (8, 8, False)
runner._select_moe_comm_method.return_value = "alltoall"
mock_get_forward_context = MagicMock()
mock_get_forward_context.cudagraph_runtime_mode = None
mock_get_forward_context.capturing = True
# Execute
proposer.dummy_run(8)
# Verify
runner._sync_metadata_across_dp.assert_called_once()
runner._select_moe_comm_method.assert_called_once()
mock_set_context.assert_called()
# Check that model was called correct number of times
assert proposer.model.call_count == vllm_config.speculative_config.num_speculative_tokens
@patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context")
@patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context")
@patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer")
def test_dummy_run_full_graph(self, mock_cpu_gpu_buffer, mock_set_context,
mock_get_forward_context, vllm_config,
runner):
# Setup
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
proposer.enable_shared_expert_dp = False
proposer.model = MagicMock()
runner._sync_metadata_across_dp.return_value = (8, 8, False)
runner._select_moe_comm_method.return_value = "alltoall"
runner.attn_groups = []
mock_get_forward_context = MagicMock()
mock_get_forward_context.cudagraph_runtime_mode = None
mock_get_forward_context.capturing = True
# Execute
proposer.dummy_run(num_tokens=8,
num_reqs=5,
aclgraph_runtime_mode=CUDAGraphMode.FULL)
# Verify
runner._sync_metadata_across_dp.assert_called_once()
runner._select_moe_comm_method.assert_called_once()
mock_set_context.assert_called()
# Check that model was called correct number of times
assert proposer.model.call_count == vllm_config.speculative_config.num_speculative_tokens
@patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer")
def test_generate_token_ids(self, mock_cpu_gpu_buffer):
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
mock_deps = MagicMock()
mock_deps.scheduler_output = MagicMock(spec=SchedulerOutput)
mock_deps.scheduler_output.num_scheduled_tokens = 16
mock_deps.spec_decode_metadata = MagicMock(spec=SpecDecodeMetadata)
mock_deps.spec_decode_metadata.num_draft_tokens = 2
mock_deps.runner = MagicMock()
mock_deps.runner.input_batch = MagicMock(num_reqs=4)
mock_deps.runner.input_ids = torch.arange(16, dtype=torch.int32)
mock_deps.runner.spec_decode_common_attn_metadata = MagicMock()
mock_deps.runner.pcp_size = 2
mock_deps.runner.input_ids_pcp_full = torch.arange(32,
dtype=torch.int32)
mock_deps.runner.query_start_loc_pcp_full_cpu = torch.tensor(
[0, 8, 16, 24, 32])
mock_deps.positions = torch.arange(16, dtype=torch.int32)
mock_deps.hidden_states = torch.zeros(16, 4096, dtype=torch.float16)
mock_deps.sampled_token_ids = torch.tensor([[100, 101, -1],
[200, -1, -1],
[300, 301, 302]])
proposer = MagicMock(spec=MtpProposer)
proposer.enable_shared_expert_dp = False
proposer.runner = mock_deps.runner
proposer.decode_threshold = 1
proposer.speculative_config = MagicMock(
disable_padded_drafter_batch=False)
proposer.pcp_size = mock_deps.runner.pcp_size
proposer._get_attn_metadata = MagicMock(return_value=MagicMock())
proposer.prepare_next_token_ids_padded = MagicMock(
return_value=(torch.tensor([101, 200, 302]), 3))
proposer.prepare_inputs_padded = MagicMock(
return_value=(MagicMock(), torch.tensor([0, 2, 4]),
torch.tensor([7, 15, 23])))
proposer._propose = MagicMock(
return_value=torch.tensor([400, 401, 402]))
proposer.generate_token_ids = MtpProposer.generate_token_ids.__get__(
proposer)
draft_token_ids = proposer.generate_token_ids(
sampled_token_ids=mock_deps.sampled_token_ids,
scheduler_output=mock_deps.scheduler_output,
spec_decode_metadata=mock_deps.spec_decode_metadata,
positions=mock_deps.positions,
num_scheduled_tokens=mock_deps.scheduler_output.
num_scheduled_tokens,
hidden_states=mock_deps.hidden_states,
)
proposer.prepare_next_token_ids_padded.assert_called_once()
proposer.prepare_inputs_padded.assert_called_once()
proposer._get_attn_metadata.assert_called_once()
proposer._propose.assert_called_once()
assert torch.equal(draft_token_ids, proposer._propose.return_value)
@patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer")
def test_prepare_next_token_ids_cpu(self, mock_cpu_gpu_buffer):
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
sampled_token_ids = [[10, 20, 30], [40, 50], [60]]
mock_gpu_batch = MagicMock()
mock_gpu_batch.req_ids = ["req1", "req2", "req3"]
mock_num_scheduled = {"req1": 0, "req2": 0, "req3": 0}
proposer = MagicMock(spec=MtpProposer)
proposer.input_ids = MagicMock(device=torch.device("cpu"))
proposer.prepare_next_token_ids_cpu = MtpProposer.prepare_next_token_ids_cpu.__get__(
proposer)
result = proposer.prepare_next_token_ids_cpu(
sampled_token_ids=sampled_token_ids,
requests={},
gpu_input_batch=mock_gpu_batch,
num_scheduled_tokens=mock_num_scheduled)
assert torch.all(
result == torch.tensor([30, 50, 60], dtype=torch.int32))
@patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer")
def test_prepare_next_token_ids_padded(self, mock_cpu_gpu_buffer):
mock_common_attn_metadata = MagicMock(spec=CommonAttentionMetadata)
mock_common_attn_metadata.seq_lens_cpu = torch.tensor(
[10, 8, 5, 12], dtype=torch.int32)
mock_sampled_token_ids = torch.tensor([
[101, 102, 103],
[201, -1, 203],
[-1, -1, -1],
[301, 10000, 303],
],
dtype=torch.int32,
device=torch.device("cpu"))
mock_requests = {} # dict[str, CachedRequestState]
req0 = MagicMock(spec=CachedRequestState)
req0.get_token_id = MagicMock(return_value=1000)
mock_requests["req_0"] = req0
req1 = MagicMock(spec=CachedRequestState)
req1.get_token_id = MagicMock(return_value=2000)
mock_requests["req_1"] = req1
req2 = MagicMock(spec=CachedRequestState)
req2.get_token_id = MagicMock(return_value=3000)
mock_requests["req_2"] = req2
req3 = MagicMock(spec=CachedRequestState)
req3.get_token_id = MagicMock(return_value=4000)
mock_requests["req_3"] = req3
mock_gpu_input_batch = MagicMock(spec=InputBatch)
mock_gpu_input_batch.num_reqs = 4
mock_gpu_input_batch.req_ids = ["req_0", "req_1", "req_2", "req_3"]
mock_gpu_input_batch.vocab_size = 5000
mock_backup = MagicMock()
mock_backup.np = np.array([1, 2, 3, 4, 5, 6, 7], dtype=np.int32)
mock_backup.gpu = torch.tensor([1, 2, 3, 4, 5, 6, 7],
dtype=torch.int32)
mock_backup.copy_to_gpu = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_backup
proposer = MagicMock(spec=MtpProposer)
proposer.backup_next_token_ids = mock_backup
proposer.input_ids = MagicMock(device=torch.device("cpu"))
proposer.prepare_next_token_ids_padded = MtpProposer.prepare_next_token_ids_padded.__get__(
proposer)
discard_request_indices = torch.tensor([1, 3], dtype=torch.int64)
num_discarded_requests = 2
next_token_ids, valid_sampled_tokens_count = proposer.prepare_next_token_ids_padded(
common_attn_metadata=mock_common_attn_metadata,
sampled_token_ids=mock_sampled_token_ids,
requests=mock_requests,
gpu_input_batch=mock_gpu_input_batch,
discard_request_indices=discard_request_indices,
num_discarded_requests=num_discarded_requests)
mock_backup_output = proposer.backup_next_token_ids
expected_backup_cpu = np.array(
[1000, 2000, 3000, 4000, 0, 0, 0, 0, 0, 0])
assert np.array_equal(mock_backup_output.np[:4],
expected_backup_cpu[:4])
mock_backup_output.copy_to_gpu.assert_called_once_with(4)
modified_sampled = mock_sampled_token_ids.clone()
modified_sampled.index_fill_(
0, discard_request_indices[:num_discarded_requests], -1)
assert valid_sampled_tokens_count[1].item() == 0
assert valid_sampled_tokens_count[3].item() == 0
expected_valid_count = torch.tensor([3, 0, 0, 0], dtype=torch.int32)
assert torch.equal(valid_sampled_tokens_count, expected_valid_count)
expected_next_tokens = torch.tensor([103, 2, 3, 4],
dtype=torch.int32,
device=torch.device("cpu"))
assert torch.equal(next_token_ids, expected_next_tokens)
@patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer")
def test_prepare_inputs_padded(self, mock_cpu_gpu_buffer):
mock_buffer_instance = MagicMock()
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
mock_common_attn_metadata = MagicMock(spec=CommonAttentionMetadata)
mock_common_attn_metadata.query_start_loc_cpu = torch.tensor(
[0, 8, 16, 24], dtype=torch.int32)
mock_common_attn_metadata.seq_lens_cpu = torch.tensor(
[8, 8, 8], dtype=torch.int32)
mock_common_attn_metadata.num_input_tokens = 3
mock_common_attn_metadata.query_start_loc = torch.tensor(
[0, 8, 16, 24], dtype=torch.int32)
mock_common_attn_metadata.seq_lens = torch.tensor([8, 8, 8],
dtype=torch.int32)
mock_common_attn_metadata.num_reqs = 3
mock_common_attn_metadata.num_computed_tokens_cpu = torch.tensor(
[5, 6, 7], dtype=torch.int32)
mock_common_attn_metadata.block_table_tensor = MagicMock()
mock_common_attn_metadata.slot_mapping = MagicMock()
mock_common_attn_metadata.positions = MagicMock()
mock_spec_decode_metadata = MagicMock(spec=SpecDecodeMetadata)
mock_spec_decode_metadata.cu_num_draft_tokens = torch.tensor(
[3, 5, 7], dtype=torch.int32)
mock_runner = MagicMock()
mock_runner.actual_seq_lengths_q = MagicMock()
mock_runner.attn_mask = MagicMock()
mock_runner.spec_attn_mask = MagicMock()
mock_runner.attn_state = MagicMock()
mock_runner.graph_pad_size = 0
mock_runner.decode_token_per_req = MagicMock()
proposer = MagicMock(spec=MtpProposer)
proposer.runner = mock_runner
proposer.arange = torch.arange(100, dtype=torch.int32)
proposer.prepare_inputs_padded = MtpProposer.prepare_inputs_padded.__get__(
proposer)
mock_valid_sampled_tokens_count = torch.tensor([2, 1, 2],
dtype=torch.int32)
(spec_common_attn_metadata, token_indices,
token_indices_to_sample) = proposer.prepare_inputs_padded(
common_attn_metadata=mock_common_attn_metadata,
spec_decode_metadata=mock_spec_decode_metadata,
valid_sampled_tokens_count=mock_valid_sampled_tokens_count)
total_num_tokens = mock_common_attn_metadata.query_start_loc_cpu[
-1].item()
expected_token_indices = proposer.arange[:total_num_tokens]
assert torch.equal(token_indices, expected_token_indices)
assert token_indices.shape == (24, )
assert token_indices.dtype == torch.int32
expected_sample_indices = torch.tensor([5, 13, 22], dtype=torch.int32)
assert torch.equal(token_indices_to_sample, expected_sample_indices)
assert isinstance(spec_common_attn_metadata,
AscendCommonAttentionMetadata)
assert torch.equal(spec_common_attn_metadata.query_start_loc,
mock_common_attn_metadata.query_start_loc)
assert torch.equal(spec_common_attn_metadata.query_start_loc_cpu,
mock_common_attn_metadata.query_start_loc_cpu)
assert torch.equal(spec_common_attn_metadata.seq_lens_cpu,
mock_common_attn_metadata.seq_lens)
assert spec_common_attn_metadata.num_reqs == mock_common_attn_metadata.num_reqs
assert spec_common_attn_metadata.num_actual_tokens == total_num_tokens
assert spec_common_attn_metadata.max_query_len == 8
assert spec_common_attn_metadata.actual_seq_lengths_q == proposer.runner.actual_seq_lengths_q
assert spec_common_attn_metadata.attn_mask == proposer.runner.attn_mask
assert spec_common_attn_metadata.spec_attn_mask == proposer.runner.spec_attn_mask