[Feature]Use DispatchGmmCombineDecode operator to replace MC2(Optional) (#5040)
### What this PR does / why we need it?
This PR adds model-side integration for the previously introduced
experimental AscendC fused operator DispatchGmmCombineDecode, used in
MoE decoding.
The operator implementation itself was added in a prior PR[#4139
](https://github.com/vllm-project/vllm-ascend/pull/4139).
This change only adapts the model execution path to optionally use the
fused operator.
When the environment variable VLLM_ASCEND_ENABLE_FUSED_MC2=2 is set, the
original MC2 path composed of multiple operators (A8W8 dispatch → GMM →
SwiGLU → GMM → combine) might be replaced by the single fused operator
DispatchGmmCombineDecode.
By default, the existing multi-operator MC2 implementation is preserved.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
This commit is contained in:
@@ -231,6 +231,10 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
topk_weights = topk_weights.to(self.in_dtype)
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
# When VLLM_ASCEND_ENABLE_FUSED_MC2 == 2, use dispatch_gmm_combine_decode, need fp32 scale
|
||||
w2_weight_scale_fp32_flag = (
|
||||
get_forward_context().moe_comm_type == MoECommType.FUSED_MC2
|
||||
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2)
|
||||
if self.dynamic_eplb:
|
||||
w1 = layer.w13_weight_list
|
||||
w1_scale = layer.w13_weight_scale_fp32_list
|
||||
@@ -240,7 +244,10 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
w1 = [layer.w13_weight]
|
||||
w1_scale = [layer.w13_weight_scale_fp32]
|
||||
w2 = [layer.w2_weight]
|
||||
w2_scale = [layer.w2_weight_scale]
|
||||
w2_scale = [
|
||||
layer.w2_weight_scale_fp32
|
||||
if w2_weight_scale_fp32_flag else layer.w2_weight_scale
|
||||
]
|
||||
|
||||
fused_scale_flag = (get_forward_context().moe_comm_type
|
||||
== MoECommType.FUSED_MC2
|
||||
@@ -279,6 +286,8 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
layer.w13_weight_offset.data.shape[0], -1)
|
||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
|
||||
layer.w2_weight_scale.data.shape[0], -1)
|
||||
layer.w2_weight_scale_fp32 = layer.w2_weight_scale.data.to(
|
||||
torch.float32)
|
||||
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
||||
layer.w2_weight_offset.data.shape[0], -1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user