[bugfix] Use FUSED_MC2 MoE comm path for the op dispatch_ffn_combine (#5156)
### What this PR does / why we need it?
- Renames the MoE comm enum value `MoECommType.FUSED_ALLTOALL` to
`MoECommType.FUSED_MC2` and updates all call sites.
- Updates `select_moe_comm_method` to optionally select `FUSED_MC2` on
Ascend A3 when:
- `enable_expert_parallel=True`
- quantization is `w8a8_dynamic`
- `EP <= 16`
- `dynamic_eplb` is disabled
- `is_mtp_model = False`
- Replaces the old “fused all-to-all” comm implementation with
`FusedMC2CommImpl`, using `TokenDispatcherWithMC2` /
`PrepareAndFinalizeWithMC2` and `dispatch_ffn_combine`.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Chen Chen <0109chenchen@gmail.com>
This commit is contained in:
@@ -23,6 +23,7 @@ from vllm.config import CompilationMode, get_current_vllm_config
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
@@ -246,15 +247,16 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
w2 = [layer.w2_weight]
|
||||
w2_scale = [layer.w2_weight_scale]
|
||||
|
||||
fused_flag = get_forward_context(
|
||||
).moe_comm_type == MoECommType.FUSED_ALLTOALL
|
||||
fused_scale_flag = (get_forward_context().moe_comm_type
|
||||
== MoECommType.FUSED_MC2
|
||||
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1)
|
||||
return moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
pertoken_scale=pertoken_scale,
|
||||
w1=w1[0] if fused_flag else w1,
|
||||
w1_scale=layer.fused_w1_scale if fused_flag else w1_scale,
|
||||
w2=w2[0] if fused_flag else w2,
|
||||
w2_scale=layer.fused_w2_scale if fused_flag else w2_scale,
|
||||
w1=w1,
|
||||
w1_scale=[layer.fused_w1_scale] if fused_scale_flag else w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=[layer.fused_w2_scale] if fused_scale_flag else w2_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_int8_w8a8=True,
|
||||
|
||||
Reference in New Issue
Block a user