diff --git a/tests/ut/spec_decode/test_mtp_proposer.py b/tests/ut/spec_decode/test_mtp_proposer.py new file mode 100644 index 00000000..a2507ccc --- /dev/null +++ b/tests/ut/spec_decode/test_mtp_proposer.py @@ -0,0 +1,445 @@ +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import torch +from vllm.config import (CacheConfig, CompilationConfig, CUDAGraphMode, + ModelConfig, SchedulerConfig, SpeculativeConfig, + VllmConfig) +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch + +from vllm_ascend.ascend_config import init_ascend_config +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.spec_decode.interface import SpecDcodeType +from vllm_ascend.spec_decode.mtp_proposer import MtpProposer + + +class TestMtpProposer: + + @pytest.fixture + def vllm_config(self): + config = MagicMock(spec=VllmConfig) + config.additional_config = None + config.speculative_config = MagicMock(spec=SpeculativeConfig) + config.speculative_config.num_speculative_tokens = 2 + config.speculative_config.method = "deepseek_mtp" + config.speculative_config.draft_model_config = MagicMock() + config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096 + + config.model_config = MagicMock(spec=ModelConfig) + config.model_config.dtype = torch.float16 + config.model_config.max_model_len = 2048 + config.model_config.uses_mrope = False + config.model_config.hf_config = None + + config.load_config = None + + config.cache_config = MagicMock(spec=CacheConfig) + config.cache_config.block_size = 16 + + config.scheduler_config = MagicMock(spec=SchedulerConfig) + config.scheduler_config.max_num_batched_tokens = 4096 + config.scheduler_config.max_num_seqs = 256 + + config.compilation_config = MagicMock(spec=CompilationConfig) + config.compilation_config.cudagraph_capture_sizes = [1, 2, 4, 8] + config.compilation_config.static_forward_context = dict() + + config.device_config = MagicMock() + config.device_config.device = torch.device("cpu") + init_ascend_config(config) + return config + + @pytest.fixture + def runner(self): + runner = MagicMock() + runner.pcp_size = 1 + runner.dcp_size = 1 + runner.pcp_rank = 0 + runner.max_num_tokens = 4096 + runner.max_num_reqs = 256 + runner._use_aclgraph.return_value = False + runner.reserved_mc2_mask = None + runner.in_profile_run = False + return runner + + @patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer") + def test_init(self, mock_cpu_gpu_buffer, vllm_config, runner): + mock_buffer_instance = MagicMock() + mock_cpu_gpu_buffer.return_value = mock_buffer_instance + + # Test basic initialization + proposer = MtpProposer(vllm_config, torch.device("cpu"), runner) + + assert proposer.name == SpecDcodeType.MTP + assert proposer.vllm_config == vllm_config + assert proposer.device == torch.device("cpu") + assert proposer.dtype == torch.float16 + assert proposer.num_speculative_tokens == 2 + assert proposer.hidden_size == 4096 + assert proposer.block_size == 16 + + # Test with mrope enabled + assert hasattr(proposer, "positions") + assert not hasattr(proposer, "mrope_positions") + assert proposer.use_sparse is False + + @patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer") + def test_init_with_aclgraph(self, mock_cpu_gpu_buffer, vllm_config, + runner): + mock_buffer_instance = MagicMock() + mock_cpu_gpu_buffer.return_value = mock_buffer_instance + runner._use_aclgraph.return_value = True + proposer = MtpProposer(vllm_config, torch.device("cpu"), runner) + + assert proposer.use_aclgraph is True + assert proposer.cudagraph_batch_sizes == [1, 2, 4, 8] + + @patch("vllm.config.get_layers_from_vllm_config") + @patch("vllm_ascend.spec_decode.mtp_proposer.get_model_loader") + @patch( + "vllm_ascend.spec_decode.mtp_proposer.process_weights_after_loading") + @patch("vllm_ascend.spec_decode.mtp_proposer.set_default_torch_dtype") + @patch("vllm_ascend.spec_decode.mtp_proposer.set_current_vllm_config") + @patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer") + def test_load_model(self, mock_cpu_gpu_buffer, mock_set_config, + mock_set_dtype, mock_process_weights, mock_get_loader, + mock_get_layers, vllm_config, runner): + mock_buffer_instance = MagicMock() + mock_cpu_gpu_buffer.return_value = mock_buffer_instance + attn_layers_all = { + "target_attn_layer": "val0", + "draft_attn_layer": "val1", + "draft_attn_exclude_by_indexer": "val2", + } + + indexer_layers_all = { + "target_indexer_0": "val3", + "draft_attn_exclude_by_indexer": "val4" + } + + def get_layers_side_effect(vllm_config, cache_cls): + if cache_cls == AttentionLayerBase: + return attn_layers_all + elif cache_cls == DeepseekV32IndexerCache: + return indexer_layers_all + else: + return {} + + # Setup + proposer = MtpProposer(vllm_config, torch.device("cpu"), runner) + proposer._init_mtp_model = MagicMock() + mock_model = MagicMock() + proposer.model = mock_model + + mock_loader = MagicMock() + mock_get_loader.return_value = mock_loader + mock_loader.get_all_weights.return_value = { + "dummy_weight": torch.tensor([1.0]) + } + + mock_get_layers.side_effect = get_layers_side_effect + with pytest.raises(AssertionError): + proposer.load_model(mock_model) + + @patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context") + @patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context") + @patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer") + def test_dummy_run(self, mock_cpu_gpu_buffer, mock_set_context, + mock_get_forward_context, vllm_config, runner): + mock_buffer_instance = MagicMock() + mock_cpu_gpu_buffer.return_value = mock_buffer_instance + proposer = MtpProposer(vllm_config, torch.device("cpu"), runner) + proposer.model = MagicMock() + proposer.enable_shared_expert_dp = False + runner._sync_metadata_across_dp.return_value = (8, 8, False) + runner._select_moe_comm_method.return_value = "alltoall" + + mock_get_forward_context = MagicMock() + mock_get_forward_context.cudagraph_runtime_mode = None + mock_get_forward_context.capturing = True + # Execute + proposer.dummy_run(8) + + # Verify + runner._sync_metadata_across_dp.assert_called_once() + runner._select_moe_comm_method.assert_called_once() + mock_set_context.assert_called() + + # Check that model was called correct number of times + assert proposer.model.call_count == vllm_config.speculative_config.num_speculative_tokens + + @patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context") + @patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context") + @patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer") + def test_dummy_run_full_graph(self, mock_cpu_gpu_buffer, mock_set_context, + mock_get_forward_context, vllm_config, + runner): + # Setup + mock_buffer_instance = MagicMock() + mock_cpu_gpu_buffer.return_value = mock_buffer_instance + proposer = MtpProposer(vllm_config, torch.device("cpu"), runner) + proposer.enable_shared_expert_dp = False + proposer.model = MagicMock() + runner._sync_metadata_across_dp.return_value = (8, 8, False) + runner._select_moe_comm_method.return_value = "alltoall" + runner.attn_groups = [] + + mock_get_forward_context = MagicMock() + mock_get_forward_context.cudagraph_runtime_mode = None + mock_get_forward_context.capturing = True + # Execute + proposer.dummy_run(num_tokens=8, + num_reqs=5, + aclgraph_runtime_mode=CUDAGraphMode.FULL) + + # Verify + runner._sync_metadata_across_dp.assert_called_once() + runner._select_moe_comm_method.assert_called_once() + mock_set_context.assert_called() + + # Check that model was called correct number of times + assert proposer.model.call_count == vllm_config.speculative_config.num_speculative_tokens + + @patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer") + def test_generate_token_ids(self, mock_cpu_gpu_buffer): + mock_buffer_instance = MagicMock() + mock_cpu_gpu_buffer.return_value = mock_buffer_instance + + mock_deps = MagicMock() + mock_deps.scheduler_output = MagicMock(spec=SchedulerOutput) + mock_deps.scheduler_output.num_scheduled_tokens = 16 + mock_deps.spec_decode_metadata = MagicMock(spec=SpecDecodeMetadata) + mock_deps.spec_decode_metadata.num_draft_tokens = 2 + mock_deps.runner = MagicMock() + mock_deps.runner.input_batch = MagicMock(num_reqs=4) + mock_deps.runner.input_ids = torch.arange(16, dtype=torch.int32) + mock_deps.runner.spec_decode_common_attn_metadata = MagicMock() + mock_deps.runner.pcp_size = 2 + mock_deps.runner.input_ids_pcp_full = torch.arange(32, + dtype=torch.int32) + mock_deps.runner.query_start_loc_pcp_full_cpu = torch.tensor( + [0, 8, 16, 24, 32]) + mock_deps.positions = torch.arange(16, dtype=torch.int32) + mock_deps.hidden_states = torch.zeros(16, 4096, dtype=torch.float16) + mock_deps.sampled_token_ids = torch.tensor([[100, 101, -1], + [200, -1, -1], + [300, 301, 302]]) + + proposer = MagicMock(spec=MtpProposer) + proposer.enable_shared_expert_dp = False + proposer.runner = mock_deps.runner + proposer.decode_threshold = 1 + proposer.speculative_config = MagicMock( + disable_padded_drafter_batch=False) + proposer.pcp_size = mock_deps.runner.pcp_size + proposer._get_attn_metadata = MagicMock(return_value=MagicMock()) + proposer.prepare_next_token_ids_padded = MagicMock( + return_value=(torch.tensor([101, 200, 302]), 3)) + proposer.prepare_inputs_padded = MagicMock( + return_value=(MagicMock(), torch.tensor([0, 2, 4]), + torch.tensor([7, 15, 23]))) + proposer._propose = MagicMock( + return_value=torch.tensor([400, 401, 402])) + proposer.generate_token_ids = MtpProposer.generate_token_ids.__get__( + proposer) + + draft_token_ids = proposer.generate_token_ids( + sampled_token_ids=mock_deps.sampled_token_ids, + scheduler_output=mock_deps.scheduler_output, + spec_decode_metadata=mock_deps.spec_decode_metadata, + positions=mock_deps.positions, + num_scheduled_tokens=mock_deps.scheduler_output. + num_scheduled_tokens, + hidden_states=mock_deps.hidden_states, + ) + + proposer.prepare_next_token_ids_padded.assert_called_once() + proposer.prepare_inputs_padded.assert_called_once() + proposer._get_attn_metadata.assert_called_once() + proposer._propose.assert_called_once() + assert torch.equal(draft_token_ids, proposer._propose.return_value) + + @patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer") + def test_prepare_next_token_ids_cpu(self, mock_cpu_gpu_buffer): + mock_buffer_instance = MagicMock() + mock_cpu_gpu_buffer.return_value = mock_buffer_instance + sampled_token_ids = [[10, 20, 30], [40, 50], [60]] + + mock_gpu_batch = MagicMock() + mock_gpu_batch.req_ids = ["req1", "req2", "req3"] + mock_num_scheduled = {"req1": 0, "req2": 0, "req3": 0} + + proposer = MagicMock(spec=MtpProposer) + proposer.input_ids = MagicMock(device=torch.device("cpu")) + proposer.prepare_next_token_ids_cpu = MtpProposer.prepare_next_token_ids_cpu.__get__( + proposer) + result = proposer.prepare_next_token_ids_cpu( + sampled_token_ids=sampled_token_ids, + requests={}, + gpu_input_batch=mock_gpu_batch, + num_scheduled_tokens=mock_num_scheduled) + + assert torch.all( + result == torch.tensor([30, 50, 60], dtype=torch.int32)) + + @patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer") + def test_prepare_next_token_ids_padded(self, mock_cpu_gpu_buffer): + mock_common_attn_metadata = MagicMock(spec=CommonAttentionMetadata) + mock_common_attn_metadata.seq_lens_cpu = torch.tensor( + [10, 8, 5, 12], dtype=torch.int32) + mock_sampled_token_ids = torch.tensor([ + [101, 102, 103], + [201, -1, 203], + [-1, -1, -1], + [301, 10000, 303], + ], + dtype=torch.int32, + device=torch.device("cpu")) + + mock_requests = {} # dict[str, CachedRequestState] + req0 = MagicMock(spec=CachedRequestState) + req0.get_token_id = MagicMock(return_value=1000) + mock_requests["req_0"] = req0 + + req1 = MagicMock(spec=CachedRequestState) + req1.get_token_id = MagicMock(return_value=2000) + mock_requests["req_1"] = req1 + + req2 = MagicMock(spec=CachedRequestState) + req2.get_token_id = MagicMock(return_value=3000) + mock_requests["req_2"] = req2 + + req3 = MagicMock(spec=CachedRequestState) + req3.get_token_id = MagicMock(return_value=4000) + mock_requests["req_3"] = req3 + + mock_gpu_input_batch = MagicMock(spec=InputBatch) + mock_gpu_input_batch.num_reqs = 4 + mock_gpu_input_batch.req_ids = ["req_0", "req_1", "req_2", "req_3"] + mock_gpu_input_batch.vocab_size = 5000 + + mock_backup = MagicMock() + mock_backup.np = np.array([1, 2, 3, 4, 5, 6, 7], dtype=np.int32) + mock_backup.gpu = torch.tensor([1, 2, 3, 4, 5, 6, 7], + dtype=torch.int32) + mock_backup.copy_to_gpu = MagicMock() + mock_cpu_gpu_buffer.return_value = mock_backup + + proposer = MagicMock(spec=MtpProposer) + proposer.backup_next_token_ids = mock_backup + proposer.input_ids = MagicMock(device=torch.device("cpu")) + proposer.prepare_next_token_ids_padded = MtpProposer.prepare_next_token_ids_padded.__get__( + proposer) + + discard_request_indices = torch.tensor([1, 3], dtype=torch.int64) + num_discarded_requests = 2 + + next_token_ids, valid_sampled_tokens_count = proposer.prepare_next_token_ids_padded( + common_attn_metadata=mock_common_attn_metadata, + sampled_token_ids=mock_sampled_token_ids, + requests=mock_requests, + gpu_input_batch=mock_gpu_input_batch, + discard_request_indices=discard_request_indices, + num_discarded_requests=num_discarded_requests) + + mock_backup_output = proposer.backup_next_token_ids + + expected_backup_cpu = np.array( + [1000, 2000, 3000, 4000, 0, 0, 0, 0, 0, 0]) + assert np.array_equal(mock_backup_output.np[:4], + expected_backup_cpu[:4]) + mock_backup_output.copy_to_gpu.assert_called_once_with(4) + + modified_sampled = mock_sampled_token_ids.clone() + modified_sampled.index_fill_( + 0, discard_request_indices[:num_discarded_requests], -1) + assert valid_sampled_tokens_count[1].item() == 0 + assert valid_sampled_tokens_count[3].item() == 0 + + expected_valid_count = torch.tensor([3, 0, 0, 0], dtype=torch.int32) + assert torch.equal(valid_sampled_tokens_count, expected_valid_count) + + expected_next_tokens = torch.tensor([103, 2, 3, 4], + dtype=torch.int32, + device=torch.device("cpu")) + assert torch.equal(next_token_ids, expected_next_tokens) + + @patch("vllm_ascend.spec_decode.mtp_proposer.CpuGpuBuffer") + def test_prepare_inputs_padded(self, mock_cpu_gpu_buffer): + mock_buffer_instance = MagicMock() + mock_cpu_gpu_buffer.return_value = mock_buffer_instance + + mock_common_attn_metadata = MagicMock(spec=CommonAttentionMetadata) + mock_common_attn_metadata.query_start_loc_cpu = torch.tensor( + [0, 8, 16, 24], dtype=torch.int32) + mock_common_attn_metadata.seq_lens_cpu = torch.tensor( + [8, 8, 8], dtype=torch.int32) + mock_common_attn_metadata.num_input_tokens = 3 + mock_common_attn_metadata.query_start_loc = torch.tensor( + [0, 8, 16, 24], dtype=torch.int32) + mock_common_attn_metadata.seq_lens = torch.tensor([8, 8, 8], + dtype=torch.int32) + mock_common_attn_metadata.num_reqs = 3 + mock_common_attn_metadata.num_computed_tokens_cpu = torch.tensor( + [5, 6, 7], dtype=torch.int32) + mock_common_attn_metadata.block_table_tensor = MagicMock() + mock_common_attn_metadata.slot_mapping = MagicMock() + mock_common_attn_metadata.positions = MagicMock() + + mock_spec_decode_metadata = MagicMock(spec=SpecDecodeMetadata) + mock_spec_decode_metadata.cu_num_draft_tokens = torch.tensor( + [3, 5, 7], dtype=torch.int32) + + mock_runner = MagicMock() + mock_runner.actual_seq_lengths_q = MagicMock() + mock_runner.attn_mask = MagicMock() + mock_runner.spec_attn_mask = MagicMock() + mock_runner.attn_state = MagicMock() + mock_runner.graph_pad_size = 0 + mock_runner.decode_token_per_req = MagicMock() + + proposer = MagicMock(spec=MtpProposer) + proposer.runner = mock_runner + proposer.arange = torch.arange(100, dtype=torch.int32) + proposer.prepare_inputs_padded = MtpProposer.prepare_inputs_padded.__get__( + proposer) + + mock_valid_sampled_tokens_count = torch.tensor([2, 1, 2], + dtype=torch.int32) + + (spec_common_attn_metadata, token_indices, + token_indices_to_sample) = proposer.prepare_inputs_padded( + common_attn_metadata=mock_common_attn_metadata, + spec_decode_metadata=mock_spec_decode_metadata, + valid_sampled_tokens_count=mock_valid_sampled_tokens_count) + + total_num_tokens = mock_common_attn_metadata.query_start_loc_cpu[ + -1].item() + expected_token_indices = proposer.arange[:total_num_tokens] + assert torch.equal(token_indices, expected_token_indices) + assert token_indices.shape == (24, ) + assert token_indices.dtype == torch.int32 + + expected_sample_indices = torch.tensor([5, 13, 22], dtype=torch.int32) + assert torch.equal(token_indices_to_sample, expected_sample_indices) + + assert isinstance(spec_common_attn_metadata, + AscendCommonAttentionMetadata) + assert torch.equal(spec_common_attn_metadata.query_start_loc, + mock_common_attn_metadata.query_start_loc) + assert torch.equal(spec_common_attn_metadata.query_start_loc_cpu, + mock_common_attn_metadata.query_start_loc_cpu) + assert torch.equal(spec_common_attn_metadata.seq_lens_cpu, + mock_common_attn_metadata.seq_lens) + assert spec_common_attn_metadata.num_reqs == mock_common_attn_metadata.num_reqs + assert spec_common_attn_metadata.num_actual_tokens == total_num_tokens + assert spec_common_attn_metadata.max_query_len == 8 + assert spec_common_attn_metadata.actual_seq_lengths_q == proposer.runner.actual_seq_lengths_q + assert spec_common_attn_metadata.attn_mask == proposer.runner.attn_mask + assert spec_common_attn_metadata.spec_attn_mask == proposer.runner.spec_attn_mask