[v0.11.0] [Bugfix] [MoE]fix error in deepseek when using allgather (#3827)
### What this PR does / why we need it? After refactoring vllm_ascend/models and FusedMoE, we are unable to pass `gate` from deepseekv2.py to `AscendFusedMoE.forward`, which will result in error when running deepseek v3/r1 with allgather. Hence, this pr removes `gate` related computations from FusedMoE module in eager/aclgraph mode. ### Does this PR introduce _any_ user-facing change? `rm_router_logits` is deprecated in eager/aclgraph. ### How was this patch tested? e2e & ut Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
@@ -26,7 +26,7 @@ from vllm.distributed.parallel_state import (
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.utils import enable_sp, get_rm_router_logits_state
|
||||
from vllm_ascend.utils import enable_sp
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
@@ -43,31 +43,26 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
self.moe_config = moe_config
|
||||
is_deepseek_v3_r1 = self.moe_config.original_num_experts == 256
|
||||
self.rm_router_logits = get_rm_router_logits_state(
|
||||
self.moe_config.ep_size, self.moe_config.dp_size,
|
||||
is_deepseek_v3_r1)
|
||||
|
||||
@abstractmethod
|
||||
def prepare(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Prepare tensors before MoE computation. May involve:
|
||||
- Padding to align communication boundaries
|
||||
- Slicing across tensor-parallel ranks
|
||||
- Broadcasting across data-parallel ranks
|
||||
- Recomputing router logits if needed
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): Input features, shape [num_tokens, hidden_size]
|
||||
router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts]
|
||||
enable_shared_expert_dp (bool): Skip DP communication for shared experts
|
||||
replace_allreduce (bool): Bypass default all-reduce behavior
|
||||
gate (nn.Module, optional): Gate network to recompute router_logits if needed
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
@@ -116,12 +111,13 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
def prepare(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Fetch `mc2_mask` and target padding length from forward context.
|
||||
@@ -214,12 +210,13 @@ class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
def prepare(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Pad hidden_states and router_logits to next multiple of TP size.
|
||||
@@ -307,12 +304,13 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
TP AG → Attn → TP RS → EP AG → MoE → EP RS
|
||||
"""
|
||||
|
||||
def prepare(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
AllGather hidden_states and router_logits to form global tensors.
|
||||
@@ -325,7 +323,7 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
|
||||
return self._prepare_with_dp_group(hidden_states, router_logits,
|
||||
enable_shared_expert_dp,
|
||||
replace_allreduce, gate)
|
||||
replace_allreduce)
|
||||
|
||||
def _prepare_with_ep_group(
|
||||
self,
|
||||
@@ -340,12 +338,12 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
return hidden_states, router_logits, None
|
||||
|
||||
def _prepare_with_dp_group(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Fetch max token count across DP group from forward context.
|
||||
@@ -365,18 +363,14 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
if not self.rm_router_logits:
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
# All-gather across DP group
|
||||
hidden_states = self.moe_config.dp_group.all_gather(
|
||||
hidden_states, 0)
|
||||
if self.rm_router_logits:
|
||||
router_logits, _ = gate(hidden_states) # Recompute globally
|
||||
else:
|
||||
router_logits = self.moe_config.dp_group.all_gather(
|
||||
router_logits, 0)
|
||||
router_logits = self.moe_config.dp_group.all_gather(
|
||||
router_logits, 0)
|
||||
|
||||
return hidden_states, router_logits, None
|
||||
|
||||
@@ -472,12 +466,13 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
|
||||
get_dp_group().broadcast(buffer[start:end, :], idx)
|
||||
return buffer
|
||||
|
||||
def prepare(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Fetch cumulative token boundaries from forward context.
|
||||
@@ -493,11 +488,8 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
|
||||
).dp_metadata.cu_tokens_across_sp(1)
|
||||
hidden_states = self._naive_multicast(hidden_states,
|
||||
self.cu_tokens_across_dp_cpu)
|
||||
if self.rm_router_logits:
|
||||
router_logits, _ = gate(hidden_states)
|
||||
else:
|
||||
router_logits = self._naive_multicast(
|
||||
router_logits, self.cu_tokens_across_dp_cpu)
|
||||
router_logits = self._naive_multicast(router_logits,
|
||||
self.cu_tokens_across_dp_cpu)
|
||||
|
||||
return hidden_states, router_logits, None
|
||||
|
||||
|
||||
@@ -63,15 +63,16 @@ class MoECommMethod(ABC):
|
||||
self.fused_moe_prepare_finalize = self._get_fused_moe_prepare_finalize(
|
||||
)
|
||||
|
||||
def prepare(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def prepare(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
hidden_states, router_logits, mc2_mask = self.fused_moe_prepare_finalize.prepare(
|
||||
hidden_states, router_logits, enable_shared_expert_dp,
|
||||
replace_allreduce, gate)
|
||||
replace_allreduce)
|
||||
self.mc2_mask = mc2_mask
|
||||
return hidden_states, router_logits
|
||||
|
||||
|
||||
Reference in New Issue
Block a user