From 6188450269872f2e743518a1f7648f5cb40e9984 Mon Sep 17 00:00:00 2001 From: ZYang6263 <50876451+ZYang6263@users.noreply.github.com> Date: Tue, 28 Oct 2025 23:31:19 +0800 Subject: [PATCH] [v0.11.0][Bugfix]Avoid using the fusion operator in the MOE model (#3837) ### What this PR does / why we need it? The current MatmulReduceScatter operator experiences performance degradation in small-shape scenarios, so it determines whether to use this operator by judging the size of the shape. --------- Signed-off-by: ZYang6263 --- vllm_ascend/ascend_forward_context.py | 3 +++ vllm_ascend/ops/linear_op.py | 16 ++++++++++------ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 85348db..fa753f3 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -112,13 +112,16 @@ def set_ascend_forward_context( # Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold, # the performance benefits can be maximized. Conversely, if the concurrency is below the threshold, # the performance may degrade due to the switching of communication methods. + mmrs_fusion = True if is_moe_model(vllm_config): sp_enabled = enable_sp(vllm_config) and \ tp_world_size > 1 and num_tokens is not None + mmrs_fusion = False else: sp_enabled = enable_sp(vllm_config) and \ tp_world_size > 1 and \ num_tokens is not None and num_tokens > 1000 + forward_context.mmrs_fusion = mmrs_fusion if sp_enabled: pad_size = (tp_world_size - diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index b7000da..be7fa31 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -382,8 +382,10 @@ class SequenceRowParallelOp(CustomRowParallelOp): try: forward_context = get_forward_context() sp_enabled = forward_context.sp_enabled + mmrs_fusion = forward_context.mmrs_fusion except AssertionError: sp_enabled = False + mmrs_fusion = False x = input_parallel @@ -409,8 +411,9 @@ class SequenceRowParallelOp(CustomRowParallelOp): quant_per_tensor) # For unquant - if isinstance(self.layer.quant_method, UnquantizedLinearMethod - ) and torch.version.cann.startswith("8.3"): + if mmrs_fusion and isinstance( + self.layer.quant_method, UnquantizedLinearMethod + ) and torch.version.cann.startswith("8.3"): output = torch_npu.npu_mm_reduce_scatter_base( x, self.layer.weight.t(), @@ -423,10 +426,11 @@ class SequenceRowParallelOp(CustomRowParallelOp): if bias_ is not None: output.add_(bias_) # For w8a8 quant - elif (isinstance(self.layer.quant_method, AscendLinearMethod) - and isinstance(self.layer.quant_method.quant_method, - AscendW8A8LinearMethod) - ) and torch.version.cann.startswith("8.3"): + elif mmrs_fusion and ( + isinstance(self.layer.quant_method, AscendLinearMethod) + and isinstance(self.layer.quant_method.quant_method, + AscendW8A8LinearMethod) + ) and torch.version.cann.startswith("8.3"): if x.dtype != torch.int8: x_quant = quant_per_tensor( x, self.layer.aclnn_input_scale_reciprocal,