bugfix for mtp with multistream_moe (#3419)

### What this PR does / why we need it?
when infer deepseek mtp layer with multistream_moe, we should pass a
boolean to evaluate this feature and fix bugs when we are in mtp layer

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: zouyida2052 <zouyida2002@gmail.com>
This commit is contained in:
zouyida2052
2025-10-15 08:59:58 +08:00
committed by GitHub
parent c2c1db78a7
commit 3642b64afc
5 changed files with 22 additions and 11 deletions

View File

@@ -975,7 +975,7 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
# to save npu memory because they're no longer used.
dispose_tensor(previous_hidden_states)
dispose_tensor(previous_residual)
if mla_moe_communication and self.layer_idx > self.first_k_dense_replace:
if mla_moe_communication and self.layer_idx > self.first_k_dense_replace and self.layer_idx < self.layers:
hidden_states = tensor_model_parallel_all_gather(hidden_states,
dim=0)
@@ -1034,7 +1034,7 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states *= 1. / self.routed_scaling_factor
if mla_moe_communication and self.layer_idx == self.layers - 1:
if mla_moe_communication and self.layer_idx >= self.layers - 1:
hidden_states = tensor_model_parallel_all_gather(hidden_states,
dim=0)
residual = tensor_model_parallel_all_gather(residual, dim=0)