[BugFix] Fix torchair+mtp bug after deleting deepseek_mtp. (#3590)

This is a missing bug fix introduced by PR #3561

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-10-21 22:23:52 +08:00
committed by GitHub
parent 0c83eee9b1
commit bd11c0054f
3 changed files with 7 additions and 10 deletions

View File

@@ -145,8 +145,7 @@ class TestTorchairDeepSeekMultiTokenPredictor(PytestBase):
return_value=None) return_value=None)
predictor.logits_processor.return_value = torch.tensor([1.0, 2.0, 3.0]) predictor.logits_processor.return_value = torch.tensor([1.0, 2.0, 3.0])
result_logits = predictor.compute_logits(hidden_states=hidden_states, result_logits = predictor.compute_logits(hidden_states=hidden_states)
sampling_metadata=None)
predictor.logits_processor.assert_called_once() predictor.logits_processor.assert_called_once()
assert torch.allclose(result_logits, torch.tensor([1.0, 2.0, 3.0])) assert torch.allclose(result_logits, torch.tensor([1.0, 2.0, 3.0]))

View File

@@ -175,7 +175,7 @@ class MtpProposer(Proposer):
torchair_compiled_model( torchair_compiled_model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
previous_hidden_states=previous_hidden_states, hidden_states=previous_hidden_states,
inputs_embeds=None, inputs_embeds=None,
intermediate_tensors=None, intermediate_tensors=None,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
@@ -460,7 +460,7 @@ class MtpProposer(Proposer):
hidden_states = torchair_compiled_model( hidden_states = torchair_compiled_model(
input_ids=self.input_ids[:num_input_tokens], input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens], positions=self.positions[:num_input_tokens],
previous_hidden_states=self. hidden_states=self.
hidden_states[:num_input_tokens], hidden_states[:num_input_tokens],
inputs_embeds=None, inputs_embeds=None,
intermediate_tensors=None, intermediate_tensors=None,

View File

@@ -176,14 +176,12 @@ class TorchairDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata=None, # type: ignore
spec_step_idx: int = 0, spec_step_idx: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
current_step_idx = (spec_step_idx % self.num_mtp_layers) current_step_idx = (spec_step_idx % self.num_mtp_layers)
mtp_layer = self.layers_list[current_step_idx] mtp_layer = self.layers_list[current_step_idx]
logits = self.logits_processor(mtp_layer.shared_head.head, logits = self.logits_processor(mtp_layer.shared_head.head,
mtp_layer.shared_head(hidden_states), mtp_layer.shared_head(hidden_states))
sampling_metadata)
return logits return logits
@@ -209,12 +207,12 @@ class TorchairDeepSeekMTP(DeepSeekMTP):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: Optional[List[torch.Tensor]] = None, kv_caches: Optional[List[torch.Tensor]] = None,
attn_metadata: Optional[AttentionMetadata] = None, attn_metadata: Optional[AttentionMetadata] = None,
previous_hidden_states: Optional[torch.Tensor] = None, hidden_states: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0, spec_step_idx: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, previous_hidden_states, attn_metadata, hidden_states, inputs_embeds,
inputs_embeds, spec_step_idx) spec_step_idx)
return hidden_states return hidden_states