[Feature]Use DispatchGmmCombineDecode operator to replace MC2(Optional) (#5040)
### What this PR does / why we need it?
This PR adds model-side integration for the previously introduced
experimental AscendC fused operator DispatchGmmCombineDecode, used in
MoE decoding.
The operator implementation itself was added in a prior PR[#4139
](https://github.com/vllm-project/vllm-ascend/pull/4139).
This change only adapts the model execution path to optionally use the
fused operator.
When the environment variable VLLM_ASCEND_ENABLE_FUSED_MC2=2 is set, the
original MC2 path composed of multiple operators (A8W8 dispatch → GMM →
SwiGLU → GMM → combine) might be replaced by the single fused operator
DispatchGmmCombineDecode.
By default, the existing multi-operator MC2 implementation is preserved.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
This commit is contained in:
@@ -345,7 +345,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
shared_out = fc3_context.shared_experts(hidden_states)
|
||||
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
|
||||
and not shared_expert_dp_enabled():
|
||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||
set_flash_common3_context(shared_out=shared_out)
|
||||
|
||||
Reference in New Issue
Block a user