diff --git a/tests/ut/torchair/models/test_torchair_deepseek_mtp.py b/tests/ut/torchair/models/test_torchair_deepseek_mtp.py index b9b18ac..109c56e 100644 --- a/tests/ut/torchair/models/test_torchair_deepseek_mtp.py +++ b/tests/ut/torchair/models/test_torchair_deepseek_mtp.py @@ -145,8 +145,7 @@ class TestTorchairDeepSeekMultiTokenPredictor(PytestBase): return_value=None) predictor.logits_processor.return_value = torch.tensor([1.0, 2.0, 3.0]) - result_logits = predictor.compute_logits(hidden_states=hidden_states, - sampling_metadata=None) + result_logits = predictor.compute_logits(hidden_states=hidden_states) predictor.logits_processor.assert_called_once() assert torch.allclose(result_logits, torch.tensor([1.0, 2.0, 3.0])) diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 97b094e..66308ef 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -175,7 +175,7 @@ class MtpProposer(Proposer): torchair_compiled_model( input_ids=input_ids, positions=positions, - previous_hidden_states=previous_hidden_states, + hidden_states=previous_hidden_states, inputs_embeds=None, intermediate_tensors=None, attn_metadata=attn_metadata, @@ -460,7 +460,7 @@ class MtpProposer(Proposer): hidden_states = torchair_compiled_model( input_ids=self.input_ids[:num_input_tokens], positions=self.positions[:num_input_tokens], - previous_hidden_states=self. + hidden_states=self. hidden_states[:num_input_tokens], inputs_embeds=None, intermediate_tensors=None, diff --git a/vllm_ascend/torchair/models/torchair_deepseek_mtp.py b/vllm_ascend/torchair/models/torchair_deepseek_mtp.py index b3760f7..c8503e3 100644 --- a/vllm_ascend/torchair/models/torchair_deepseek_mtp.py +++ b/vllm_ascend/torchair/models/torchair_deepseek_mtp.py @@ -176,14 +176,12 @@ class TorchairDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata=None, # type: ignore spec_step_idx: int = 0, ) -> torch.Tensor: current_step_idx = (spec_step_idx % self.num_mtp_layers) mtp_layer = self.layers_list[current_step_idx] logits = self.logits_processor(mtp_layer.shared_head.head, - mtp_layer.shared_head(hidden_states), - sampling_metadata) + mtp_layer.shared_head(hidden_states)) return logits @@ -209,12 +207,12 @@ class TorchairDeepSeekMTP(DeepSeekMTP): positions: torch.Tensor, kv_caches: Optional[List[torch.Tensor]] = None, attn_metadata: Optional[AttentionMetadata] = None, - previous_hidden_states: Optional[torch.Tensor] = None, + hidden_states: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, previous_hidden_states, - inputs_embeds, spec_step_idx) + attn_metadata, hidden_states, inputs_embeds, + spec_step_idx) return hidden_states