[Feature]Use DispatchGmmCombineDecode operator to replace MC2(Optional) (#5040)

### What this PR does / why we need it?

This PR adds model-side integration for the previously introduced
experimental AscendC fused operator DispatchGmmCombineDecode, used in
MoE decoding.

The operator implementation itself was added in a prior PR[#4139
](https://github.com/vllm-project/vllm-ascend/pull/4139).
This change only adapts the model execution path to optionally use the
fused operator.

When the environment variable VLLM_ASCEND_ENABLE_FUSED_MC2=2 is set, the
original MC2 path composed of multiple operators (A8W8 dispatch → GMM →
SwiGLU → GMM → combine) might be replaced by the single fused operator
DispatchGmmCombineDecode.

By default, the existing multi-operator MC2 implementation is preserved.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
This commit is contained in:
wangqiankun13
2025-12-21 15:23:59 +08:00
committed by GitHub
parent 67a0325cf2
commit 904c18f929
6 changed files with 51 additions and 9 deletions

View File

@@ -345,7 +345,7 @@ class AscendFusedMoE(FusedMoE):
shared_out = fc3_context.shared_experts(hidden_states)
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out)
set_flash_common3_context(shared_out=shared_out)

View File

@@ -291,9 +291,9 @@ class FusedMC2CommImpl(MoECommMethod):
assert not (
w1_scale is None or w2_scale is None
), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
out = torch.empty_like(hidden_states)
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
out = torch.empty_like(hidden_states)
torch.ops._C_ascend.dispatch_ffn_combine(
x=hidden_states,
weight1=w1[0],
@@ -307,7 +307,21 @@ class FusedMC2CommImpl(MoECommMethod):
out=out,
)
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
raise NotImplementedError()
assert expert_map is not None, "expert_map cannot be None."
out, _ = torch.ops._C_ascend.dispatch_gmm_combine_decode(
x=hidden_states,
expert_ids=topk_ids,
gmm1_permuted_weight=w1[0],
gmm1_permuted_weight_scale=w1_scale[0],
gmm2_weight=w2[0],
gmm2_weight_scale=w2_scale[0],
expert_smooth_scales=None,
expert_scales=topk_weights.to(torch.float32),
group_ep=self.token_dispatcher.moe_all_to_all_group_name,
ep_rank_size=self.token_dispatcher.ep_world_size,
ep_rank_id=self.token_dispatcher.ep_rank_id,
moe_expert_num=len(expert_map),
global_bs=self.token_dispatcher.fused_global_bs)
else:
raise ValueError(
f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")

View File

@@ -125,6 +125,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
max_num_tokens = min(max_num_reqs * uniform_decode_query_len, 512)
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
self.global_bs = num_tokens_per_tp_rank * self.ep_world_size
self.fused_global_bs = max_num_tokens * self.ep_world_size
def get_dispatch_mc2_kwargs(
self,