[Bugfix] [MoE] fix error in deepseek when using allgather (#3824)
### What this PR does / why we need it? After refactoring vllm_ascend/models and FusedMoE, we are unable to pass `gate` from deepseekv2.py to `AscendFusedMoE.forward`, which will result in error when running deepseek v3/r1 with allgather. Hence, this pr removes `gate` related computations from FusedMoE module in eager/aclgraph mode. ### Does this PR introduce _any_ user-facing change? `rm_router_logits` is deprecated in eager/aclgraph. ### How was this patch tested? e2e & ut - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/releases/v0.11.1 Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
@@ -64,13 +64,12 @@ class MoECommMethod(ABC):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
hidden_states, router_logits, mc2_mask, context_metadata = self.prepare_finalize.prepare(
|
||||
hidden_states, router_logits, enable_shared_expert_dp,
|
||||
replace_allreduce, gate)
|
||||
replace_allreduce)
|
||||
return hidden_states, router_logits, mc2_mask, context_metadata
|
||||
|
||||
def finalize(self,
|
||||
|
||||
@@ -27,7 +27,7 @@ from vllm.distributed.parallel_state import (
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.utils import enable_sp, get_rm_router_logits_state
|
||||
from vllm_ascend.utils import enable_sp
|
||||
|
||||
|
||||
class PrepareAndFinalize(ABC):
|
||||
@@ -44,10 +44,6 @@ class PrepareAndFinalize(ABC):
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
self.moe_config = moe_config
|
||||
is_deepseek_v3_r1 = self.moe_config.original_num_experts == 256
|
||||
self.rm_router_logits = get_rm_router_logits_state(
|
||||
self.moe_config.ep_size, self.moe_config.dp_size,
|
||||
is_deepseek_v3_r1)
|
||||
|
||||
@abstractmethod
|
||||
def prepare(
|
||||
@@ -55,8 +51,7 @@ class PrepareAndFinalize(ABC):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
@@ -64,14 +59,12 @@ class PrepareAndFinalize(ABC):
|
||||
- Padding to align communication boundaries
|
||||
- Slicing across tensor-parallel ranks
|
||||
- Broadcasting across data-parallel ranks
|
||||
- Recomputing router logits if needed
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): Input features, shape [num_tokens, hidden_size]
|
||||
router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts]
|
||||
enable_shared_expert_dp (bool): Skip DP communication for shared experts
|
||||
replace_allreduce (bool): Bypass default all-reduce behavior
|
||||
gate (nn.Module, optional): Gate network to recompute router_logits if needed
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
@@ -124,8 +117,7 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
@@ -221,8 +213,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
@@ -303,7 +294,6 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
@@ -318,7 +308,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
|
||||
return self._prepare_with_dp_group(hidden_states, router_logits,
|
||||
enable_shared_expert_dp,
|
||||
replace_allreduce, gate)
|
||||
replace_allreduce)
|
||||
|
||||
def _prepare_with_ep_group(
|
||||
self,
|
||||
@@ -339,7 +329,6 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
@@ -361,18 +350,14 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
if not self.rm_router_logits:
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
# All-gather across DP group
|
||||
hidden_states = self.moe_config.dp_group.all_gather(
|
||||
hidden_states, 0)
|
||||
if self.rm_router_logits:
|
||||
router_logits, _ = gate(hidden_states) # Recompute globally
|
||||
else:
|
||||
router_logits = self.moe_config.dp_group.all_gather(
|
||||
router_logits, 0)
|
||||
router_logits = self.moe_config.dp_group.all_gather(
|
||||
router_logits, 0)
|
||||
return hidden_states, router_logits, None, None
|
||||
|
||||
def finalize(self,
|
||||
@@ -474,8 +459,7 @@ class PrepareAndFinalizeWithNaiveMulticast(PrepareAndFinalize):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
@@ -493,11 +477,8 @@ class PrepareAndFinalizeWithNaiveMulticast(PrepareAndFinalize):
|
||||
).dp_metadata.cu_tokens_across_sp(1)
|
||||
hidden_states = self._naive_multicast(hidden_states,
|
||||
self.cu_tokens_across_dp_cpu)
|
||||
if self.rm_router_logits:
|
||||
router_logits, _ = gate(hidden_states)
|
||||
else:
|
||||
router_logits = self._naive_multicast(
|
||||
router_logits, self.cu_tokens_across_dp_cpu)
|
||||
router_logits = self._naive_multicast(router_logits,
|
||||
self.cu_tokens_across_dp_cpu)
|
||||
|
||||
return hidden_states, router_logits, None, None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user