[bugfix] Use FUSED_MC2 MoE comm path for the op dispatch_ffn_combine (#5156)
### What this PR does / why we need it?
- Renames the MoE comm enum value `MoECommType.FUSED_ALLTOALL` to
`MoECommType.FUSED_MC2` and updates all call sites.
- Updates `select_moe_comm_method` to optionally select `FUSED_MC2` on
Ascend A3 when:
- `enable_expert_parallel=True`
- quantization is `w8a8_dynamic`
- `EP <= 16`
- `dynamic_eplb` is disabled
- `is_mtp_model = False`
- Replaces the old “fused all-to-all” comm implementation with
`FusedMC2CommImpl`, using `TokenDispatcherWithMC2` /
`PrepareAndFinalizeWithMC2` and `dispatch_ffn_combine`.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Chen Chen <0109chenchen@gmail.com>
This commit is contained in:
@@ -533,7 +533,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL} \
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
|
||||
and not shared_expert_dp_enabled():
|
||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||
else:
|
||||
|
||||
@@ -22,6 +22,7 @@ import torch
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
||||
@@ -43,8 +44,7 @@ def setup_moe_comm_method(moe_config):
|
||||
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config)
|
||||
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config)
|
||||
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
|
||||
_MoECommMethods[MoECommType.FUSED_ALLTOALL] = FusedAlltoAllCommImpl(
|
||||
moe_config)
|
||||
_MoECommMethods[MoECommType.FUSED_MC2] = FusedMC2CommImpl(moe_config)
|
||||
|
||||
|
||||
class MoECommMethod(ABC):
|
||||
@@ -241,30 +241,27 @@ class AlltoAllCommImpl(MoECommMethod):
|
||||
return PrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
|
||||
|
||||
class FusedAlltoAllCommImpl(MoECommMethod):
|
||||
class FusedMC2CommImpl(MoECommMethod):
|
||||
"""This implementation is for the scenarios listed below:
|
||||
1. `enable_expert_parallel=True`.
|
||||
2. `npu_grouped_matmul` is available.
|
||||
|
||||
This implementation uses all-to-all communication to exchange tokens
|
||||
between data parallel ranks before and after the MLP computation. It should
|
||||
have better performance than AllGatherCommImpl when DP size > 1.
|
||||
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
|
||||
3. `enable_expert_parallel=False` is not supported.
|
||||
|
||||
This implementation uses the MC2 communication method, which is optimized for
|
||||
Communication and Computation parallelism on Ascend devices.
|
||||
"""
|
||||
|
||||
def _get_token_dispatcher(self):
|
||||
return TokenDispatcherWithAll2AllV(
|
||||
top_k=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
return TokenDispatcherWithMC2()
|
||||
|
||||
def _get_prepare_finalize(self):
|
||||
return PrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
return PrepareAndFinalizeWithMC2(self.moe_config)
|
||||
|
||||
def fused_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1: torch.Tensor | list[torch.Tensor],
|
||||
w2: torch.Tensor | list[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
@@ -274,8 +271,8 @@ class FusedAlltoAllCommImpl(MoECommMethod):
|
||||
use_int4_w4a16: bool = False,
|
||||
global_num_experts: Optional[int] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[list[torch.Tensor]] = None,
|
||||
w2_scale: Optional[list[torch.Tensor]] = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
w1_offset: Optional[torch.Tensor] = None,
|
||||
@@ -291,18 +288,27 @@ class FusedAlltoAllCommImpl(MoECommMethod):
|
||||
dynamic_eplb: bool = False,
|
||||
mc2_mask: torch.Tensor = None,
|
||||
pertoken_scale: Optional[torch.Tensor] = None):
|
||||
assert not (
|
||||
w1_scale is None or w2_scale is None
|
||||
), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
|
||||
out = torch.empty_like(hidden_states)
|
||||
|
||||
torch.ops._C_ascend.dispatch_ffn_combine(
|
||||
x=hidden_states,
|
||||
weight1=w1,
|
||||
weight2=w2,
|
||||
expert_idx=topk_ids,
|
||||
scale1=w1_scale,
|
||||
scale2=w2_scale,
|
||||
probs=topk_weights.to(torch.float32),
|
||||
group=self.token_dispatcher.moe_all_to_all_group_name,
|
||||
max_output_size=65536,
|
||||
out=out,
|
||||
)
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||
torch.ops._C_ascend.dispatch_ffn_combine(
|
||||
x=hidden_states,
|
||||
weight1=w1[0],
|
||||
weight2=w2[0],
|
||||
expert_idx=topk_ids,
|
||||
scale1=w1_scale[0],
|
||||
scale2=w2_scale[0],
|
||||
probs=topk_weights.to(torch.float32),
|
||||
group=self.token_dispatcher.moe_all_to_all_group_name,
|
||||
max_output_size=65536,
|
||||
out=out,
|
||||
)
|
||||
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user