[Refactor] [MoE] Rename moe-related classes & files (#3646)

### What this PR does / why we need it?
1. Rename common_fused_moe.py to fused_moe.py.
2. Rename fused_moe_prepare_and_finalize.py / FusedMoEPrepareAndFinalize
to prepare_finalize.py / PrepareAndFinalize.
3. Rename vllm_ascend/ops/moe to vllm_ascend/ops/fused_moe.
4. Move vllm_ascend/ops/fused_moe.py to
vllm_ascend/ops/fused_moe/fused_moe.py
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e & ut

- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
weichen
2025-10-25 11:22:03 +08:00
committed by GitHub
parent 0637e8f021
commit 63c363d3de
25 changed files with 183 additions and 199 deletions

View File

@@ -87,7 +87,8 @@ def set_ascend_forward_context(
):
forward_context = get_forward_context()
from vllm_ascend.ops.moe.moe_comm_method import get_moe_comm_method
from vllm_ascend.ops.fused_moe.moe_comm_method import \
get_moe_comm_method
forward_context.moe_comm_type = moe_comm_type
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)

View File

@@ -66,7 +66,7 @@ from vllm.platforms import current_platform
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.models.layers.sfa import AscendSFAModules, Indexer
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
from vllm_ascend.ops.fused_moe.fused_moe import AscendFusedMoE
from vllm_ascend.ops.linear import AscendLinearBase
from vllm_ascend.utils import vllm_version_is

View File

@@ -17,7 +17,7 @@
import torch
import vllm_ascend.ops.common_fused_moe # noqa
import vllm_ascend.ops.fused_moe.fused_moe # noqa
import vllm_ascend.ops.layernorm # noqa
import vllm_ascend.ops.register_custom_ops # noqa
import vllm_ascend.ops.vocab_parallel_embedding # noqa

View File

@@ -35,8 +35,8 @@ from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
determine_default_log2phy_map)
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, enable_sp, is_310p,
is_enable_nz, npu_stream_switch,
shared_expert_dp_enabled,

View File

@@ -24,15 +24,13 @@ from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
FusedMoEPrepareAndFinalizeWithAll2All,
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
FusedMoEPrepareAndFinalizeWithNaiveMulticast)
from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.moe.token_dispatcher import (TokenDispatcherWithAll2AllV,
TokenDispatcherWithAllGather,
TokenDispatcherWithMC2,
TokenDispatcherWithMoge)
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.fused_moe.prepare_finalize import (
PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather,
PrepareAndFinalizeWithMC2, PrepareAndFinalizeWithNaiveMulticast)
from vllm_ascend.ops.fused_moe.token_dispatcher import (
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
TokenDispatcherWithMC2, TokenDispatcherWithMoge)
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
@@ -59,8 +57,7 @@ class MoECommMethod(ABC):
self.moe_config = moe_config
self.token_dispatcher = self._get_token_dispatcher()
self.fused_moe_prepare_finalize = self._get_fused_moe_prepare_finalize(
)
self.prepare_finalize = self._get_prepare_finalize()
def prepare(
self,
@@ -71,7 +68,7 @@ class MoECommMethod(ABC):
gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
hidden_states, router_logits, mc2_mask, context_metadata = self.fused_moe_prepare_finalize.prepare(
hidden_states, router_logits, mc2_mask, context_metadata = self.prepare_finalize.prepare(
hidden_states, router_logits, enable_shared_expert_dp,
replace_allreduce, gate)
return hidden_states, router_logits, mc2_mask, context_metadata
@@ -80,8 +77,9 @@ class MoECommMethod(ABC):
hidden_states: torch.Tensor,
reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor:
hidden_states = self.fused_moe_prepare_finalize.finalize(
hidden_states, reduce_results, context_metadata)
hidden_states = self.prepare_finalize.finalize(hidden_states,
reduce_results,
context_metadata)
return hidden_states
def fused_experts(
@@ -169,9 +167,9 @@ class MoECommMethod(ABC):
"_get_token_dispatcher function not implemented.")
@abstractmethod
def _get_fused_moe_prepare_finalize(self):
def _get_prepare_finalize(self):
raise NotImplementedError(
"_get_fused_moe_prepare_finalize function not implemented.")
"_get_prepare_finalize function not implemented.")
class AllGatherCommImpl(MoECommMethod):
@@ -205,8 +203,8 @@ class AllGatherCommImpl(MoECommMethod):
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
def _get_fused_moe_prepare_finalize(self):
return FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config)
def _get_prepare_finalize(self):
return PrepareAndFinalizeWithAllGather(self.moe_config)
class MC2CommImpl(MoECommMethod):
@@ -222,8 +220,8 @@ class MC2CommImpl(MoECommMethod):
def _get_token_dispatcher(self):
return TokenDispatcherWithMC2()
def _get_fused_moe_prepare_finalize(self):
return FusedMoEPrepareAndFinalizeWithMC2(self.moe_config)
def _get_prepare_finalize(self):
return PrepareAndFinalizeWithMC2(self.moe_config)
class AlltoAllCommImpl(MoECommMethod):
@@ -242,8 +240,8 @@ class AlltoAllCommImpl(MoECommMethod):
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
def _get_fused_moe_prepare_finalize(self):
return FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config)
def _get_prepare_finalize(self):
return PrepareAndFinalizeWithAll2All(self.moe_config)
class NaiveMulticastCommImpl(MoECommMethod):
@@ -271,5 +269,5 @@ class NaiveMulticastCommImpl(MoECommMethod):
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
def _get_fused_moe_prepare_finalize(self):
return FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config)
def _get_prepare_finalize(self):
return PrepareAndFinalizeWithNaiveMulticast(self.moe_config)

View File

@@ -30,7 +30,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.utils import enable_sp, get_rm_router_logits_state
class FusedMoEPrepareAndFinalize(ABC):
class PrepareAndFinalize(ABC):
"""
Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization
in distributed environments. Subclasses implement specific communication strategies
@@ -103,7 +103,7 @@ class FusedMoEPrepareAndFinalize(ABC):
raise NotImplementedError("Finalize function not implemented.")
class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
"""
MoE communication strategy using All-to-All style slicing.
Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing.
@@ -195,7 +195,7 @@ class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
return hidden_states
class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All):
class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
"""
MoE communication strategy using MC2, which is based on All2All. Hence, it inherits
All2All and share the same finalize method.
@@ -275,7 +275,7 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All):
return hidden_states, router_logits, mc2_mask, context_metadata
class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
"""
MoE communication strategy using All-Gather + Reduce-Scatter on EP group.
There are two sets of prepare and finalize:
@@ -429,7 +429,7 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
return hidden_states
class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
class PrepareAndFinalizeWithNaiveMulticast(PrepareAndFinalize):
"""
MoE communication strategy using Naive Multicast (point-to-point broadcast).
Will be used in prefill when using allgather in decode. Each DP rank broadcasts its slice to all others.

View File

@@ -28,7 +28,7 @@ import torch_npu
from vllm.distributed.parallel_state import get_ep_group
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.moe.comm_utils import (
from vllm_ascend.ops.fused_moe.comm_utils import (
async_all_to_all, gather_from_sequence_parallel_region)
from vllm_ascend.utils import (AscendSocVersion, get_ascend_soc_version,
is_hierarchical_communication_enabled)

View File

@@ -37,7 +37,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
get_otp_group)
from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod
from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable,
oproj_tp_enable)

View File

@@ -26,7 +26,7 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz

View File

@@ -24,7 +24,7 @@ from vllm.distributed.parallel_state import get_ep_group
from vllm.forward_context import get_forward_context
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz

View File

@@ -25,7 +25,7 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, is_enable_nz,
vllm_version_is)

View File

@@ -538,8 +538,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
from vllm_ascend.models.layers.sfa import AscendSparseFlashAttention
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.common_fused_moe import (AscendFusedMoE,
AscendSharedFusedMoE)
from vllm_ascend.ops.fused_moe.fused_moe import (AscendFusedMoE,
AscendSharedFusedMoE)
from vllm_ascend.ops.layernorm import (AscendGemmaRMSNorm,
AscendQuantRMSNorm, AscendRMSNorm)
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,