[MoE] [Refactor] Combine common_fused_moe and fused_moe (#3176)
### What this PR does / why we need it? 1. Move additional functionalities from fused_moe.py to common_fused_moe.py and remove fused_moe.py 2. Remove unnecessary custom classes from qwen3_moe.py, and it will be completely removed after we release vllm-ascend v0.11.0 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Qwen3-30B-A3B/Qwen3-30B-A3B-W8A8/DeepSeek-V3-W4A8-Pruing/deepseek-mtp/pangu-pro-moe-pruing: 1. Enable/Disable EP 3. Aclgraph & eager 4. SP - vLLM version: v0.11.0 --------- Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com> Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
@@ -26,6 +26,8 @@ from vllm.distributed.parallel_state import (
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.utils import get_rm_router_logits_state
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
@@ -41,13 +43,16 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
self.moe_config = moe_config
|
||||
is_deepseek_v3_r1 = self.moe_config.original_num_experts == 256
|
||||
self.rm_router_logits = get_rm_router_logits_state(
|
||||
self.moe_config.ep_size, self.moe_config.dp_size,
|
||||
is_deepseek_v3_r1)
|
||||
|
||||
@abstractmethod
|
||||
def prepare(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
rm_router_logits: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@@ -61,7 +66,6 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
hidden_states (torch.Tensor): Input features, shape [num_tokens, hidden_size]
|
||||
router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts]
|
||||
enable_shared_expert_dp (bool): Skip DP communication for shared experts
|
||||
rm_router_logits (bool): Discard input router_logits and recompute via gate
|
||||
replace_allreduce (bool): Bypass default all-reduce behavior
|
||||
gate (nn.Module, optional): Gate network to recompute router_logits if needed
|
||||
|
||||
@@ -116,7 +120,6 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
rm_router_logits: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@@ -215,7 +218,6 @@ class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
rm_router_logits: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@@ -294,7 +296,6 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
rm_router_logits: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@@ -302,7 +303,6 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
1. Fetch max token count across DP group from forward context.
|
||||
2. Pad local tensors to that size.
|
||||
3. All-gather across DP group to form global input tensor.
|
||||
4. Optionally recompute router_logits using gate if `rm_router_logits=True`.
|
||||
|
||||
Returns:
|
||||
Tuple of (global_hidden_states, global_router_logits, None)
|
||||
@@ -318,14 +318,14 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
if not rm_router_logits:
|
||||
if not self.rm_router_logits:
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
# All-gather across DP group
|
||||
hidden_states = self.moe_config.dp_group.all_gather(
|
||||
hidden_states, 0)
|
||||
if rm_router_logits:
|
||||
if self.rm_router_logits:
|
||||
router_logits, _ = gate(hidden_states) # Recompute globally
|
||||
else:
|
||||
router_logits = self.moe_config.dp_group.all_gather(
|
||||
@@ -399,14 +399,12 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
rm_router_logits: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Fetch cumulative token boundaries from forward context.
|
||||
2. Multicast hidden_states and router_logits to form global tensors.
|
||||
3. Optionally recompute router_logits globally if `rm_router_logits=True`.
|
||||
|
||||
Returns:
|
||||
Tuple of (global_hidden_states, global_router_logits, None)
|
||||
@@ -418,7 +416,7 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
|
||||
).dp_metadata.cu_tokens_across_sp(1)
|
||||
hidden_states = self._naive_multicast(hidden_states,
|
||||
self.cu_tokens_across_dp_cpu)
|
||||
if rm_router_logits:
|
||||
if self.rm_router_logits:
|
||||
router_logits, _ = gate(hidden_states)
|
||||
else:
|
||||
router_logits = self._naive_multicast(
|
||||
|
||||
@@ -67,12 +67,11 @@ class MoECommMethod(ABC):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
rm_router_logits: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
hidden_states, router_logits, mc2_mask = self.fused_moe_prepare_finalize.prepare(
|
||||
hidden_states, router_logits, enable_shared_expert_dp,
|
||||
rm_router_logits, replace_allreduce, gate)
|
||||
replace_allreduce, gate)
|
||||
self.mc2_mask = mc2_mask
|
||||
return hidden_states, router_logits
|
||||
|
||||
|
||||
@@ -468,9 +468,6 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
super().__init__(**kwargs)
|
||||
self.with_quant = False
|
||||
self.num_local_experts = kwargs.get("num_local_experts", 0)
|
||||
self.num_global_redundant_experts = kwargs.get(
|
||||
"num_global_redundant_experts", 0)
|
||||
self.num_experts = self.num_experts + self.num_global_redundant_experts
|
||||
|
||||
self.hidden_shape = None
|
||||
self.topk_weights = None
|
||||
|
||||
Reference in New Issue
Block a user