From bd11c0054f21bcc10f2419f5a50f1edb855f2be3 Mon Sep 17 00:00:00 2001 From: whx <56632993+whx-sjtu@users.noreply.github.com> Date: Tue, 21 Oct 2025 22:23:52 +0800 Subject: [PATCH] [BugFix] Fix torchair+mtp bug after deleting deepseek_mtp. (#3590) This is a missing bug fix introduced by PR #3561 - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: whx-sjtu <2952154980@qq.com> --- tests/ut/torchair/models/test_torchair_deepseek_mtp.py | 3 +-- vllm_ascend/spec_decode/mtp_proposer.py | 4 ++-- vllm_ascend/torchair/models/torchair_deepseek_mtp.py | 10 ++++------ 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/tests/ut/torchair/models/test_torchair_deepseek_mtp.py b/tests/ut/torchair/models/test_torchair_deepseek_mtp.py index b9b18ac..109c56e 100644 --- a/tests/ut/torchair/models/test_torchair_deepseek_mtp.py +++ b/tests/ut/torchair/models/test_torchair_deepseek_mtp.py @@ -145,8 +145,7 @@ class TestTorchairDeepSeekMultiTokenPredictor(PytestBase): return_value=None) predictor.logits_processor.return_value = torch.tensor([1.0, 2.0, 3.0]) - result_logits = predictor.compute_logits(hidden_states=hidden_states, - sampling_metadata=None) + result_logits = predictor.compute_logits(hidden_states=hidden_states) predictor.logits_processor.assert_called_once() assert torch.allclose(result_logits, torch.tensor([1.0, 2.0, 3.0])) diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 97b094e..66308ef 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -175,7 +175,7 @@ class MtpProposer(Proposer): torchair_compiled_model( input_ids=input_ids, positions=positions, - previous_hidden_states=previous_hidden_states, + hidden_states=previous_hidden_states, inputs_embeds=None, intermediate_tensors=None, attn_metadata=attn_metadata, @@ -460,7 +460,7 @@ class MtpProposer(Proposer): hidden_states = torchair_compiled_model( input_ids=self.input_ids[:num_input_tokens], positions=self.positions[:num_input_tokens], - previous_hidden_states=self. + hidden_states=self. hidden_states[:num_input_tokens], inputs_embeds=None, intermediate_tensors=None, diff --git a/vllm_ascend/torchair/models/torchair_deepseek_mtp.py b/vllm_ascend/torchair/models/torchair_deepseek_mtp.py index b3760f7..c8503e3 100644 --- a/vllm_ascend/torchair/models/torchair_deepseek_mtp.py +++ b/vllm_ascend/torchair/models/torchair_deepseek_mtp.py @@ -176,14 +176,12 @@ class TorchairDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata=None, # type: ignore spec_step_idx: int = 0, ) -> torch.Tensor: current_step_idx = (spec_step_idx % self.num_mtp_layers) mtp_layer = self.layers_list[current_step_idx] logits = self.logits_processor(mtp_layer.shared_head.head, - mtp_layer.shared_head(hidden_states), - sampling_metadata) + mtp_layer.shared_head(hidden_states)) return logits @@ -209,12 +207,12 @@ class TorchairDeepSeekMTP(DeepSeekMTP): positions: torch.Tensor, kv_caches: Optional[List[torch.Tensor]] = None, attn_metadata: Optional[AttentionMetadata] = None, - previous_hidden_states: Optional[torch.Tensor] = None, + hidden_states: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, previous_hidden_states, - inputs_embeds, spec_step_idx) + attn_metadata, hidden_states, inputs_embeds, + spec_step_idx) return hidden_states