[Bugfix] [MoE] fix error in deepseek when using allgather (#3824)

### What this PR does / why we need it?
After refactoring vllm_ascend/models and FusedMoE, we are unable to pass
`gate` from deepseekv2.py to `AscendFusedMoE.forward`, which will result
in error when running deepseek v3/r1 with allgather.
Hence, this pr removes `gate` related computations from FusedMoE module
in eager/aclgraph mode.
### Does this PR introduce _any_ user-facing change?
`rm_router_logits` is deprecated in eager/aclgraph.
### How was this patch tested?
e2e & ut

- vLLM version: v0.11.0rc3
- vLLM main:
https://github.com/vllm-project/vllm/commit/releases/v0.11.1

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
weichen
2025-10-29 14:51:39 +08:00
committed by GitHub
parent 900086fdc6
commit 0d1859af08
7 changed files with 56 additions and 85 deletions

View File

@@ -64,13 +64,12 @@ class MoECommMethod(ABC):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None
replace_allreduce: bool = False
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
hidden_states, router_logits, mc2_mask, context_metadata = self.prepare_finalize.prepare(
hidden_states, router_logits, enable_shared_expert_dp,
replace_allreduce, gate)
replace_allreduce)
return hidden_states, router_logits, mc2_mask, context_metadata
def finalize(self,

View File

@@ -27,7 +27,7 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.utils import enable_sp, get_rm_router_logits_state
from vllm_ascend.utils import enable_sp
class PrepareAndFinalize(ABC):
@@ -44,10 +44,6 @@ class PrepareAndFinalize(ABC):
def __init__(self, moe_config: FusedMoEConfig):
self.moe_config = moe_config
is_deepseek_v3_r1 = self.moe_config.original_num_experts == 256
self.rm_router_logits = get_rm_router_logits_state(
self.moe_config.ep_size, self.moe_config.dp_size,
is_deepseek_v3_r1)
@abstractmethod
def prepare(
@@ -55,8 +51,7 @@ class PrepareAndFinalize(ABC):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None
replace_allreduce: bool = False
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
@@ -64,14 +59,12 @@ class PrepareAndFinalize(ABC):
- Padding to align communication boundaries
- Slicing across tensor-parallel ranks
- Broadcasting across data-parallel ranks
- Recomputing router logits if needed
Args:
hidden_states (torch.Tensor): Input features, shape [num_tokens, hidden_size]
router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts]
enable_shared_expert_dp (bool): Skip DP communication for shared experts
replace_allreduce (bool): Bypass default all-reduce behavior
gate (nn.Module, optional): Gate network to recompute router_logits if needed
Returns:
Tuple of:
@@ -124,8 +117,7 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None
replace_allreduce: bool = False
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
@@ -221,8 +213,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None
replace_allreduce: bool = False
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
@@ -303,7 +294,6 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
@@ -318,7 +308,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
return self._prepare_with_dp_group(hidden_states, router_logits,
enable_shared_expert_dp,
replace_allreduce, gate)
replace_allreduce)
def _prepare_with_ep_group(
self,
@@ -339,7 +329,6 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
@@ -361,18 +350,14 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
if not self.rm_router_logits:
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
# All-gather across DP group
hidden_states = self.moe_config.dp_group.all_gather(
hidden_states, 0)
if self.rm_router_logits:
router_logits, _ = gate(hidden_states) # Recompute globally
else:
router_logits = self.moe_config.dp_group.all_gather(
router_logits, 0)
router_logits = self.moe_config.dp_group.all_gather(
router_logits, 0)
return hidden_states, router_logits, None, None
def finalize(self,
@@ -474,8 +459,7 @@ class PrepareAndFinalizeWithNaiveMulticast(PrepareAndFinalize):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None
replace_allreduce: bool = False
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
@@ -493,11 +477,8 @@ class PrepareAndFinalizeWithNaiveMulticast(PrepareAndFinalize):
).dp_metadata.cu_tokens_across_sp(1)
hidden_states = self._naive_multicast(hidden_states,
self.cu_tokens_across_dp_cpu)
if self.rm_router_logits:
router_logits, _ = gate(hidden_states)
else:
router_logits = self._naive_multicast(
router_logits, self.cu_tokens_across_dp_cpu)
router_logits = self._naive_multicast(router_logits,
self.cu_tokens_across_dp_cpu)
return hidden_states, router_logits, None, None