From 5c0a23f98b2442b8c86786d6205e4ad0f06767a4 Mon Sep 17 00:00:00 2001 From: ZYang6263 <50876451+ZYang6263@users.noreply.github.com> Date: Sat, 25 Oct 2025 08:20:43 +0800 Subject: [PATCH] [0.11.0][Perf] Add fused matmul/reduce-scatter kernel for performance optimization. (#3725) ### What this PR does / why we need it? This PR boosts performance by introducing a fused kernel for the matrix matmul and reduce scatter operations. It supports both unquantized (e.g., BFloat16) and W8A8 quantized models. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 ### What this PR does / why we need it? ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Signed-off-by: ZYang6263 --- vllm_ascend/ops/linear_op.py | 89 +++++++++++++++++++++++++++++++++--- 1 file changed, 82 insertions(+), 7 deletions(-) diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index ff0462e..b7000da 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -39,11 +39,15 @@ from typing import Optional, Union import torch import torch.distributed as dist +import torch.nn.functional as F import torch_npu from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter -from vllm.distributed import split_tensor_along_last_dim +from vllm.distributed import (split_tensor_along_last_dim, + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) from vllm.distributed.parallel_state import get_tp_group +from vllm.forward_context import get_forward_context from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, get_otp_group) @@ -375,12 +379,83 @@ class SequenceRowParallelOp(CustomRowParallelOp): def matmul_and_reduce(self, input_parallel: torch.Tensor, bias_: Optional[Parameter]) -> torch.Tensor: assert self.quant_method is not None - output_parallel = self.quant_method.apply(self.layer, - input_parallel, - bias=bias_) - from vllm_ascend.ops.register_custom_ops import \ - _maybe_pad_and_reduce_impl - output = _maybe_pad_and_reduce_impl(output_parallel) + try: + forward_context = get_forward_context() + sp_enabled = forward_context.sp_enabled + except AssertionError: + sp_enabled = False + + x = input_parallel + + if not sp_enabled: + output_parallel = self.layer.quant_method.apply(self.layer, + x, + bias=bias_) + return tensor_model_parallel_all_reduce(output_parallel) + + pad_size = forward_context.pad_size + if pad_size > 0: + x = F.pad(x, (0, 0, 0, pad_size)) + + world_size = self.layer.tp_size + comm_mode = "aiv" + hcom_name = get_tp_group().device_group._get_backend( + torch.device('npu')).get_hccl_comm_name(self.layer.tp_rank) + + from vllm.model_executor.layers.linear import UnquantizedLinearMethod + + from vllm_ascend.quantization.quant_config import AscendLinearMethod + from vllm_ascend.quantization.w8a8 import (AscendW8A8LinearMethod, + quant_per_tensor) + + # For unquant + if isinstance(self.layer.quant_method, UnquantizedLinearMethod + ) and torch.version.cann.startswith("8.3"): + output = torch_npu.npu_mm_reduce_scatter_base( + x, + self.layer.weight.t(), + hcom_name, + world_size, + reduce_op="sum", + bias=None, + comm_turn=0, + comm_mode=comm_mode) + if bias_ is not None: + output.add_(bias_) + # For w8a8 quant + elif (isinstance(self.layer.quant_method, AscendLinearMethod) + and isinstance(self.layer.quant_method.quant_method, + AscendW8A8LinearMethod) + ) and torch.version.cann.startswith("8.3"): + if x.dtype != torch.int8: + x_quant = quant_per_tensor( + x, self.layer.aclnn_input_scale_reciprocal, + self.layer.aclnn_input_offset) + else: + x_quant = x + quant_bias = self.layer.quant_bias + deq_scale = self.layer.deq_scale + output_dtype = torch.bfloat16 + output = torch_npu.npu_mm_reduce_scatter_base( + x_quant, + self.layer.weight, + hcom_name, + world_size, + reduce_op="sum", + bias=None, + comm_turn=0, + x2_scale=deq_scale, + output_dtype=output_dtype, + comm_mode=comm_mode) + output = torch.add( + output, + torch.mul(quant_bias, deq_scale).to(self.layer.params_dtype)) + else: + output_parallel = self.layer.quant_method.apply(self.layer, + x, + bias=bias_) + output = tensor_model_parallel_reduce_scatter(output_parallel, 0) + return output def update_attrs(self):