[Main][Perf] Add fused matmul/reduce-scatter kernel for performance optimization. (#3693)
### What this PR does / why we need it? This PR boosts performance by introducing a fused kernel for the matrix matmul and reduce scatter operations. It supports both unquantized (e.g., BFloat16) and W8A8 quantized models. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: ZYang6263 <zy626375@gmail.com>
This commit is contained in:
@@ -39,11 +39,15 @@ from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.distributed import split_tensor_along_last_dim
|
||||
from vllm.distributed import (split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_reduce_scatter)
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
|
||||
get_otp_group)
|
||||
@@ -375,12 +379,83 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
def matmul_and_reduce(self, input_parallel: torch.Tensor,
|
||||
bias_: Optional[Parameter]) -> torch.Tensor:
|
||||
assert self.quant_method is not None
|
||||
output_parallel = self.quant_method.apply(self.layer,
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
from vllm_ascend.ops.register_custom_ops import \
|
||||
_maybe_pad_and_reduce_impl
|
||||
output = _maybe_pad_and_reduce_impl(output_parallel)
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
except AssertionError:
|
||||
sp_enabled = False
|
||||
|
||||
x = input_parallel
|
||||
|
||||
if not sp_enabled:
|
||||
output_parallel = self.layer.quant_method.apply(self.layer,
|
||||
x,
|
||||
bias=bias_)
|
||||
return tensor_model_parallel_all_reduce(output_parallel)
|
||||
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
x = F.pad(x, (0, 0, 0, pad_size))
|
||||
|
||||
world_size = self.layer.tp_size
|
||||
comm_mode = "aiv"
|
||||
hcom_name = get_tp_group().device_group._get_backend(
|
||||
torch.device('npu')).get_hccl_comm_name(self.layer.tp_rank)
|
||||
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
|
||||
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
||||
from vllm_ascend.quantization.w8a8 import (AscendW8A8LinearMethod,
|
||||
quant_per_tensor)
|
||||
|
||||
# For unquant
|
||||
if isinstance(self.layer.quant_method, UnquantizedLinearMethod
|
||||
) and torch.version.cann.startswith("8.3"):
|
||||
output = torch_npu.npu_mm_reduce_scatter_base(
|
||||
x,
|
||||
self.layer.weight.t(),
|
||||
hcom_name,
|
||||
world_size,
|
||||
reduce_op="sum",
|
||||
bias=None,
|
||||
comm_turn=0,
|
||||
comm_mode=comm_mode)
|
||||
if bias_ is not None:
|
||||
output.add_(bias_)
|
||||
# For w8a8 quant
|
||||
elif (isinstance(self.layer.quant_method, AscendLinearMethod)
|
||||
and isinstance(self.layer.quant_method.quant_method,
|
||||
AscendW8A8LinearMethod)
|
||||
) and torch.version.cann.startswith("8.3"):
|
||||
if x.dtype != torch.int8:
|
||||
x_quant = quant_per_tensor(
|
||||
x, self.layer.aclnn_input_scale_reciprocal,
|
||||
self.layer.aclnn_input_offset)
|
||||
else:
|
||||
x_quant = x
|
||||
quant_bias = self.layer.quant_bias
|
||||
deq_scale = self.layer.deq_scale
|
||||
output_dtype = torch.bfloat16
|
||||
output = torch_npu.npu_mm_reduce_scatter_base(
|
||||
x_quant,
|
||||
self.layer.weight,
|
||||
hcom_name,
|
||||
world_size,
|
||||
reduce_op="sum",
|
||||
bias=None,
|
||||
comm_turn=0,
|
||||
x2_scale=deq_scale,
|
||||
output_dtype=output_dtype,
|
||||
comm_mode=comm_mode)
|
||||
output = torch.add(
|
||||
output,
|
||||
torch.mul(quant_bias, deq_scale).to(self.layer.params_dtype))
|
||||
else:
|
||||
output_parallel = self.layer.quant_method.apply(self.layer,
|
||||
x,
|
||||
bias=bias_)
|
||||
output = tensor_model_parallel_reduce_scatter(output_parallel, 0)
|
||||
|
||||
return output
|
||||
|
||||
def update_attrs(self):
|
||||
|
||||
Reference in New Issue
Block a user