[Main][Perf] Add fused matmul/reduce-scatter kernel for performance optimization. (#3693)
### What this PR does / why we need it? This PR boosts performance by introducing a fused kernel for the matrix matmul and reduce scatter operations. It supports both unquantized (e.g., BFloat16) and W8A8 quantized models. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: ZYang6263 <zy626375@gmail.com>
This commit is contained in:
@@ -39,11 +39,15 @@ from typing import Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
import torch.nn.functional as F
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
from vllm.distributed import split_tensor_along_last_dim
|
from vllm.distributed import (split_tensor_along_last_dim,
|
||||||
|
tensor_model_parallel_all_reduce,
|
||||||
|
tensor_model_parallel_reduce_scatter)
|
||||||
from vllm.distributed.parallel_state import get_tp_group
|
from vllm.distributed.parallel_state import get_tp_group
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
|
|
||||||
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
|
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
|
||||||
get_otp_group)
|
get_otp_group)
|
||||||
@@ -375,12 +379,83 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
|||||||
def matmul_and_reduce(self, input_parallel: torch.Tensor,
|
def matmul_and_reduce(self, input_parallel: torch.Tensor,
|
||||||
bias_: Optional[Parameter]) -> torch.Tensor:
|
bias_: Optional[Parameter]) -> torch.Tensor:
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
output_parallel = self.quant_method.apply(self.layer,
|
try:
|
||||||
input_parallel,
|
forward_context = get_forward_context()
|
||||||
bias=bias_)
|
sp_enabled = forward_context.sp_enabled
|
||||||
from vllm_ascend.ops.register_custom_ops import \
|
except AssertionError:
|
||||||
_maybe_pad_and_reduce_impl
|
sp_enabled = False
|
||||||
output = _maybe_pad_and_reduce_impl(output_parallel)
|
|
||||||
|
x = input_parallel
|
||||||
|
|
||||||
|
if not sp_enabled:
|
||||||
|
output_parallel = self.layer.quant_method.apply(self.layer,
|
||||||
|
x,
|
||||||
|
bias=bias_)
|
||||||
|
return tensor_model_parallel_all_reduce(output_parallel)
|
||||||
|
|
||||||
|
pad_size = forward_context.pad_size
|
||||||
|
if pad_size > 0:
|
||||||
|
x = F.pad(x, (0, 0, 0, pad_size))
|
||||||
|
|
||||||
|
world_size = self.layer.tp_size
|
||||||
|
comm_mode = "aiv"
|
||||||
|
hcom_name = get_tp_group().device_group._get_backend(
|
||||||
|
torch.device('npu')).get_hccl_comm_name(self.layer.tp_rank)
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||||
|
|
||||||
|
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
||||||
|
from vllm_ascend.quantization.w8a8 import (AscendW8A8LinearMethod,
|
||||||
|
quant_per_tensor)
|
||||||
|
|
||||||
|
# For unquant
|
||||||
|
if isinstance(self.layer.quant_method, UnquantizedLinearMethod
|
||||||
|
) and torch.version.cann.startswith("8.3"):
|
||||||
|
output = torch_npu.npu_mm_reduce_scatter_base(
|
||||||
|
x,
|
||||||
|
self.layer.weight.t(),
|
||||||
|
hcom_name,
|
||||||
|
world_size,
|
||||||
|
reduce_op="sum",
|
||||||
|
bias=None,
|
||||||
|
comm_turn=0,
|
||||||
|
comm_mode=comm_mode)
|
||||||
|
if bias_ is not None:
|
||||||
|
output.add_(bias_)
|
||||||
|
# For w8a8 quant
|
||||||
|
elif (isinstance(self.layer.quant_method, AscendLinearMethod)
|
||||||
|
and isinstance(self.layer.quant_method.quant_method,
|
||||||
|
AscendW8A8LinearMethod)
|
||||||
|
) and torch.version.cann.startswith("8.3"):
|
||||||
|
if x.dtype != torch.int8:
|
||||||
|
x_quant = quant_per_tensor(
|
||||||
|
x, self.layer.aclnn_input_scale_reciprocal,
|
||||||
|
self.layer.aclnn_input_offset)
|
||||||
|
else:
|
||||||
|
x_quant = x
|
||||||
|
quant_bias = self.layer.quant_bias
|
||||||
|
deq_scale = self.layer.deq_scale
|
||||||
|
output_dtype = torch.bfloat16
|
||||||
|
output = torch_npu.npu_mm_reduce_scatter_base(
|
||||||
|
x_quant,
|
||||||
|
self.layer.weight,
|
||||||
|
hcom_name,
|
||||||
|
world_size,
|
||||||
|
reduce_op="sum",
|
||||||
|
bias=None,
|
||||||
|
comm_turn=0,
|
||||||
|
x2_scale=deq_scale,
|
||||||
|
output_dtype=output_dtype,
|
||||||
|
comm_mode=comm_mode)
|
||||||
|
output = torch.add(
|
||||||
|
output,
|
||||||
|
torch.mul(quant_bias, deq_scale).to(self.layer.params_dtype))
|
||||||
|
else:
|
||||||
|
output_parallel = self.layer.quant_method.apply(self.layer,
|
||||||
|
x,
|
||||||
|
bias=bias_)
|
||||||
|
output = tensor_model_parallel_reduce_scatter(output_parallel, 0)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def update_attrs(self):
|
def update_attrs(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user