[v0.18.0][BugFix] Fix Qwen3.5 MoE flash comm v1 shared expert shape error of mtp layer on A2 (#8004)

### What this PR does / why we need it?
Fix Qwen3.5 MoE MTP layer shared expert shape error when flash comm v1
is enabled.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?

- vLLM version: v0.18.0
- vLLM main:
35141a7eed

Signed-off-by: Wangbingjie <wangbj1207@126.com>
This commit is contained in:
wangbj127
2026-04-13 17:36:09 +08:00
committed by GitHub
parent 39c071a0f5
commit d94b1dc2d0
2 changed files with 2 additions and 4 deletions

View File

@@ -80,7 +80,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
enable_sp_by_pass() and is_ep_comm enable_sp_by_pass() and is_ep_comm
) )
if not flash_comm_v1_enabled or (forward_context.is_draft_model and is_vl_model()): if not flash_comm_v1_enabled or (forward_context.is_draft_model and is_vl_model() and not is_ep_comm):
return tensor_model_parallel_all_reduce(x) return tensor_model_parallel_all_reduce(x)
dp_metadata = forward_context.dp_metadata dp_metadata = forward_context.dp_metadata

View File

@@ -1595,10 +1595,8 @@ class SpecDecodeBaseProposer(EagleProposer):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if self.is_multimodal_model and _EXTRA_CTX.flash_comm_v1_enabled:
return hidden_states, positions
if self.method == "mtp": if self.method == "mtp":
if _EXTRA_CTX.flash_comm_v1_enabled: if _EXTRA_CTX.flash_comm_v1_enabled and not self.is_multimodal_model:
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states) hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states)
positions = positions.unsqueeze(-1) positions = positions.unsqueeze(-1)
positions = torch.ops.vllm.maybe_pad_and_reduce(positions) positions = torch.ops.vllm.maybe_pad_and_reduce(positions)