[Refactor] [MoE] Rename moe-related classes & files (#3646)
### What this PR does / why we need it?
1. Rename common_fused_moe.py to fused_moe.py.
2. Rename fused_moe_prepare_and_finalize.py / FusedMoEPrepareAndFinalize
to prepare_finalize.py / PrepareAndFinalize.
3. Rename vllm_ascend/ops/moe to vllm_ascend/ops/fused_moe.
4. Move vllm_ascend/ops/fused_moe.py to
vllm_ascend/ops/fused_moe/fused_moe.py
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e & ut
- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
This commit is contained in:
@@ -87,7 +87,8 @@ def set_ascend_forward_context(
|
||||
):
|
||||
forward_context = get_forward_context()
|
||||
|
||||
from vllm_ascend.ops.moe.moe_comm_method import get_moe_comm_method
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import \
|
||||
get_moe_comm_method
|
||||
forward_context.moe_comm_type = moe_comm_type
|
||||
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ from vllm.platforms import current_platform
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.models.layers.sfa import AscendSFAModules, Indexer
|
||||
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.ops.fused_moe.fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.ops.linear import AscendLinearBase
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
import vllm_ascend.ops.common_fused_moe # noqa
|
||||
import vllm_ascend.ops.fused_moe.fused_moe # noqa
|
||||
import vllm_ascend.ops.layernorm # noqa
|
||||
import vllm_ascend.ops.register_custom_ops # noqa
|
||||
import vllm_ascend.ops.vocab_parallel_embedding # noqa
|
||||
|
||||
@@ -35,8 +35,8 @@ from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
|
||||
determine_default_log2phy_map)
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, enable_sp, is_310p,
|
||||
is_enable_nz, npu_stream_switch,
|
||||
shared_expert_dp_enabled,
|
||||
@@ -24,15 +24,13 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
|
||||
FusedMoEPrepareAndFinalizeWithAll2All,
|
||||
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
|
||||
FusedMoEPrepareAndFinalizeWithNaiveMulticast)
|
||||
from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.moe.token_dispatcher import (TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather,
|
||||
TokenDispatcherWithMC2,
|
||||
TokenDispatcherWithMoge)
|
||||
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
||||
PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather,
|
||||
PrepareAndFinalizeWithMC2, PrepareAndFinalizeWithNaiveMulticast)
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import (
|
||||
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
|
||||
TokenDispatcherWithMC2, TokenDispatcherWithMoge)
|
||||
|
||||
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
|
||||
|
||||
@@ -59,8 +57,7 @@ class MoECommMethod(ABC):
|
||||
self.moe_config = moe_config
|
||||
|
||||
self.token_dispatcher = self._get_token_dispatcher()
|
||||
self.fused_moe_prepare_finalize = self._get_fused_moe_prepare_finalize(
|
||||
)
|
||||
self.prepare_finalize = self._get_prepare_finalize()
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
@@ -71,7 +68,7 @@ class MoECommMethod(ABC):
|
||||
gate=None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
hidden_states, router_logits, mc2_mask, context_metadata = self.fused_moe_prepare_finalize.prepare(
|
||||
hidden_states, router_logits, mc2_mask, context_metadata = self.prepare_finalize.prepare(
|
||||
hidden_states, router_logits, enable_shared_expert_dp,
|
||||
replace_allreduce, gate)
|
||||
return hidden_states, router_logits, mc2_mask, context_metadata
|
||||
@@ -80,8 +77,9 @@ class MoECommMethod(ABC):
|
||||
hidden_states: torch.Tensor,
|
||||
reduce_results: bool,
|
||||
context_metadata: Optional[dict] = None) -> torch.Tensor:
|
||||
hidden_states = self.fused_moe_prepare_finalize.finalize(
|
||||
hidden_states, reduce_results, context_metadata)
|
||||
hidden_states = self.prepare_finalize.finalize(hidden_states,
|
||||
reduce_results,
|
||||
context_metadata)
|
||||
return hidden_states
|
||||
|
||||
def fused_experts(
|
||||
@@ -169,9 +167,9 @@ class MoECommMethod(ABC):
|
||||
"_get_token_dispatcher function not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def _get_fused_moe_prepare_finalize(self):
|
||||
def _get_prepare_finalize(self):
|
||||
raise NotImplementedError(
|
||||
"_get_fused_moe_prepare_finalize function not implemented.")
|
||||
"_get_prepare_finalize function not implemented.")
|
||||
|
||||
|
||||
class AllGatherCommImpl(MoECommMethod):
|
||||
@@ -205,8 +203,8 @@ class AllGatherCommImpl(MoECommMethod):
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
|
||||
def _get_fused_moe_prepare_finalize(self):
|
||||
return FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config)
|
||||
def _get_prepare_finalize(self):
|
||||
return PrepareAndFinalizeWithAllGather(self.moe_config)
|
||||
|
||||
|
||||
class MC2CommImpl(MoECommMethod):
|
||||
@@ -222,8 +220,8 @@ class MC2CommImpl(MoECommMethod):
|
||||
def _get_token_dispatcher(self):
|
||||
return TokenDispatcherWithMC2()
|
||||
|
||||
def _get_fused_moe_prepare_finalize(self):
|
||||
return FusedMoEPrepareAndFinalizeWithMC2(self.moe_config)
|
||||
def _get_prepare_finalize(self):
|
||||
return PrepareAndFinalizeWithMC2(self.moe_config)
|
||||
|
||||
|
||||
class AlltoAllCommImpl(MoECommMethod):
|
||||
@@ -242,8 +240,8 @@ class AlltoAllCommImpl(MoECommMethod):
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
|
||||
def _get_fused_moe_prepare_finalize(self):
|
||||
return FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
def _get_prepare_finalize(self):
|
||||
return PrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
|
||||
|
||||
class NaiveMulticastCommImpl(MoECommMethod):
|
||||
@@ -271,5 +269,5 @@ class NaiveMulticastCommImpl(MoECommMethod):
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
|
||||
def _get_fused_moe_prepare_finalize(self):
|
||||
return FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config)
|
||||
def _get_prepare_finalize(self):
|
||||
return PrepareAndFinalizeWithNaiveMulticast(self.moe_config)
|
||||
@@ -30,7 +30,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
from vllm_ascend.utils import enable_sp, get_rm_router_logits_state
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
class PrepareAndFinalize(ABC):
|
||||
"""
|
||||
Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization
|
||||
in distributed environments. Subclasses implement specific communication strategies
|
||||
@@ -103,7 +103,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
raise NotImplementedError("Finalize function not implemented.")
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
|
||||
class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using All-to-All style slicing.
|
||||
Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing.
|
||||
@@ -195,7 +195,7 @@ class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All):
|
||||
class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
||||
"""
|
||||
MoE communication strategy using MC2, which is based on All2All. Hence, it inherits
|
||||
All2All and share the same finalize method.
|
||||
@@ -275,7 +275,7 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All):
|
||||
return hidden_states, router_logits, mc2_mask, context_metadata
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using All-Gather + Reduce-Scatter on EP group.
|
||||
There are two sets of prepare and finalize:
|
||||
@@ -429,7 +429,7 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
|
||||
class PrepareAndFinalizeWithNaiveMulticast(PrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using Naive Multicast (point-to-point broadcast).
|
||||
Will be used in prefill when using allgather in decode. Each DP rank broadcasts its slice to all others.
|
||||
@@ -28,7 +28,7 @@ import torch_npu
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.moe.comm_utils import (
|
||||
from vllm_ascend.ops.fused_moe.comm_utils import (
|
||||
async_all_to_all, gather_from_sequence_parallel_region)
|
||||
from vllm_ascend.utils import (AscendSocVersion, get_ascend_soc_version,
|
||||
is_hierarchical_communication_enabled)
|
||||
@@ -37,7 +37,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
|
||||
get_otp_group)
|
||||
from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod
|
||||
from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
|
||||
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
|
||||
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable,
|
||||
oproj_tp_enable)
|
||||
|
||||
@@ -26,7 +26,7 @@ from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from vllm.distributed.parallel_state import get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz
|
||||
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, is_enable_nz,
|
||||
vllm_version_is)
|
||||
|
||||
|
||||
@@ -538,8 +538,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
||||
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
|
||||
from vllm_ascend.models.layers.sfa import AscendSparseFlashAttention
|
||||
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
|
||||
from vllm_ascend.ops.common_fused_moe import (AscendFusedMoE,
|
||||
AscendSharedFusedMoE)
|
||||
from vllm_ascend.ops.fused_moe.fused_moe import (AscendFusedMoE,
|
||||
AscendSharedFusedMoE)
|
||||
from vllm_ascend.ops.layernorm import (AscendGemmaRMSNorm,
|
||||
AscendQuantRMSNorm, AscendRMSNorm)
|
||||
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
|
||||
|
||||
Reference in New Issue
Block a user