[Main] [Refactor] Enable MoECommMethod in Eager Mode (#2791)
### What this PR does / why we need it?
1. Replace prepare/finalize operation in fused_moe.py by
moe_comm_method.prepare()/finalize()
2. Replace unified_fused_experts by moe_comm_method.fused_experts() in
fused_moe.py/w8a8_dynamic.py/w4a8_dynamic.py
3. Add calling _select_moe_comm_method in spec-decode proposers.
4. Currently, w4a8_dynamic does not support gatherep, use all2allv
instead.
5. Remove redundant code.
### Does this PR introduce _any_ user-facing change?
AllgatherEP switch is disabled in aclgraph/eager mode, just follow the
rules in modelrunner_v1._select_moe_comm_method()
### How was this patch tested?
e2e & ut
- vLLM version: v0.10.2
- vLLM main:
7f6f2c1182
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
@@ -32,8 +32,8 @@ from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
|
||||
AlltoAllCommImpl, MC2CommImpl)
|
||||
from vllm_ascend.ops.moe.token_dispatcher import setup_token_dispatchers
|
||||
AlltoAllCommImpl, MC2CommImpl,
|
||||
NaiveMulticastCommImpl)
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
|
||||
|
||||
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
|
||||
@@ -280,17 +280,17 @@ class AscendFusedMoE(FusedMoE):
|
||||
num_redundant_experts,
|
||||
has_bias,
|
||||
)
|
||||
setup_token_dispatchers(self.moe_config.ep_size,
|
||||
top_k=self.top_k,
|
||||
num_experts=self.global_num_experts,
|
||||
num_local_experts=self.local_num_experts)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.moe_config.tp_group = get_tp_group()
|
||||
self.moe_config.dp_group = get_dp_group()
|
||||
self.moe_config.ep_group = get_ep_group()
|
||||
self.moe_config.mc2_group = get_mc2_group()
|
||||
|
||||
for method in {AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl}:
|
||||
for method in {
|
||||
AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl,
|
||||
NaiveMulticastCommImpl
|
||||
}:
|
||||
setattr(
|
||||
self, method.__name__.lower(),
|
||||
method(moe_config=self.moe_config)) # type: ignore[abstract]
|
||||
|
||||
Reference in New Issue
Block a user