[0.11.0][Perf] Add fused matmul/reduce-scatter kernel for performance optimization. (#3725)

### What this PR does / why we need it?
This PR boosts performance by introducing a fused kernel for the matrix
matmul and reduce scatter operations. It supports both unquantized
(e.g., BFloat16) and W8A8 quantized models.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.

- Please clarify why the changes are needed. For instance, the use case
and bug description.

- Fixes #
-->

### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->

### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

Signed-off-by: ZYang6263 <zy626375@gmail.com>
This commit is contained in:
ZYang6263
2025-10-25 08:20:43 +08:00
committed by GitHub
parent 17dd9ae42c
commit 5c0a23f98b

View File

@@ -39,11 +39,15 @@ from typing import Optional, Union
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch_npu
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from vllm.distributed import split_tensor_along_last_dim
from vllm.distributed import (split_tensor_along_last_dim,
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter)
from vllm.distributed.parallel_state import get_tp_group
from vllm.forward_context import get_forward_context
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
get_otp_group)
@@ -375,12 +379,83 @@ class SequenceRowParallelOp(CustomRowParallelOp):
def matmul_and_reduce(self, input_parallel: torch.Tensor,
bias_: Optional[Parameter]) -> torch.Tensor:
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self.layer,
input_parallel,
bias=bias_)
from vllm_ascend.ops.register_custom_ops import \
_maybe_pad_and_reduce_impl
output = _maybe_pad_and_reduce_impl(output_parallel)
try:
forward_context = get_forward_context()
sp_enabled = forward_context.sp_enabled
except AssertionError:
sp_enabled = False
x = input_parallel
if not sp_enabled:
output_parallel = self.layer.quant_method.apply(self.layer,
x,
bias=bias_)
return tensor_model_parallel_all_reduce(output_parallel)
pad_size = forward_context.pad_size
if pad_size > 0:
x = F.pad(x, (0, 0, 0, pad_size))
world_size = self.layer.tp_size
comm_mode = "aiv"
hcom_name = get_tp_group().device_group._get_backend(
torch.device('npu')).get_hccl_comm_name(self.layer.tp_rank)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.quantization.w8a8 import (AscendW8A8LinearMethod,
quant_per_tensor)
# For unquant
if isinstance(self.layer.quant_method, UnquantizedLinearMethod
) and torch.version.cann.startswith("8.3"):
output = torch_npu.npu_mm_reduce_scatter_base(
x,
self.layer.weight.t(),
hcom_name,
world_size,
reduce_op="sum",
bias=None,
comm_turn=0,
comm_mode=comm_mode)
if bias_ is not None:
output.add_(bias_)
# For w8a8 quant
elif (isinstance(self.layer.quant_method, AscendLinearMethod)
and isinstance(self.layer.quant_method.quant_method,
AscendW8A8LinearMethod)
) and torch.version.cann.startswith("8.3"):
if x.dtype != torch.int8:
x_quant = quant_per_tensor(
x, self.layer.aclnn_input_scale_reciprocal,
self.layer.aclnn_input_offset)
else:
x_quant = x
quant_bias = self.layer.quant_bias
deq_scale = self.layer.deq_scale
output_dtype = torch.bfloat16
output = torch_npu.npu_mm_reduce_scatter_base(
x_quant,
self.layer.weight,
hcom_name,
world_size,
reduce_op="sum",
bias=None,
comm_turn=0,
x2_scale=deq_scale,
output_dtype=output_dtype,
comm_mode=comm_mode)
output = torch.add(
output,
torch.mul(quant_bias, deq_scale).to(self.layer.params_dtype))
else:
output_parallel = self.layer.quant_method.apply(self.layer,
x,
bias=bias_)
output = tensor_model_parallel_reduce_scatter(output_parallel, 0)
return output
def update_attrs(self):