[BugFix] Fix torchair+mtp bug after deleting deepseek_mtp. (#3590)
This is a missing bug fix introduced by PR #3561 - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -145,8 +145,7 @@ class TestTorchairDeepSeekMultiTokenPredictor(PytestBase):
|
|||||||
return_value=None)
|
return_value=None)
|
||||||
predictor.logits_processor.return_value = torch.tensor([1.0, 2.0, 3.0])
|
predictor.logits_processor.return_value = torch.tensor([1.0, 2.0, 3.0])
|
||||||
|
|
||||||
result_logits = predictor.compute_logits(hidden_states=hidden_states,
|
result_logits = predictor.compute_logits(hidden_states=hidden_states)
|
||||||
sampling_metadata=None)
|
|
||||||
predictor.logits_processor.assert_called_once()
|
predictor.logits_processor.assert_called_once()
|
||||||
assert torch.allclose(result_logits, torch.tensor([1.0, 2.0, 3.0]))
|
assert torch.allclose(result_logits, torch.tensor([1.0, 2.0, 3.0]))
|
||||||
|
|
||||||
|
|||||||
@@ -175,7 +175,7 @@ class MtpProposer(Proposer):
|
|||||||
torchair_compiled_model(
|
torchair_compiled_model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
previous_hidden_states=previous_hidden_states,
|
hidden_states=previous_hidden_states,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
intermediate_tensors=None,
|
intermediate_tensors=None,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
@@ -460,7 +460,7 @@ class MtpProposer(Proposer):
|
|||||||
hidden_states = torchair_compiled_model(
|
hidden_states = torchair_compiled_model(
|
||||||
input_ids=self.input_ids[:num_input_tokens],
|
input_ids=self.input_ids[:num_input_tokens],
|
||||||
positions=self.positions[:num_input_tokens],
|
positions=self.positions[:num_input_tokens],
|
||||||
previous_hidden_states=self.
|
hidden_states=self.
|
||||||
hidden_states[:num_input_tokens],
|
hidden_states[:num_input_tokens],
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
intermediate_tensors=None,
|
intermediate_tensors=None,
|
||||||
|
|||||||
@@ -176,14 +176,12 @@ class TorchairDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
|
|||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
sampling_metadata=None, # type: ignore
|
|
||||||
spec_step_idx: int = 0,
|
spec_step_idx: int = 0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||||
mtp_layer = self.layers_list[current_step_idx]
|
mtp_layer = self.layers_list[current_step_idx]
|
||||||
logits = self.logits_processor(mtp_layer.shared_head.head,
|
logits = self.logits_processor(mtp_layer.shared_head.head,
|
||||||
mtp_layer.shared_head(hidden_states),
|
mtp_layer.shared_head(hidden_states))
|
||||||
sampling_metadata)
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
@@ -209,12 +207,12 @@ class TorchairDeepSeekMTP(DeepSeekMTP):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||||
attn_metadata: Optional[AttentionMetadata] = None,
|
attn_metadata: Optional[AttentionMetadata] = None,
|
||||||
previous_hidden_states: Optional[torch.Tensor] = None,
|
hidden_states: Optional[torch.Tensor] = None,
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
spec_step_idx: int = 0,
|
spec_step_idx: int = 0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||||
attn_metadata, previous_hidden_states,
|
attn_metadata, hidden_states, inputs_embeds,
|
||||||
inputs_embeds, spec_step_idx)
|
spec_step_idx)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|||||||
Reference in New Issue
Block a user