[Feature]Use DispatchGmmCombineDecode operator to replace MC2(Optional) (#5040)
### What this PR does / why we need it?
This PR adds model-side integration for the previously introduced
experimental AscendC fused operator DispatchGmmCombineDecode, used in
MoE decoding.
The operator implementation itself was added in a prior PR[#4139
](https://github.com/vllm-project/vllm-ascend/pull/4139).
This change only adapts the model execution path to optionally use the
fused operator.
When the environment variable VLLM_ASCEND_ENABLE_FUSED_MC2=2 is set, the
original MC2 path composed of multiple operators (A8W8 dispatch → GMM →
SwiGLU → GMM → combine) might be replaced by the single fused operator
DispatchGmmCombineDecode.
By default, the existing multi-operator MC2 implementation is preserved.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
This commit is contained in:
@@ -345,7 +345,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
shared_out = fc3_context.shared_experts(hidden_states)
|
||||
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
|
||||
and not shared_expert_dp_enabled():
|
||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||
set_flash_common3_context(shared_out=shared_out)
|
||||
|
||||
@@ -291,9 +291,9 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
assert not (
|
||||
w1_scale is None or w2_scale is None
|
||||
), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
|
||||
out = torch.empty_like(hidden_states)
|
||||
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||
out = torch.empty_like(hidden_states)
|
||||
torch.ops._C_ascend.dispatch_ffn_combine(
|
||||
x=hidden_states,
|
||||
weight1=w1[0],
|
||||
@@ -307,7 +307,21 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
out=out,
|
||||
)
|
||||
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
||||
raise NotImplementedError()
|
||||
assert expert_map is not None, "expert_map cannot be None."
|
||||
out, _ = torch.ops._C_ascend.dispatch_gmm_combine_decode(
|
||||
x=hidden_states,
|
||||
expert_ids=topk_ids,
|
||||
gmm1_permuted_weight=w1[0],
|
||||
gmm1_permuted_weight_scale=w1_scale[0],
|
||||
gmm2_weight=w2[0],
|
||||
gmm2_weight_scale=w2_scale[0],
|
||||
expert_smooth_scales=None,
|
||||
expert_scales=topk_weights.to(torch.float32),
|
||||
group_ep=self.token_dispatcher.moe_all_to_all_group_name,
|
||||
ep_rank_size=self.token_dispatcher.ep_world_size,
|
||||
ep_rank_id=self.token_dispatcher.ep_rank_id,
|
||||
moe_expert_num=len(expert_map),
|
||||
global_bs=self.token_dispatcher.fused_global_bs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
|
||||
|
||||
@@ -125,6 +125,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
max_num_tokens = min(max_num_reqs * uniform_decode_query_len, 512)
|
||||
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
|
||||
self.global_bs = num_tokens_per_tp_rank * self.ep_world_size
|
||||
self.fused_global_bs = max_num_tokens * self.ep_world_size
|
||||
|
||||
def get_dispatch_mc2_kwargs(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user