From a60e179c7f5ca5ab878bf51f177cbc9e24a6fb6b Mon Sep 17 00:00:00 2001 From: Zetong Li <48438720+slippersss@users.noreply.github.com> Date: Fri, 6 Mar 2026 09:10:57 +0800 Subject: [PATCH] [Refactor][EAGLE] 8/N delete mtp_proposer (#7016) ### What this PR does / why we need it? This PR aims to delete mtp_proposer. By fixing a bug in both dsv32 and glm5, now it should be ok to remove mtp_proposer. The bug is actually about unnecessary slicing of `slot_mapping`. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? by ci - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/15d76f74e2fdb12a95ea00f0ca283acf6219a2b7 --------- Signed-off-by: Zetong Li --- tests/ut/spec_decode/test_mtp_proposer.py | 367 --------------- vllm_ascend/attention/attention_v1.py | 6 +- vllm_ascend/spec_decode/__init__.py | 5 +- vllm_ascend/spec_decode/eagle_proposer.py | 17 +- vllm_ascend/spec_decode/mtp_proposer.py | 547 ---------------------- vllm_ascend/worker/model_runner_v1.py | 8 +- 6 files changed, 19 insertions(+), 931 deletions(-) delete mode 100644 tests/ut/spec_decode/test_mtp_proposer.py delete mode 100644 vllm_ascend/spec_decode/mtp_proposer.py diff --git a/tests/ut/spec_decode/test_mtp_proposer.py b/tests/ut/spec_decode/test_mtp_proposer.py deleted file mode 100644 index 6a4d88f8..00000000 --- a/tests/ut/spec_decode/test_mtp_proposer.py +++ /dev/null @@ -1,367 +0,0 @@ -from unittest.mock import MagicMock, patch - -import numpy as np -import pytest -import torch -from vllm.config import (CacheConfig, CompilationConfig, CUDAGraphMode, - ModelConfig, SchedulerConfig, SpeculativeConfig, - VllmConfig, set_current_vllm_config) -from vllm.v1.attention.backends.utils import CommonAttentionMetadata -from vllm.v1.spec_decode.metadata import SpecDecodeMetadata -from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch - -from vllm_ascend.ascend_config import init_ascend_config -from vllm_ascend.attention.utils import AscendCommonAttentionMetadata -from vllm_ascend.spec_decode.mtp_proposer import AscendMtpProposer - - -class TestMtpProposer: - - @pytest.fixture(autouse=True) - def patch_supports_multimodal_inputs(self): - with patch( - "vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs", - return_value=False - ): - yield - - @pytest.fixture - def vllm_config(self): - config = MagicMock(spec=VllmConfig) - config.additional_config = None - config.speculative_config = MagicMock(spec=SpeculativeConfig) - config.speculative_config.num_speculative_tokens = 2 - config.speculative_config.method = "mtp" - config.speculative_config.draft_model_config = MagicMock() - config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096 - config.speculative_config.draft_model_config.uses_mrope = False - config.speculative_config.draft_model_config.uses_xdrope_dim = 0 - config.speculative_config.speculative_token_tree = str([ - (i + 1) * (0, ) for i in range(2) - ]) - config.speculative_config.disable_padded_drafter_batch = False - - config.model_config = MagicMock(spec=ModelConfig) - config.model_config.dtype = torch.float16 - config.model_config.max_model_len = 2048 - config.model_config.uses_mrope = False - config.model_config.uses_xdrope_dim = 0 - config.model_config.hf_text_config = MagicMock(spec=[]) # Empty spec to prevent hasattr from returning True - config.model_config.hf_text_config.to_dict = MagicMock(return_value={}) - config.model_config.hf_config = None - config.parallel_config.tensor_parallel_size = 1 - config.parallel_config.data_parallel_rank = 0 - config.parallel_config.data_parallel_size = 1 - config.parallel_config.prefill_context_parallel_size = 1 - config.parallel_config.enable_expert_parallel = False - config.speculative_config.draft_tensor_parallel_size = 1 - - config.load_config = None - - config.cache_config = MagicMock(spec=CacheConfig) - config.cache_config.block_size = 16 - - config.scheduler_config = MagicMock(spec=SchedulerConfig) - config.scheduler_config.max_num_batched_tokens = 4096 - config.scheduler_config.max_num_seqs = 256 - - config.compilation_config = MagicMock(spec=CompilationConfig) - config.compilation_config.cudagraph_capture_sizes = [1, 2, 4, 8] - config.compilation_config.static_forward_context = dict() - config.compilation_config.pass_config = MagicMock() - config.compilation_config.pass_config.enable_sp = False - - config.device_config = MagicMock() - config.device_config.device = torch.device("cpu") - init_ascend_config(config) - return config - - @pytest.fixture - def runner(self): - runner = MagicMock() - runner.pcp_size = 1 - runner.dcp_size = 1 - runner.pcp_rank = 0 - runner.max_num_tokens = 4096 - runner.max_num_reqs = 256 - runner._use_aclgraph.return_value = False - runner.reserved_mc2_mask = None - runner.pin_memory = False - return runner - - @patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") - def test_init(self, mock_cpu_gpu_buffer, vllm_config, runner): - mock_buffer_instance = MagicMock() - mock_cpu_gpu_buffer.return_value = mock_buffer_instance - - # Test basic initialization - with set_current_vllm_config(vllm_config): - proposer = AscendMtpProposer(vllm_config, torch.device("cpu"), runner) - - assert proposer.vllm_config == vllm_config - assert proposer.device == torch.device("cpu") - assert proposer.dtype == torch.float16 - assert proposer.num_speculative_tokens == 2 - assert proposer.hidden_size == 4096 - - # Test with mrope enabled - assert hasattr(proposer, "positions") - assert not hasattr(proposer, "mrope_positions") - assert proposer.use_sparse is False - - @patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") - def test_init_with_aclgraph(self, mock_cpu_gpu_buffer, vllm_config, - runner): - mock_buffer_instance = MagicMock() - mock_cpu_gpu_buffer.return_value = mock_buffer_instance - runner._use_aclgraph.return_value = True - vllm_config.scheduler_config.async_scheduling = False - vllm_config.speculative_config.enforce_eager = False - with set_current_vllm_config(vllm_config): - proposer = AscendMtpProposer(vllm_config, torch.device("cpu"), runner) - - assert proposer.use_cuda_graph is True - - @patch("vllm_ascend.ascend_forward_context.get_dp_group") - @patch("vllm_ascend.ascend_forward_context.get_tensor_model_parallel_world_size", return_value=1) - @patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context") - @patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context") - @patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") - def test_dummy_run(self, mock_cpu_gpu_buffer, mock_set_context, - mock_get_forward_context, mock_tp_world_size, mock_dp_group, vllm_config, runner): - mock_buffer_instance = MagicMock() - mock_cpu_gpu_buffer.return_value = mock_buffer_instance - mock_dp_group.return_value.world_size = 1 - with set_current_vllm_config(vllm_config): - proposer = AscendMtpProposer(vllm_config, torch.device("cpu"), runner) - - # Mock _runnable to prevent actual execution - proposer._runnable = MagicMock() - proposer.enable_shared_expert_dp = False - runner._sync_metadata_across_dp.return_value = (8, 8, False) - - mock_get_forward_context = MagicMock() - mock_get_forward_context.cudagraph_runtime_mode = None - mock_get_forward_context.capturing = True - # Execute - proposer.dummy_run(8) - - # Verify - runner._sync_metadata_across_dp.assert_called_once() - - # Check that _runnable was called - assert proposer._runnable.call_count == 1 - - @patch("vllm_ascend.ascend_forward_context.get_dp_group") - @patch("vllm_ascend.ascend_forward_context.get_tensor_model_parallel_world_size", return_value=1) - @patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context") - @patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context") - @patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") - def test_dummy_run_full_graph(self, mock_cpu_gpu_buffer, mock_set_context, - mock_get_forward_context, mock_tp_world_size, mock_dp_group, vllm_config, - runner): - # Setup - mock_buffer_instance = MagicMock() - mock_cpu_gpu_buffer.return_value = mock_buffer_instance - mock_dp_group.return_value.world_size = 1 - with set_current_vllm_config(vllm_config): - proposer = AscendMtpProposer(vllm_config, torch.device("cpu"), runner) - - # Mock _runnable to prevent actual execution - proposer._runnable = MagicMock() - proposer.enable_shared_expert_dp = False - runner._sync_metadata_across_dp.return_value = (8, 8, False) - runner.attn_groups = [] - - mock_get_forward_context = MagicMock() - mock_get_forward_context.cudagraph_runtime_mode = None - mock_get_forward_context.capturing = True - # Execute - proposer.dummy_run(num_tokens=8, - num_reqs=5, - aclgraph_runtime_mode=CUDAGraphMode.FULL) - - # Verify - runner._sync_metadata_across_dp.assert_called_once() - - # Check that _runnable was called - assert proposer._runnable.call_count == 1 - - @patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") - def test_prepare_next_token_ids_cpu(self, mock_cpu_gpu_buffer): - mock_buffer_instance = MagicMock() - mock_cpu_gpu_buffer.return_value = mock_buffer_instance - sampled_token_ids = [[10, 20, 30], [40, 50], [60]] - - mock_gpu_batch = MagicMock() - mock_gpu_batch.req_ids = ["req1", "req2", "req3"] - mock_num_scheduled = {"req1": 0, "req2": 0, "req3": 0} - - proposer = MagicMock(spec=AscendMtpProposer) - proposer.input_ids = MagicMock(device=torch.device("cpu")) - proposer.prepare_next_token_ids_cpu = AscendMtpProposer.prepare_next_token_ids_cpu.__get__( - proposer) - result = proposer.prepare_next_token_ids_cpu( - sampled_token_ids=sampled_token_ids, - requests={}, - gpu_input_batch=mock_gpu_batch, - num_scheduled_tokens=mock_num_scheduled) - - assert torch.all( - result == torch.tensor([30, 50, 60], dtype=torch.int32)) - - @patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") - def test_prepare_next_token_ids_padded(self, mock_cpu_gpu_buffer): - mock_common_attn_metadata = MagicMock(spec=CommonAttentionMetadata) - mock_common_attn_metadata.seq_lens_cpu = torch.tensor( - [10, 8, 5, 12], dtype=torch.int32) - mock_sampled_token_ids = torch.tensor([ - [101, 102, 103], - [201, -1, 203], - [-1, -1, -1], - [301, 10000, 303], - ], - dtype=torch.int32, - device=torch.device("cpu")) - - mock_requests = {} # dict[str, CachedRequestState] - req0 = MagicMock(spec=CachedRequestState) - req0.get_token_id = MagicMock(return_value=1000) - mock_requests["req_0"] = req0 - - req1 = MagicMock(spec=CachedRequestState) - req1.get_token_id = MagicMock(return_value=2000) - mock_requests["req_1"] = req1 - - req2 = MagicMock(spec=CachedRequestState) - req2.get_token_id = MagicMock(return_value=3000) - mock_requests["req_2"] = req2 - - req3 = MagicMock(spec=CachedRequestState) - req3.get_token_id = MagicMock(return_value=4000) - mock_requests["req_3"] = req3 - - mock_gpu_input_batch = MagicMock(spec=InputBatch) - mock_gpu_input_batch.num_reqs = 4 - mock_gpu_input_batch.req_ids = ["req_0", "req_1", "req_2", "req_3"] - mock_gpu_input_batch.vocab_size = 5000 - - mock_backup = MagicMock() - mock_backup.np = np.array([1, 2, 3, 4, 5, 6, 7], dtype=np.int32) - mock_backup.gpu = torch.tensor([1, 2, 3, 4, 5, 6, 7], - dtype=torch.int32) - mock_backup.copy_to_gpu = MagicMock() - mock_cpu_gpu_buffer.return_value = mock_backup - - proposer = MagicMock(spec=AscendMtpProposer) - proposer.backup_next_token_ids = mock_backup - proposer.input_ids = MagicMock(device=torch.device("cpu")) - proposer.prepare_next_token_ids_padded = AscendMtpProposer.prepare_next_token_ids_padded.__get__( - proposer) - - discard_request_indices = torch.tensor([1, 3], dtype=torch.int64) - num_discarded_requests = 2 - - next_token_ids, valid_sampled_tokens_count = proposer.prepare_next_token_ids_padded( - common_attn_metadata=mock_common_attn_metadata, - sampled_token_ids=mock_sampled_token_ids, - requests=mock_requests, - gpu_input_batch=mock_gpu_input_batch, - discard_request_indices=discard_request_indices, - num_discarded_requests=num_discarded_requests) - - mock_backup_output = proposer.backup_next_token_ids - - expected_backup_cpu = np.array( - [1000, 2000, 3000, 4000, 0, 0, 0, 0, 0, 0]) - assert np.array_equal(mock_backup_output.np[:4], - expected_backup_cpu[:4]) - mock_backup_output.copy_to_gpu.assert_called_once_with(4) - - modified_sampled = mock_sampled_token_ids.clone() - modified_sampled.index_fill_( - 0, discard_request_indices[:num_discarded_requests], -1) - assert valid_sampled_tokens_count[1].item() == 0 - assert valid_sampled_tokens_count[3].item() == 0 - - expected_valid_count = torch.tensor([3, 0, 0, 0], dtype=torch.int32) - assert torch.equal(valid_sampled_tokens_count, expected_valid_count) - - expected_next_tokens = torch.tensor([103, 2, 3, 4], - dtype=torch.int32, - device=torch.device("cpu")) - assert torch.equal(next_token_ids, expected_next_tokens) - - @patch("vllm_ascend.spec_decode.eagle_proposer.HAS_TRITON", False) - @patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") - def test_prepare_inputs_padded(self, mock_cpu_gpu_buffer): - mock_buffer_instance = MagicMock() - mock_cpu_gpu_buffer.return_value = mock_buffer_instance - - mock_common_attn_metadata = MagicMock(spec=CommonAttentionMetadata) - mock_common_attn_metadata.query_start_loc_cpu = torch.tensor( - [0, 8, 16, 24], dtype=torch.int32) - mock_common_attn_metadata.seq_lens_cpu = torch.tensor( - [8, 8, 8], dtype=torch.int32) - mock_common_attn_metadata.num_input_tokens = 3 - mock_common_attn_metadata.query_start_loc = torch.tensor( - [0, 8, 16, 24], dtype=torch.int32) - mock_common_attn_metadata.seq_lens = torch.tensor([8, 8, 8], - dtype=torch.int32) - mock_common_attn_metadata.num_actual_tokens = 24 - mock_common_attn_metadata.num_reqs = 3 - mock_common_attn_metadata.num_computed_tokens_cpu = torch.tensor( - [5, 6, 7], dtype=torch.int32) - mock_common_attn_metadata.block_table_tensor = MagicMock() - mock_common_attn_metadata.slot_mapping = MagicMock() - mock_common_attn_metadata.positions = MagicMock() - - mock_spec_decode_metadata = MagicMock(spec=SpecDecodeMetadata) - mock_spec_decode_metadata.cu_num_draft_tokens = torch.tensor( - [3, 5, 7], dtype=torch.int32) - - mock_runner = MagicMock() - mock_runner.actual_seq_lengths_q = MagicMock() - mock_runner.attn_state = MagicMock() - mock_runner.graph_pad_size = 0 - mock_runner.pcp_size = 1 - mock_runner.decode_token_per_req = MagicMock() - - proposer = MagicMock(spec=AscendMtpProposer) - proposer.runner = mock_runner - proposer.pcp_size = 1 - proposer.arange = torch.arange(100, dtype=torch.int32) - proposer.prepare_inputs_padded = AscendMtpProposer.prepare_inputs_padded.__get__( - proposer) - - mock_valid_sampled_tokens_count = torch.tensor([2, 1, 2], - dtype=torch.int32) - - (spec_common_attn_metadata, token_indices, - token_indices_to_sample) = proposer.prepare_inputs_padded( - common_attn_metadata=mock_common_attn_metadata, - spec_decode_metadata=mock_spec_decode_metadata, - valid_sampled_tokens_count=mock_valid_sampled_tokens_count) - - total_num_tokens = mock_common_attn_metadata.query_start_loc_cpu[ - -1].item() - expected_token_indices = proposer.arange[:total_num_tokens] - assert torch.equal(token_indices, expected_token_indices) - assert token_indices.shape == (24, ) - assert token_indices.dtype == torch.int32 - - expected_sample_indices = torch.tensor([5, 13, 22], dtype=torch.int32) - assert torch.equal(token_indices_to_sample, expected_sample_indices) - - assert isinstance(spec_common_attn_metadata, - AscendCommonAttentionMetadata) - assert torch.equal(spec_common_attn_metadata.query_start_loc, - mock_common_attn_metadata.query_start_loc) - assert torch.equal(spec_common_attn_metadata.query_start_loc_cpu, - mock_common_attn_metadata.query_start_loc_cpu) - assert torch.equal(spec_common_attn_metadata.seq_lens_cpu, - mock_common_attn_metadata.seq_lens) - assert spec_common_attn_metadata.num_reqs == mock_common_attn_metadata.num_reqs - assert spec_common_attn_metadata.num_actual_tokens == total_num_tokens - assert spec_common_attn_metadata.max_query_len == 8 - assert spec_common_attn_metadata.actual_seq_lengths_q == proposer.runner.actual_seq_lengths_q diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index c3cf61b9..a1c79d94 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -324,7 +324,11 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): common_attn_metadata: AscendCommonAttentionMetadata, attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, ): - if attn_state in (AscendAttentionState.DecodeOnly, AscendAttentionState.ChunkedPrefill): + if attn_state in ( + AscendAttentionState.DecodeOnly, + AscendAttentionState.ChunkedPrefill, + AscendAttentionState.SpecDecoding, + ): attn_metadata = self.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, diff --git a/vllm_ascend/spec_decode/__init__.py b/vllm_ascend/spec_decode/__init__.py index 5cfc6a70..78644448 100644 --- a/vllm_ascend/spec_decode/__init__.py +++ b/vllm_ascend/spec_decode/__init__.py @@ -18,7 +18,6 @@ # from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer -from vllm_ascend.spec_decode.mtp_proposer import AscendMtpProposer from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer from vllm_ascend.spec_decode.suffix_proposer import AscendSuffixDecodingProposer @@ -30,9 +29,7 @@ def get_spec_decode_method(method, vllm_config, device, runner): return AscendSuffixDecodingProposer(vllm_config, runner) elif method == "medusa": return AscendMedusaProposer(vllm_config, device) - elif method in ("eagle", "eagle3"): + elif method in ("eagle", "eagle3", "mtp"): return AscendEagleProposer(vllm_config, device, runner) - elif method == "mtp": - return AscendMtpProposer(vllm_config, device, runner) else: raise ValueError(f"Unknown speculative decoding method: {method}") diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index db6032c7..0a5e5d74 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -129,7 +129,11 @@ class AscendEagleProposer(EagleProposer): self.use_cuda_graph = self.runner._use_aclgraph() and not self.speculative_config.enforce_eager if self.method == "mtp": - self.use_cuda_graph = self.use_cuda_graph and not self.use_async_scheduling + self.use_cuda_graph = ( + self.use_cuda_graph + and not self.use_async_scheduling + and not self.speculative_config.disable_padded_drafter_batch + ) # TODO: Remove it when the bug of fx-graph is solved self.maybe_eager_context: AbstractContextManager[Any] = nullcontext() @@ -340,7 +344,8 @@ class AscendEagleProposer(EagleProposer): # Set the real slot_mapping. common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step] attn_metadata_eagle = builder.build_for_graph_capture( - common_attn_metadata, AscendAttentionState.ChunkedPrefill + common_attn_metadata, + AscendAttentionState.SpecDecoding if self.method == "mtp" else AscendAttentionState.ChunkedPrefill, ) per_layer_attn_metadata = dict() for layer_name in self.attn_layer_names: @@ -536,7 +541,7 @@ class AscendEagleProposer(EagleProposer): slot_mapping_lens = common_attn_metadata.slot_mapping.shape[0] self.slot_mapping_group[0][:slot_mapping_lens].copy_(common_attn_metadata.slot_mapping[:slot_mapping_lens]) self.slot_mapping_group[0][slot_mapping_lens:].fill_(-1) - common_attn_metadata.slot_mapping = self.slot_mapping_group[0][:slot_mapping_lens] + common_attn_metadata.slot_mapping = self.slot_mapping_group[0] common_attn_metadata.num_input_tokens = num_input_tokens # FIXME(woosuk): The below two ops cause synchronization. Optimize. builder = self.runner.attn_groups[0][0].get_metadata_builder() @@ -900,7 +905,9 @@ class AscendEagleProposer(EagleProposer): common_attn_metadata.num_actual_tokens = batch_size common_attn_metadata.max_query_len = 1 common_attn_metadata.decode_token_per_req = 1 - common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill + common_attn_metadata.attn_state = ( + AscendAttentionState.SpecDecoding if self.method == "mtp" else AscendAttentionState.ChunkedPrefill + ) common_attn_metadata.graph_pad_size = -1 common_attn_metadata.num_input_tokens = input_batch_size @@ -982,7 +989,7 @@ class AscendEagleProposer(EagleProposer): self.slot_mapping_group[draft_step][: slot_mapping.shape[0]].copy_(slot_mapping.to(torch.int32)) self.slot_mapping_group[draft_step][slot_mapping.shape[0] :].fill_(PADDING_SLOT_ID) # Set the address of the attn_metadata.slot_mapping to the self.slot_mapping_group[idx] - common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step][: slot_mapping.shape[0]] + common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step] # Rebuild attention metadata attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py deleted file mode 100644 index 249e4502..00000000 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ /dev/null @@ -1,547 +0,0 @@ -import torch -import torch.nn as nn -from vllm.config import CUDAGraphMode -from vllm.distributed import get_pcp_group -from vllm.forward_context import get_forward_context -from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM -from vllm.v1.attention.backends.utils import CommonAttentionMetadata -from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID -from vllm.v1.utils import record_function_or_nullcontext - -from vllm_ascend.ascend_forward_context import set_ascend_forward_context -from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.utils import AscendCommonAttentionMetadata -from vllm_ascend.compilation.acl_graph import ACLGraphWrapper -from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla, update_cos_sin -from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer -from vllm_ascend.utils import lmhead_tp_enable - - -class AscendMtpProposer(AscendEagleProposer): - # TODO: Find out why ModelRunner does not this explicit typing? - model: nn.Module | ACLGraphWrapper - - @torch.inference_mode() - def dummy_run( - self, - num_tokens: int, - with_prefill: bool = False, - in_graph_capturing: bool = False, - num_reqs: int = 0, - num_tokens_across_dp=None, - aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor=None, - dummy_compute_logits=lambda hidden_states: None, - is_profile=False, - ) -> None: - # Currently, both GLM and DS encounter issues when enabling the fullgraph mode and running on EagleProposer. - # Therefore, we temporarily bypass this problem by adding a conditional check for fullgraph. - # TODO: this conditional check should be removed after bug fixing. - if not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(): - super().dummy_run( - num_tokens, - with_prefill, - in_graph_capturing, - num_reqs, - num_tokens_across_dp, - aclgraph_runtime_mode, - batch_descriptor, - dummy_compute_logits, - is_profile, - ) - return - ( - num_tokens, - num_tokens_across_dp, - with_prefill, - ) = self.runner._sync_metadata_across_dp(num_tokens, with_prefill) - if not self.use_cuda_graph: - # there is synchronization between mtp steps when enabling aclgraph, - # disable aclgraph when use async scheduling to avoid the - # synchronization overhead. - # NOTE: we need to set aclgraph_runtime_mode to None in both dummy_run - # and _propose. - aclgraph_runtime_mode = CUDAGraphMode.NONE - if aclgraph_runtime_mode == CUDAGraphMode.FULL: - if len(self.runner.attn_groups) > 0: - num_computed_tokens_cpu = self.runner.input_batch.num_computed_tokens_cpu_tensor[:num_reqs] - common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=self.runner.query_start_loc.gpu[: num_reqs + 1], - query_start_loc_cpu=self.runner.query_start_loc.cpu[: num_reqs + 1], - seq_lens_cpu=self.runner.seq_lens.cpu, - seq_lens=self.runner.seq_lens.gpu[:num_reqs], - num_reqs=num_reqs, - num_actual_tokens=num_tokens, - num_input_tokens=num_tokens, - max_query_len=self.num_speculative_tokens + 1, - num_computed_tokens_cpu=num_computed_tokens_cpu, - actual_seq_lengths_q=self.runner.actual_seq_lengths_q, - block_table_tensor=self.runner.input_batch.block_table[0].get_device_tensor(), - slot_mapping=self.runner.input_batch.block_table[0].slot_mapping.gpu, - positions=self.runner.positions.gpu, - attn_state=self.runner.attn_state, - decode_token_per_req=self.runner.decode_token_per_req, - max_seq_len=0, - ) - if self.pcp_size * self.dcp_size > 1: - # update long_seq related params and flatten block_table - common_attn_metadata.prefill_context_parallel_metadata = self.runner.pcp_manager.long_seq_metadata - common_attn_metadata.block_table_tensor = self.runner.input_batch.block_table[ - 0 - ].get_device_tensor()[: num_reqs * self.decode_threshold] - - builder = self.runner.attn_groups[0][0].get_metadata_builder() - # `AscendAttentionState.SpecDecoding` is only designed for MLA. - # `AscendAttentionState.ChunkedPrefill` is used in self-attention. - attn_state = ( - AscendAttentionState.SpecDecoding - if self.vllm_config.model_config.use_mla - else AscendAttentionState.ChunkedPrefill - ) - attn_metadata_mtp = builder.build_for_graph_capture(common_attn_metadata, attn_state) - attn_metadata = {} - for layer_name in self.attn_layer_names: - attn_metadata[layer_name] = attn_metadata_mtp - else: - attn_metadata = None - else: - attn_metadata = None - - input_ids = self.input_ids[:num_tokens] - positions = self._get_positions(num_tokens) - previous_hidden_states = self.hidden_states[:num_tokens] - for i in range(self.num_speculative_tokens): - if i > 0 and not in_graph_capturing and aclgraph_runtime_mode == CUDAGraphMode.FULL: - aclgraph_runtime_mode = CUDAGraphMode.NONE - with set_ascend_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp, - num_actual_tokens=0, - aclgraph_runtime_mode=aclgraph_runtime_mode, - batch_descriptor=batch_descriptor, - is_draft_model=True, - in_profile_run=is_profile, - ): - # Reset MOE layer index for each MTP step iteration - forward_context = get_forward_context() - if forward_context is not None: - forward_context.moe_layer_index = 0 - previous_hidden_states, positions = self.maybe_pad_and_reduce(previous_hidden_states, positions) - self.model(input_ids=input_ids, positions=positions, hidden_states=previous_hidden_states) - forward_context = get_forward_context() - if ( - forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL - and not forward_context.capturing - and not self.use_sparse - ): - self._update_full_graph_params(forward_context, num_tokens) - - previous_hidden_states, positions, _ = self.maybe_all_gather_and_unpad( - previous_hidden_states, positions - ) - dummy_compute_logits(previous_hidden_states) - if with_prefill: - break - - def _propose( - self, - # [num_tokens] - target_token_ids: torch.Tensor, - # [num_tokens] or [3, num_tokens] when M-RoPE is enabled - target_positions: torch.Tensor, - # [num_tokens, hidden_size] - target_hidden_states: torch.Tensor, - # [batch_size] - next_token_ids: torch.Tensor, - last_token_indices: torch.Tensor | None, - common_attn_metadata: CommonAttentionMetadata, - sampling_metadata: SamplingMetadata, - mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, - req_scheduled_tokens=None, - long_seq_metadata=None, - num_prefill_reqs=0, - num_decode_reqs=0, - scheduler_output: SchedulerOutput = None, - num_scheduled_tokens: int = 0, - ) -> torch.Tensor: - # Currently, both GLM and DS encounter issues when enabling the fullgraph mode and running on EagleProposer. - # Therefore, we temporarily bypass this problem by adding a conditional check for fullgraph. - # TODO: this conditional check should be removed after bug fixing. - if not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(): - draft_token_ids = super()._propose( - target_token_ids, - target_positions, - target_hidden_states, - next_token_ids, - last_token_indices, - common_attn_metadata, - sampling_metadata, - mm_embed_inputs, - req_scheduled_tokens, - long_seq_metadata, - num_prefill_reqs, - num_decode_reqs, - scheduler_output, - num_scheduled_tokens, - ) - return draft_token_ids - - num_tokens = target_token_ids.shape[0] - batch_size = next_token_ids.shape[0] - - if last_token_indices is None: - last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 - - if self.method == "eagle3": - assert isinstance(self.model, Eagle3LlamaForCausalLM) - target_hidden_states = self.model.combine_hidden_states(target_hidden_states) - assert target_hidden_states.shape[-1] == self.hidden_size - - # Shift the input ids by one token. - # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[: num_tokens - 1] = target_token_ids[1:] - # Replace the last token with the next token. - # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - self.input_ids[last_token_indices] = next_token_ids - - # update pcp related params - if self.pcp_size * self.dcp_size > 1: - assert long_seq_metadata is not None - common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata - ori_last_token_indices = last_token_indices.clone() - query_lens_d = self.runner.query_lens[:num_decode_reqs] - if self.pcp_size > 1: - # 1. preprocess decode/prefill input_ids & target_hidden_states - # decode input_ids: keep unchanged - # decode target_hidden_states: remove padding - # prefill input_ids: add padding and pcp split - # prefill target_hidden_states: pcp split - num_tokens_d = query_lens_d.sum().item() - num_tokens_d_padded = num_tokens_d * self.pcp_size - input_ids_d = self.input_ids[:num_tokens_d] - input_ids_p = self.input_ids[num_tokens_d:num_tokens] - target_hidden_states_d_padded = target_hidden_states[:num_tokens_d_padded] - if num_tokens_d: - # remove padding (from pcp all-gather) in decode part - mask_start_loc = torch.cat( - [torch.tensor([0], dtype=torch.int32), torch.cumsum(query_lens_d * self.pcp_size, dim=0)[:-1]] - ) - mask_len = query_lens_d - mask = [] - for req_id in range(num_decode_reqs): - mask += list(range(mask_start_loc[req_id], mask_start_loc[req_id] + mask_len[req_id])) - target_hidden_states_d = target_hidden_states_d_padded[mask] - else: - target_hidden_states_d = target_hidden_states_d_padded - target_hidden_states_p = target_hidden_states[num_tokens_d_padded:] - req_scheduled_tokens_p = {} - for i, req_id in enumerate(self.runner.input_batch.req_ids): - if i >= num_decode_reqs: - req_scheduled_tokens_p[req_id] = req_scheduled_tokens[req_id] - (num_tokens_p, input_ids_p, target_hidden_states_p, max_query_len_p, seq_lens_p, cu_num_tokens_p) = ( - self._split_pcp_input(req_scheduled_tokens_p, input_ids_p, target_hidden_states_p) - ) - num_tokens = num_tokens_d + num_tokens_p - target_positions = target_positions[:num_tokens] - self.input_ids[:num_tokens].copy_(torch.cat([input_ids_d, input_ids_p], dim=0)) - target_hidden_states = torch.cat([target_hidden_states_d, target_hidden_states_p], dim=0) - # 2. update sample_indices according to main model - if num_decode_reqs: - last_token_indices[:num_decode_reqs] = self.runner.logits_indices[last_token_indices[:num_decode_reqs]] - if num_prefill_reqs: - last_token_indices[-num_prefill_reqs:] = self.runner.logits_indices[-num_prefill_reqs:] - # 3. update attn_metadata params that may be influenced by pcp - common_attn_metadata.num_actual_tokens = num_tokens - common_attn_metadata.max_query_len = max(self.decode_threshold, max_query_len_p) - common_attn_metadata.seq_lens[-num_prefill_reqs:] = seq_lens_p - common_attn_metadata.seq_lens_cpu[-num_prefill_reqs:] = seq_lens_p - query_start_loc_p = cu_num_tokens_p[1:] + common_attn_metadata.query_start_loc[num_decode_reqs].item() - common_attn_metadata.query_start_loc[-num_prefill_reqs:] = query_start_loc_p - common_attn_metadata.query_start_loc_cpu[-num_prefill_reqs:] = query_start_loc_p - - assert self.runner is not None - - # Note(qcs): We may need to refactor these check logics. - if self.use_cuda_graph and num_scheduled_tokens <= self.runner.cudagraph_batch_sizes[-1]: - num_input_tokens = self.runner.cudagraph_dispatcher._bs_to_padded_graph_size[num_scheduled_tokens] - else: - # Eager mode, no padding needed - num_input_tokens = num_tokens - - # copy inputs to buffer for cudagraph - self._set_positions(num_tokens, target_positions) - self.hidden_states[:num_tokens] = target_hidden_states - # eager/acl piecewise mode need to update num_tokens_across_dp - (num_input_tokens, num_tokens_across_dp, with_prefill) = self.runner._sync_metadata_across_dp( - num_input_tokens, self.runner.with_prefill - ) - - # Enable shared_expert_dp and MTP FULL graph may cause accuracy issues. - if scheduler_output and not self.enable_shared_expert_dp: - max_query_len = common_attn_metadata.max_query_len - uniform_decode = (max_query_len in list(range(1, self.num_speculative_tokens + 2))) and ( - scheduler_output.total_num_scheduled_tokens - == self.runner.input_batch.num_reqs * (self.num_speculative_tokens + 1) - ) - else: - uniform_decode = False - has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0 - aclgraph_runtime_mode, batch_descriptor = self.runner.cudagraph_dispatcher.dispatch( - num_tokens=num_input_tokens, uniform_decode=uniform_decode, has_lora=has_lora - ) - if not self.use_cuda_graph: - # there is synchronization between mtp steps when enabling aclgraph, - # disable aclgraph when use async scheduling to avoid the - # synchronization overhead. - # NOTE: we need to set aclgraph_runtime_mode to None in both dummy_run - # and _propose. - aclgraph_runtime_mode = CUDAGraphMode.NONE - - if ( - self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs() - and aclgraph_runtime_mode == CUDAGraphMode.FULL - ): - graph_pad_size = num_input_tokens - else: - graph_pad_size = -1 - - # If use fullgraph and disable_padded_drafter_batch=True, We need to - # update the graph_pad_size in common_attn_metadata, to tell the - # builder padding some elements. - common_attn_metadata.graph_pad_size = graph_pad_size - common_attn_metadata.num_input_tokens = num_input_tokens - builder = self.runner.attn_groups[0][0].get_metadata_builder() - attn_metadata_mtp = builder.build(0, common_attn_metadata, self.runner.get_model()) - attn_metadata = {} - for layer_name in self.attn_layer_names: - attn_metadata[layer_name] = attn_metadata_mtp - - update_cos_sin(self._get_positions(num_input_tokens)) - for step in range(self.num_speculative_tokens): - with set_ascend_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens, - num_tokens_across_dp=num_tokens_across_dp, - aclgraph_runtime_mode=aclgraph_runtime_mode, - batch_descriptor=batch_descriptor, - num_actual_tokens=num_tokens, - is_draft_model=True, - ): - # Reset MOE layer index for each MTP step to match all_moe_layers registration - forward_context = get_forward_context() - if forward_context is not None: - forward_context.moe_layer_index = 0 - - with record_function_or_nullcontext("mtp_forward"): - model_kwargs = {} - model_kwargs["attn_metadata"] = attn_metadata - input_ids = self.input_ids[:num_input_tokens] - positions = self._get_positions(num_input_tokens) - hidden_states = self.hidden_states[:num_input_tokens] - - hidden_states, positions = self.maybe_pad_and_reduce(hidden_states, positions) - - hidden_states = self.model(input_ids=input_ids, positions=positions, hidden_states=hidden_states) - forward_context = get_forward_context() - if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not self.use_sparse: - self._update_full_graph_params(forward_context, num_input_tokens) - - hidden_states, positions, _ = self.maybe_all_gather_and_unpad(hidden_states, positions) - - num_indices = last_token_indices.shape[0] - if lmhead_tp_enable(): - max_num_reqs_across_dp = ( - self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len - ) - last_token_indices = nn.functional.pad(last_token_indices, (0, max_num_reqs_across_dp - num_indices)) - - if self.pcp_size > 1 and step == 0: - # remove graph padding before all_gather - hidden_states = hidden_states[:num_tokens] - hidden_states = get_pcp_group().all_gather(hidden_states, 0) - hidden_states = torch.index_select( - hidden_states, 0, self.runner.pcp_manager.pcp_allgather_restore_idx.gpu[: hidden_states.shape[0]] - ) - - sample_hidden_states = hidden_states[last_token_indices] - logits = self.model.compute_logits(sample_hidden_states) - if lmhead_tp_enable() and num_indices < logits.shape[0]: - logits = logits[:num_indices] - last_token_indices = last_token_indices[:num_indices] - draft_token_ids = logits.argmax(dim=-1) - - if self.num_speculative_tokens == 1: - # [batch_size, 1] - return draft_token_ids.view(-1, 1) - - if step == 0: - draft_token_ids_list = [draft_token_ids] - else: - draft_token_ids_list.append(draft_token_ids) - - # prepare next mtp inputs - # mtp>1: prefill skip or decode skip last loop - if with_prefill: - for _ in range(self.num_speculative_tokens - 1): - draft_token_ids_list.append(draft_token_ids) - if step == self.num_speculative_tokens - 1 or with_prefill: - break - - attn_metadata_i = attn_metadata[self.attn_layer_names[0]] - - if step == 0: - positions = target_positions[last_token_indices] - hidden_states = hidden_states[last_token_indices] - slot_mapping = attn_metadata_i.slot_mapping[last_token_indices] - attn_metadata_i.slot_mapping.fill_(-1) - attn_metadata_i.query_start_loc = self.arange[: batch_size + 1] - last_token_indices = self.arange[:batch_size] - if getattr(attn_metadata_i, "num_decode_tokens", 0): - attn_metadata_i.num_decode_tokens = batch_size - if self.pcp_size * self.dcp_size > 1: - positions = target_positions[ori_last_token_indices] - # For pcp/dcp, tokens are split across different cp ranks, - # so we can not simply update slot_mapping by += 1. - # Instead, we pre-allocate mtp slot_mapping in model_runner - # (_generate_pcp_mtp_input), and use updated slot_indices - # to get corresponding slot_mapping in each step. - num_reject_tokens = ( - torch.tensor(self.runner.pcp_manager.cu_num_tokens_pcp_full, dtype=torch.int32).to(self.device) - - ori_last_token_indices - - 1 - ) - num_accept_tokens = query_lens_d.to(self.device) - num_reject_tokens - # `AscendAttentionState.SpecDecoding` is only designed for MLA. - # `AscendAttentionState.ChunkedPrefill` is used in self-attention. - mtp_slot_mapping = self.runner.pcp_manager.mtp_slot_pad - - # slot_mapping index base offset: - # scheduled tokens + pre-allocated mtp tokens + accepted tokens - slot_idx_base = ( - torch.cat( - [ - torch.tensor([0], dtype=torch.int32, device=self.device), - (torch.cumsum(query_lens_d, dim=0)[:-1] * self.pcp_size).to(self.device), - ] - ) - + torch.arange(num_decode_reqs, device=self.device) - * (self.num_speculative_tokens - 1) - * self.pcp_size - + (num_accept_tokens - 1) * self.pcp_size - ) - slot_indices_list = [] - for req_id in range(num_decode_reqs): - slot_indices_list.append( - torch.arange( - slot_idx_base[req_id], slot_idx_base[req_id] + self.pcp_size, device=self.device - ) - ) - slot_indices = torch.cat(slot_indices_list, dim=0) - - # fold block_table (restore it to original size before flattened) - block_indices = torch.cat( - [torch.tensor([0], dtype=torch.int32), torch.cumsum(query_lens_d, dim=0)[:-1]] - ) - attn_metadata_i.decode.block_table[:batch_size] = attn_metadata_i.decode.block_table[block_indices] - attn_metadata_i.decode.block_table = attn_metadata_i.decode.block_table[:batch_size] - - input_ids = draft_token_ids_list[-1].int() - positions += 1 - - decode_metadata = getattr(attn_metadata_i, "decode", None) - prefill_metadata = getattr(attn_metadata_i, "prefill", None) - # When disable_padded_drafter_batch=False, it should not to be updating these params, maybe. - if decode_metadata is not None and ( - self.speculative_config.disable_padded_drafter_batch or aclgraph_runtime_mode != CUDAGraphMode.FULL - ): - decode_metadata.actual_seq_lengths_q = self.arange_cpu[1 : batch_size + 1].tolist() - if aclgraph_runtime_mode == CUDAGraphMode.FULL: - decode_metadata.actual_seq_lengths_q = builder.pad_actual_seq_len_q_mtp_disable_pad( - graph_pad_size - batch_size, batch_size, decode_metadata.actual_seq_lengths_q - ) - decode_metadata.cos, decode_metadata.sin = get_cos_and_sin_mla(positions[:batch_size]) - # NOTE(woosuk): We should handle the case where the draft model - # generates tokens beyond the max model length. Since it is complex - # to remove such requests from the batch, we keep them in the batch - # but adjust the position ids and slot mappings to avoid the - # out-of-range access during the model execution. The draft tokens - # generated with this adjustment should be ignored. - exceeds_max_model_len = positions[:batch_size] >= self.runner.model_config.max_model_len - # Mask out the position ids that exceed the max model length. - # Otherwise, we may get out-of-range error in RoPE. - clamped_positions = torch.where(exceeds_max_model_len, 0, positions[:batch_size]) - # Increment the sequence lengths. - # This is an out-of-place operation to avoid modifying the original tensor - # when enable async_scheduling. - attn_metadata_i.seq_lens = attn_metadata_i.seq_lens + 1 - # For the requests that exceed the max model length, we set the - # sequence length to 1 to minimize their overheads in attention. - exceeds_mask = attn_metadata_i.seq_lens[:batch_size] > self.runner.model_config.max_model_len - attn_metadata_i.seq_lens[:batch_size].masked_fill_(exceeds_mask, 1) - # Mask out the slot mappings that exceed the max model length. - # Otherwise, the KV cache will be inadvertently updated with the - # padding tokens. - slot_mapping += 1 - if self.pcp_size > 1: - exceeds_max_model_len = exceeds_max_model_len.repeat_interleave( - slot_mapping.size(0) // exceeds_max_model_len.size(0) - ) - slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) - - # copy inputs to buffer for cudagraph - self.input_ids[:batch_size] = input_ids - self._set_positions(batch_size, clamped_positions) - self.hidden_states[: hidden_states.shape[0]] = hidden_states - if self.pcp_size * self.dcp_size > 1: - # update local seq_len - num_computed_tokens_of_pcp_dcp = self.runner.pcp_manager._get_cp_local_seq_lens( - attn_metadata_i.seq_lens[:batch_size], - self.pcp_size, - self.dcp_size, - self.runner.parallel_config.cp_kv_cache_interleave_size, - ) - cp_seq_len = num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank] - attn_metadata_i.decode.cp_seq_len = cp_seq_len - # update slot_mapping - slot_indices += self.pcp_size - slot_mapping = mtp_slot_mapping[slot_indices] - attn_metadata_i.slot_mapping[: batch_size * self.pcp_size] = slot_mapping - else: - attn_metadata_i.slot_mapping[:batch_size] = slot_mapping - if self.speculative_config.disable_padded_drafter_batch: - if self.uses_mrope: - self.mrope_positions[:, batch_size:num_input_tokens] = 0 - else: - self.positions[batch_size:num_input_tokens] = 0 - self.input_ids[batch_size:num_input_tokens] = 0 - self.hidden_states[batch_size:num_input_tokens].fill_(0) - - if prefill_metadata is not None: - prefill_metadata.seq_lens = attn_metadata_i.seq_lens - prefill_metadata.seq_lens_list = prefill_metadata.seq_lens.tolist() - prefill_metadata.context_lens = attn_metadata_i.seq_lens - prefill_metadata.input_positions = self._get_positions(num_input_tokens) - prefill_metadata.max_seq_lens += 1 - prefill_metadata.max_seq_lens = min( - prefill_metadata.max_seq_lens, self.runner.model_config.max_model_len - ) - if decode_metadata is not None: - decode_metadata.seq_lens = attn_metadata_i.seq_lens - decode_metadata.seq_lens_list = decode_metadata.seq_lens.tolist() - decode_seq_lens_list = decode_metadata.seq_lens_list - if aclgraph_runtime_mode == CUDAGraphMode.FULL and self.speculative_config.disable_padded_drafter_batch: - decode_metadata.seq_lens_list = decode_seq_lens_list + [0] * ( - graph_pad_size - len(decode_seq_lens_list) - ) - decode_metadata.input_positions = self._get_positions(num_input_tokens) - decode_metadata.max_seq_lens += 1 - decode_metadata.max_seq_lens = min(decode_metadata.max_seq_lens, self.runner.model_config.max_model_len) - - # mtp>1: [batch_size, k] - draft_token_ids = torch.stack(draft_token_ids_list, dim=1) - return draft_token_ids diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 923f4fd4..f574b9e4 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -109,7 +109,6 @@ from vllm_ascend.sample.sampler import AscendSampler from vllm_ascend.spec_decode import get_spec_decode_method from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer -from vllm_ascend.spec_decode.mtp_proposer import AscendMtpProposer from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer from vllm_ascend.spec_decode.suffix_proposer import AscendSuffixDecodingProposer from vllm_ascend.utils import ( @@ -404,12 +403,7 @@ class NPUModelRunner(GPUModelRunner): def _set_up_drafter(self): # Set up speculative decoding. self.drafter: ( - AscendNgramProposer - | AscendEagleProposer - | AscendMtpProposer - | AscendSuffixDecodingProposer - | AscendMedusaProposer - | None + AscendNgramProposer | AscendEagleProposer | AscendSuffixDecodingProposer | AscendMedusaProposer | None ) = None self.actual_seq_lengths_q: list[int] = [] self.decode_token_per_req = 1