[Main][Perf] Add fused matmul/reduce-scatter kernel for performance optimization. (#3693)

### What this PR does / why we need it?
This PR boosts performance by introducing a fused kernel for the matrix
matmul and reduce scatter operations. It supports both unquantized
(e.g., BFloat16) and W8A8 quantized models.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: ZYang6263 <zy626375@gmail.com>
This commit is contained in:
ZYang6263
2025-10-24 18:19:58 +08:00
committed by GitHub
parent 82a4970fe9
commit 0b1da24742

View File

@@ -39,11 +39,15 @@ from typing import Optional, Union
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch_npu
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from vllm.distributed import split_tensor_along_last_dim
from vllm.distributed import (split_tensor_along_last_dim,
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter)
from vllm.distributed.parallel_state import get_tp_group
from vllm.forward_context import get_forward_context
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
get_otp_group)
@@ -375,12 +379,83 @@ class SequenceRowParallelOp(CustomRowParallelOp):
def matmul_and_reduce(self, input_parallel: torch.Tensor,
bias_: Optional[Parameter]) -> torch.Tensor:
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self.layer,
input_parallel,
bias=bias_)
from vllm_ascend.ops.register_custom_ops import \
_maybe_pad_and_reduce_impl
output = _maybe_pad_and_reduce_impl(output_parallel)
try:
forward_context = get_forward_context()
sp_enabled = forward_context.sp_enabled
except AssertionError:
sp_enabled = False
x = input_parallel
if not sp_enabled:
output_parallel = self.layer.quant_method.apply(self.layer,
x,
bias=bias_)
return tensor_model_parallel_all_reduce(output_parallel)
pad_size = forward_context.pad_size
if pad_size > 0:
x = F.pad(x, (0, 0, 0, pad_size))
world_size = self.layer.tp_size
comm_mode = "aiv"
hcom_name = get_tp_group().device_group._get_backend(
torch.device('npu')).get_hccl_comm_name(self.layer.tp_rank)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.quantization.w8a8 import (AscendW8A8LinearMethod,
quant_per_tensor)
# For unquant
if isinstance(self.layer.quant_method, UnquantizedLinearMethod
) and torch.version.cann.startswith("8.3"):
output = torch_npu.npu_mm_reduce_scatter_base(
x,
self.layer.weight.t(),
hcom_name,
world_size,
reduce_op="sum",
bias=None,
comm_turn=0,
comm_mode=comm_mode)
if bias_ is not None:
output.add_(bias_)
# For w8a8 quant
elif (isinstance(self.layer.quant_method, AscendLinearMethod)
and isinstance(self.layer.quant_method.quant_method,
AscendW8A8LinearMethod)
) and torch.version.cann.startswith("8.3"):
if x.dtype != torch.int8:
x_quant = quant_per_tensor(
x, self.layer.aclnn_input_scale_reciprocal,
self.layer.aclnn_input_offset)
else:
x_quant = x
quant_bias = self.layer.quant_bias
deq_scale = self.layer.deq_scale
output_dtype = torch.bfloat16
output = torch_npu.npu_mm_reduce_scatter_base(
x_quant,
self.layer.weight,
hcom_name,
world_size,
reduce_op="sum",
bias=None,
comm_turn=0,
x2_scale=deq_scale,
output_dtype=output_dtype,
comm_mode=comm_mode)
output = torch.add(
output,
torch.mul(quant_bias, deq_scale).to(self.layer.params_dtype))
else:
output_parallel = self.layer.quant_method.apply(self.layer,
x,
bias=bias_)
output = tensor_model_parallel_reduce_scatter(output_parallel, 0)
return output
def update_attrs(self):