[Main][Bugfix]Avoid using the fusion operator in the MOE model (#3834)

### What this PR does / why we need it?
The current MatmulReduceScatter operator experiences performance
degradation in small-shape scenarios, so it determines whether to use
this operator by judging the size of the shape.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.11.0rc3
- vLLM main:
https://github.com/vllm-project/vllm/commit/releases/v0.11.1

---------

Signed-off-by: ZYang6263 <zy626375@gmail.com>
This commit is contained in:
ZYang6263
2025-10-28 23:30:27 +08:00
committed by GitHub
parent 90ae114569
commit d08401d1e7
2 changed files with 13 additions and 6 deletions

View File

@@ -113,13 +113,16 @@ def set_ascend_forward_context(
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold, # Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold, # the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
# the performance may degrade due to the switching of communication methods. # the performance may degrade due to the switching of communication methods.
mmrs_fusion = True
if is_moe_model(vllm_config): if is_moe_model(vllm_config):
sp_enabled = enable_sp(vllm_config) and \ sp_enabled = enable_sp(vllm_config) and \
tp_world_size > 1 and num_tokens is not None tp_world_size > 1 and num_tokens is not None
mmrs_fusion = False
else: else:
sp_enabled = enable_sp(vllm_config) and \ sp_enabled = enable_sp(vllm_config) and \
tp_world_size > 1 and \ tp_world_size > 1 and \
num_tokens is not None and num_tokens > 1000 num_tokens is not None and num_tokens > 1000
forward_context.mmrs_fusion = mmrs_fusion
if sp_enabled: if sp_enabled:
pad_size = (tp_world_size - pad_size = (tp_world_size -

View File

@@ -382,8 +382,10 @@ class SequenceRowParallelOp(CustomRowParallelOp):
try: try:
forward_context = get_forward_context() forward_context = get_forward_context()
sp_enabled = forward_context.sp_enabled sp_enabled = forward_context.sp_enabled
mmrs_fusion = forward_context.mmrs_fusion
except AssertionError: except AssertionError:
sp_enabled = False sp_enabled = False
mmrs_fusion = False
x = input_parallel x = input_parallel
@@ -409,7 +411,8 @@ class SequenceRowParallelOp(CustomRowParallelOp):
quant_per_tensor) quant_per_tensor)
# For unquant # For unquant
if isinstance(self.layer.quant_method, UnquantizedLinearMethod if mmrs_fusion and isinstance(
self.layer.quant_method, UnquantizedLinearMethod
) and torch.version.cann.startswith("8.3"): ) and torch.version.cann.startswith("8.3"):
output = torch_npu.npu_mm_reduce_scatter_base( output = torch_npu.npu_mm_reduce_scatter_base(
x, x,
@@ -423,7 +426,8 @@ class SequenceRowParallelOp(CustomRowParallelOp):
if bias_ is not None: if bias_ is not None:
output.add_(bias_) output.add_(bias_)
# For w8a8 quant # For w8a8 quant
elif (isinstance(self.layer.quant_method, AscendLinearMethod) elif mmrs_fusion and (
isinstance(self.layer.quant_method, AscendLinearMethod)
and isinstance(self.layer.quant_method.quant_method, and isinstance(self.layer.quant_method.quant_method,
AscendW8A8LinearMethod) AscendW8A8LinearMethod)
) and torch.version.cann.startswith("8.3"): ) and torch.version.cann.startswith("8.3"):