2025-09-18 14:09:19 +08:00
|
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
|
|
|
# This file is a part of the vllm-ascend project.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
"""
|
|
|
|
|
This file extends the functionality of linear operations by encapsulating custom
|
|
|
|
|
communication groups and forward functions into classes (linear ops).
|
|
|
|
|
|
|
|
|
|
Current class inheritance structure:
|
2025-10-14 17:39:26 +08:00
|
|
|
CustomLinearOp
|
2025-09-18 14:09:19 +08:00
|
|
|
├── CustomColumnParallelOp
|
|
|
|
|
│ ├── MLPColumnParallelOp
|
2025-10-13 23:02:12 +08:00
|
|
|
│ ├── SequenceColumnParallelOp
|
2025-09-18 14:09:19 +08:00
|
|
|
└── CustomRowParallelOp
|
2025-10-14 17:39:26 +08:00
|
|
|
│ ├── MLPRowParallelOp
|
|
|
|
|
│ ├── OProjRowParallelOp
|
|
|
|
|
│ ├── MatmulAllreduceRowParallelOp
|
|
|
|
|
│ └── SequenceRowParallelOp
|
|
|
|
|
└── CustomReplicatedOp
|
2025-09-18 14:09:19 +08:00
|
|
|
How to extend a new linear op? Taking column parallel op as an example:
|
|
|
|
|
1. Inherit from CustomColumnParallelOp and create a new class MyColumnParallelOp
|
|
|
|
|
2. [Optional] The default communication group is the TP group. If a custom communication group is needed, override the comm_group method
|
|
|
|
|
3. Override the apply method according to requirements, which will replace the original linear.forward
|
|
|
|
|
4. Add selection logic for MyColumnParallelOp in the get_column_parallel_op method, typically based on prefix and configuration judgments
|
|
|
|
|
Row parallel op follows a similar approach - inherit from RowColumnParallelOp and register the new class in get_row_parallel_op.
|
|
|
|
|
"""
|
|
|
|
|
|
2025-10-13 23:02:12 +08:00
|
|
|
from typing import Optional, Union
|
2025-09-18 14:09:19 +08:00
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.distributed as dist
|
[0.11.0][Perf] Add fused matmul/reduce-scatter kernel for performance optimization. (#3725)
### What this PR does / why we need it?
This PR boosts performance by introducing a fused kernel for the matrix
matmul and reduce scatter operations. It supports both unquantized
(e.g., BFloat16) and W8A8 quantized models.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: ZYang6263 <zy626375@gmail.com>
2025-10-25 08:20:43 +08:00
|
|
|
import torch.nn.functional as F
|
2025-09-18 14:09:19 +08:00
|
|
|
import torch_npu
|
|
|
|
|
from torch.distributed import ProcessGroup
|
|
|
|
|
from torch.nn.parameter import Parameter
|
[0.11.0][Perf] Add fused matmul/reduce-scatter kernel for performance optimization. (#3725)
### What this PR does / why we need it?
This PR boosts performance by introducing a fused kernel for the matrix
matmul and reduce scatter operations. It supports both unquantized
(e.g., BFloat16) and W8A8 quantized models.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: ZYang6263 <zy626375@gmail.com>
2025-10-25 08:20:43 +08:00
|
|
|
from vllm.distributed import (split_tensor_along_last_dim,
|
|
|
|
|
tensor_model_parallel_all_reduce,
|
|
|
|
|
tensor_model_parallel_reduce_scatter)
|
2025-09-18 14:09:19 +08:00
|
|
|
from vllm.distributed.parallel_state import get_tp_group
|
[0.11.0][Perf] Add fused matmul/reduce-scatter kernel for performance optimization. (#3725)
### What this PR does / why we need it?
This PR boosts performance by introducing a fused kernel for the matrix
matmul and reduce scatter operations. It supports both unquantized
(e.g., BFloat16) and W8A8 quantized models.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: ZYang6263 <zy626375@gmail.com>
2025-10-25 08:20:43 +08:00
|
|
|
from vllm.forward_context import get_forward_context
|
2025-09-18 14:09:19 +08:00
|
|
|
|
|
|
|
|
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
|
|
|
|
|
get_otp_group)
|
2025-09-24 11:29:59 +08:00
|
|
|
from vllm_ascend.utils import (dense_optim_enable, enable_sp,
|
|
|
|
|
matmul_allreduce_enable, mlp_tp_enable,
|
2025-10-15 19:36:32 +08:00
|
|
|
oproj_tp_enable, shared_expert_dp_enabled)
|
2025-09-18 14:09:19 +08:00
|
|
|
|
|
|
|
|
|
2025-10-14 17:39:26 +08:00
|
|
|
class CustomLinearOp:
|
2025-09-18 14:09:19 +08:00
|
|
|
|
|
|
|
|
def __init__(self, layer):
|
|
|
|
|
self.layer = layer
|
|
|
|
|
self.bias = None
|
|
|
|
|
self.skip_bias_add = None
|
|
|
|
|
self.return_bias = None
|
|
|
|
|
self.quant_method = None
|
|
|
|
|
|
|
|
|
|
# Custom communication group, while determining weight sharding
|
|
|
|
|
@property
|
|
|
|
|
def comm_group(self):
|
|
|
|
|
return get_tp_group()
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def tp_rank(self):
|
|
|
|
|
return self.comm_group.rank_in_group
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def tp_size(self):
|
|
|
|
|
return self.comm_group.world_size
|
|
|
|
|
|
|
|
|
|
# Update the attributes required by apply(), obtaining them from the layer.
|
|
|
|
|
# Call this after the layer completes its initialization, specifically at the end of layer.init().
|
|
|
|
|
def update_attrs(self):
|
|
|
|
|
if hasattr(self.layer, "bias"):
|
|
|
|
|
self.bias = self.layer.bias
|
|
|
|
|
self.skip_bias_add = self.layer.skip_bias_add
|
|
|
|
|
self.return_bias = self.layer.return_bias
|
|
|
|
|
self.quant_method = self.layer.quant_method
|
2025-09-24 11:29:59 +08:00
|
|
|
self.prefix = self.layer.prefix
|
|
|
|
|
|
|
|
|
|
def apply_impl(self, input_):
|
|
|
|
|
raise NotImplementedError
|
2025-09-18 14:09:19 +08:00
|
|
|
|
|
|
|
|
# Replace layer.forward to customize the layer computation process.
|
|
|
|
|
def apply(self, input_):
|
2025-09-24 11:29:59 +08:00
|
|
|
output, output_bias = self.apply_impl(input_)
|
|
|
|
|
if not self.return_bias:
|
|
|
|
|
return output
|
|
|
|
|
return output, output_bias
|
2025-09-18 14:09:19 +08:00
|
|
|
|
|
|
|
|
|
2025-10-14 17:39:26 +08:00
|
|
|
class CustomColumnParallelOp(CustomLinearOp):
|
2025-09-18 14:09:19 +08:00
|
|
|
|
|
|
|
|
def __init__(self, layer):
|
|
|
|
|
super().__init__(layer)
|
|
|
|
|
self.gather_output = None
|
|
|
|
|
|
|
|
|
|
def update_attrs(self):
|
|
|
|
|
super().update_attrs()
|
|
|
|
|
self.gather_output = self.layer.gather_output
|
|
|
|
|
|
|
|
|
|
|
2025-10-14 17:39:26 +08:00
|
|
|
class CustomRowParallelOp(CustomLinearOp):
|
2025-09-18 14:09:19 +08:00
|
|
|
|
|
|
|
|
def __init__(self, layer):
|
|
|
|
|
super().__init__(layer)
|
|
|
|
|
self.reduce_results = None
|
|
|
|
|
self.input_is_parallel = None
|
|
|
|
|
self.input_size_per_partition = None
|
|
|
|
|
|
|
|
|
|
def update_attrs(self):
|
|
|
|
|
super().update_attrs()
|
|
|
|
|
self.input_is_parallel = self.layer.input_is_parallel
|
|
|
|
|
self.reduce_results = self.layer.reduce_results
|
|
|
|
|
self.input_size_per_partition = self.layer.input_size_per_partition
|
|
|
|
|
|
2025-09-24 11:29:59 +08:00
|
|
|
def apply(self, input_):
|
|
|
|
|
output, output_bias = self.apply_impl(input_)
|
|
|
|
|
if dense_optim_enable():
|
|
|
|
|
torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix)
|
|
|
|
|
if not self.return_bias:
|
|
|
|
|
return output
|
|
|
|
|
return output, output_bias
|
|
|
|
|
|
2025-09-18 14:09:19 +08:00
|
|
|
|
2025-10-14 17:39:26 +08:00
|
|
|
class CustomReplicatedOp(CustomLinearOp):
|
|
|
|
|
|
|
|
|
|
def apply_impl(self, input_):
|
|
|
|
|
bias = self.bias if not self.skip_bias_add else None
|
|
|
|
|
assert self.quant_method is not None
|
|
|
|
|
|
|
|
|
|
output = self.quant_method.apply(self.layer, input_, bias)
|
|
|
|
|
output_bias = self.bias if self.skip_bias_add else None
|
|
|
|
|
|
|
|
|
|
return output, output_bias
|
|
|
|
|
|
|
|
|
|
|
2025-09-18 14:09:19 +08:00
|
|
|
class MLPColumnParallelOp(CustomColumnParallelOp):
|
|
|
|
|
|
|
|
|
|
def __init__(self, layer):
|
|
|
|
|
super().__init__(layer)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def comm_group(self):
|
|
|
|
|
return get_mlp_tp_group()
|
|
|
|
|
|
2025-09-24 11:29:59 +08:00
|
|
|
def apply_impl(
|
2025-09-18 14:09:19 +08:00
|
|
|
self,
|
|
|
|
|
input_: torch.Tensor,
|
|
|
|
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
|
|
|
|
bias = self.bias if not self.skip_bias_add else None
|
|
|
|
|
# Matrix multiply.
|
|
|
|
|
assert self.quant_method is not None
|
|
|
|
|
input_parallel = self.comm_group.all_gather(input_, 0)
|
|
|
|
|
output = self.quant_method.apply(self.layer, input_parallel, bias)
|
|
|
|
|
|
|
|
|
|
output_bias = self.bias if self.skip_bias_add else None
|
|
|
|
|
return output, output_bias
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MLPRowParallelOp(CustomRowParallelOp):
|
|
|
|
|
|
|
|
|
|
def __init__(self, layer):
|
|
|
|
|
super().__init__(layer)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def comm_group(self):
|
|
|
|
|
return get_mlp_tp_group()
|
|
|
|
|
|
2025-09-24 11:29:59 +08:00
|
|
|
def apply_impl(
|
2025-09-18 14:09:19 +08:00
|
|
|
self, input_: torch.Tensor
|
|
|
|
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
|
|
|
|
if self.input_is_parallel:
|
|
|
|
|
input_parallel = input_
|
|
|
|
|
else:
|
|
|
|
|
splitted_input = split_tensor_along_last_dim(
|
|
|
|
|
input_, num_partitions=self.tp_size)
|
|
|
|
|
input_parallel = splitted_input[self.tp_rank].contiguous()
|
|
|
|
|
|
|
|
|
|
assert self.quant_method is not None
|
|
|
|
|
bias_ = None if (self.tp_rank > 0
|
|
|
|
|
or self.skip_bias_add) else self.layer.bias
|
|
|
|
|
output_parallel = self.quant_method.apply(self.layer,
|
|
|
|
|
input_parallel,
|
|
|
|
|
bias=bias_)
|
|
|
|
|
output = self.comm_group.reduce_scatter(output_parallel, 0)
|
|
|
|
|
|
|
|
|
|
output_bias = self.bias if self.skip_bias_add else None
|
|
|
|
|
return output, output_bias
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OProjRowParallelOp(CustomRowParallelOp):
|
|
|
|
|
|
|
|
|
|
def __init__(self, layer):
|
|
|
|
|
super().__init__(layer)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def comm_group(self):
|
|
|
|
|
return get_otp_group()
|
|
|
|
|
|
2025-09-24 11:29:59 +08:00
|
|
|
def apply_impl(
|
2025-09-18 14:09:19 +08:00
|
|
|
self,
|
|
|
|
|
input_: torch.Tensor,
|
|
|
|
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
|
|
|
|
|
|
|
|
|
if self.input_is_parallel:
|
|
|
|
|
input_parallel = input_
|
|
|
|
|
else:
|
|
|
|
|
splitted_input = split_tensor_along_last_dim(
|
|
|
|
|
input_, num_partitions=self.tp_size)
|
|
|
|
|
input_parallel = splitted_input[self.tp_rank].contiguous()
|
|
|
|
|
|
|
|
|
|
# Prepare tensors for all-to-all communication
|
|
|
|
|
local_batch_size = input_parallel.size(0)
|
|
|
|
|
chunk_size = self.input_size_per_partition
|
|
|
|
|
total_batch_size = local_batch_size * self.tp_size
|
|
|
|
|
|
|
|
|
|
# Reshape tensor for efficient cross-device transfer:
|
|
|
|
|
# [batch, dim] -> [tp_size, batch, chunk] -> flattened
|
|
|
|
|
send_buf = (input_parallel.reshape(-1,
|
|
|
|
|
self.tp_size, chunk_size).transpose(
|
|
|
|
|
0, 1).contiguous().view(-1))
|
|
|
|
|
|
|
|
|
|
# Create receive buffer
|
|
|
|
|
recv_buf = torch.empty(total_batch_size * chunk_size,
|
|
|
|
|
dtype=input_parallel.dtype,
|
|
|
|
|
device=input_parallel.device)
|
|
|
|
|
|
|
|
|
|
# Perform all-to-all communication
|
|
|
|
|
dist.all_to_all_single(recv_buf,
|
|
|
|
|
send_buf,
|
|
|
|
|
group=self.comm_group.device_group)
|
|
|
|
|
input_parallel = recv_buf.view(total_batch_size, chunk_size)
|
|
|
|
|
|
|
|
|
|
# Only fuse bias add for rank 0 to avoid duplicate bias addition in TP>1
|
|
|
|
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
|
|
|
|
assert self.quant_method is not None
|
|
|
|
|
output_parallel = self.quant_method.apply(self.layer,
|
|
|
|
|
input_parallel,
|
|
|
|
|
bias=bias_)
|
|
|
|
|
|
|
|
|
|
# otp-specific: Combine partial results across devices
|
|
|
|
|
output = self.comm_group.reduce_scatter(output_parallel, dim=0)
|
2025-09-24 18:44:15 +08:00
|
|
|
output = output.view(input_.shape[0], self.layer.output_size)
|
2025-09-18 14:09:19 +08:00
|
|
|
|
|
|
|
|
# Handle bias return based on configuration
|
|
|
|
|
output_bias = self.bias if self.skip_bias_add else None
|
|
|
|
|
return output, output_bias
|
|
|
|
|
|
|
|
|
|
def update_attrs(self):
|
|
|
|
|
super().update_attrs()
|
|
|
|
|
self.input_is_parallel = self.layer.input_is_parallel
|
|
|
|
|
self.input_size_per_partition = self.layer.input_size_per_partition
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
|
|
|
|
|
_HCOMM_INFO = None
|
|
|
|
|
|
|
|
|
|
def __init__(self, layer):
|
|
|
|
|
super().__init__(layer)
|
|
|
|
|
self.hcomm_info = self.get_hcomm_info(self.comm_group.device_group)
|
|
|
|
|
|
2025-09-24 11:29:59 +08:00
|
|
|
def apply_impl(
|
2025-09-18 14:09:19 +08:00
|
|
|
self, input_: torch.Tensor
|
|
|
|
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
|
|
|
|
if self.input_is_parallel:
|
|
|
|
|
input_parallel = input_
|
|
|
|
|
else:
|
|
|
|
|
splitted_input = split_tensor_along_last_dim(
|
|
|
|
|
input_, num_partitions=self.tp_size)
|
|
|
|
|
input_parallel = splitted_input[self.tp_rank].contiguous()
|
|
|
|
|
"""Calculate the output tensor of forward by considering
|
|
|
|
|
fusing communication and computation."""
|
|
|
|
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
|
|
|
|
if self.reduce_results and self.tp_size > 1:
|
|
|
|
|
output = torch_npu.npu_mm_all_reduce_base(input_parallel,
|
|
|
|
|
self.weight_t,
|
|
|
|
|
self.hcomm_info,
|
|
|
|
|
bias=bias_)
|
|
|
|
|
else:
|
|
|
|
|
assert self.quant_method is not None
|
|
|
|
|
output = self.quant_method.apply(self.layer,
|
|
|
|
|
input_parallel,
|
|
|
|
|
bias=bias_)
|
|
|
|
|
|
|
|
|
|
output_bias = self.bias if self.skip_bias_add else None
|
|
|
|
|
return output, output_bias
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_hcomm_info(cls, group: ProcessGroup) -> str:
|
|
|
|
|
"""Get the HCCL communication information for the given group."""
|
|
|
|
|
if cls._HCOMM_INFO is not None:
|
|
|
|
|
return cls._HCOMM_INFO
|
|
|
|
|
|
|
|
|
|
rank = torch.distributed.get_rank(group)
|
|
|
|
|
if torch.__version__ > "2.0":
|
|
|
|
|
global_rank = torch.distributed.get_global_rank(group, rank)
|
|
|
|
|
cls._HCOMM_INFO = group._get_backend(
|
|
|
|
|
torch.device("npu")).get_hccl_comm_name(global_rank)
|
|
|
|
|
else:
|
|
|
|
|
cls._HCOMM_INFO = group.get_hccl_comm_name(rank)
|
|
|
|
|
return cls._HCOMM_INFO
|
|
|
|
|
|
|
|
|
|
def update_attrs(self):
|
|
|
|
|
super().update_attrs()
|
|
|
|
|
self.weight_t = self.layer.weight.t()
|
|
|
|
|
|
|
|
|
|
|
2025-10-13 23:02:12 +08:00
|
|
|
class SequenceColumnParallelOp(CustomColumnParallelOp):
|
2025-09-18 14:09:19 +08:00
|
|
|
|
2025-10-13 23:02:12 +08:00
|
|
|
def apply_impl(
|
|
|
|
|
self, input_: torch.Tensor
|
|
|
|
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
|
|
|
|
"""Linear layer with column parallelism.
|
|
|
|
|
|
|
|
|
|
Implemented multiple optimization projects for dense models, such as FlashComm and
|
|
|
|
|
communication-computation fusion.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
bias = self.bias if not self.skip_bias_add else None
|
|
|
|
|
|
|
|
|
|
# Matrix multiply.
|
|
|
|
|
assert self.quant_method is not None
|
|
|
|
|
|
|
|
|
|
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
|
|
|
|
|
output_parallel = self.quant_method.apply(self.layer, input_, bias)
|
|
|
|
|
|
|
|
|
|
if self.gather_output:
|
|
|
|
|
# All-gather across the partitions.
|
|
|
|
|
output = self.comm_group.all_gather(output_parallel)
|
|
|
|
|
else:
|
|
|
|
|
output = output_parallel
|
|
|
|
|
output_bias = self.bias if self.skip_bias_add else None
|
|
|
|
|
return output, output_bias
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SequenceRowParallelOp(CustomRowParallelOp):
|
2025-09-18 14:09:19 +08:00
|
|
|
|
2025-09-24 11:29:59 +08:00
|
|
|
def apply_impl(
|
2025-09-18 14:09:19 +08:00
|
|
|
self, input_: torch.Tensor
|
|
|
|
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
|
|
|
|
"""Linear layer with column parallelism.
|
|
|
|
|
|
|
|
|
|
Implemented multiple optimization projects for dense models, such as FlashComm and
|
|
|
|
|
communication-computation fusion.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if self.input_is_parallel:
|
|
|
|
|
input_parallel = input_
|
|
|
|
|
else:
|
|
|
|
|
splitted_input = split_tensor_along_last_dim(
|
|
|
|
|
input_, num_partitions=self.tp_size)
|
|
|
|
|
input_parallel = splitted_input[self.tp_rank].contiguous()
|
|
|
|
|
|
|
|
|
|
assert self.quant_method is not None
|
|
|
|
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
|
|
|
|
|
|
|
|
|
if self.tp_size == 1 or not self.reduce_results:
|
2025-09-26 10:55:32 +08:00
|
|
|
output = self.quant_method.apply(self.layer,
|
|
|
|
|
input_parallel,
|
|
|
|
|
bias=bias_)
|
2025-09-18 14:09:19 +08:00
|
|
|
else:
|
2025-10-23 14:45:49 +08:00
|
|
|
output = torch.ops.vllm.matmul_and_reduce(input_parallel,
|
|
|
|
|
self.prefix)
|
2025-09-18 14:09:19 +08:00
|
|
|
|
|
|
|
|
output_bias = self.bias if self.skip_bias_add else None
|
|
|
|
|
return output, output_bias
|
|
|
|
|
|
2025-10-23 14:45:49 +08:00
|
|
|
def matmul_and_reduce(self, input_parallel: torch.Tensor,
|
|
|
|
|
bias_: Optional[Parameter]) -> torch.Tensor:
|
|
|
|
|
assert self.quant_method is not None
|
[0.11.0][Perf] Add fused matmul/reduce-scatter kernel for performance optimization. (#3725)
### What this PR does / why we need it?
This PR boosts performance by introducing a fused kernel for the matrix
matmul and reduce scatter operations. It supports both unquantized
(e.g., BFloat16) and W8A8 quantized models.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: ZYang6263 <zy626375@gmail.com>
2025-10-25 08:20:43 +08:00
|
|
|
try:
|
|
|
|
|
forward_context = get_forward_context()
|
|
|
|
|
sp_enabled = forward_context.sp_enabled
|
2025-10-28 23:31:19 +08:00
|
|
|
mmrs_fusion = forward_context.mmrs_fusion
|
[0.11.0][Perf] Add fused matmul/reduce-scatter kernel for performance optimization. (#3725)
### What this PR does / why we need it?
This PR boosts performance by introducing a fused kernel for the matrix
matmul and reduce scatter operations. It supports both unquantized
(e.g., BFloat16) and W8A8 quantized models.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: ZYang6263 <zy626375@gmail.com>
2025-10-25 08:20:43 +08:00
|
|
|
except AssertionError:
|
|
|
|
|
sp_enabled = False
|
2025-10-28 23:31:19 +08:00
|
|
|
mmrs_fusion = False
|
[0.11.0][Perf] Add fused matmul/reduce-scatter kernel for performance optimization. (#3725)
### What this PR does / why we need it?
This PR boosts performance by introducing a fused kernel for the matrix
matmul and reduce scatter operations. It supports both unquantized
(e.g., BFloat16) and W8A8 quantized models.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: ZYang6263 <zy626375@gmail.com>
2025-10-25 08:20:43 +08:00
|
|
|
|
|
|
|
|
x = input_parallel
|
|
|
|
|
|
|
|
|
|
if not sp_enabled:
|
|
|
|
|
output_parallel = self.layer.quant_method.apply(self.layer,
|
|
|
|
|
x,
|
|
|
|
|
bias=bias_)
|
|
|
|
|
return tensor_model_parallel_all_reduce(output_parallel)
|
|
|
|
|
|
|
|
|
|
pad_size = forward_context.pad_size
|
|
|
|
|
if pad_size > 0:
|
|
|
|
|
x = F.pad(x, (0, 0, 0, pad_size))
|
|
|
|
|
|
|
|
|
|
world_size = self.layer.tp_size
|
|
|
|
|
comm_mode = "aiv"
|
|
|
|
|
hcom_name = get_tp_group().device_group._get_backend(
|
|
|
|
|
torch.device('npu')).get_hccl_comm_name(self.layer.tp_rank)
|
|
|
|
|
|
|
|
|
|
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
|
|
|
|
|
|
|
|
|
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
|
|
|
|
from vllm_ascend.quantization.w8a8 import (AscendW8A8LinearMethod,
|
|
|
|
|
quant_per_tensor)
|
|
|
|
|
|
|
|
|
|
# For unquant
|
2025-11-06 09:05:08 +08:00
|
|
|
if mmrs_fusion and isinstance(self.layer.quant_method,
|
|
|
|
|
UnquantizedLinearMethod):
|
[0.11.0][Perf] Add fused matmul/reduce-scatter kernel for performance optimization. (#3725)
### What this PR does / why we need it?
This PR boosts performance by introducing a fused kernel for the matrix
matmul and reduce scatter operations. It supports both unquantized
(e.g., BFloat16) and W8A8 quantized models.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: ZYang6263 <zy626375@gmail.com>
2025-10-25 08:20:43 +08:00
|
|
|
output = torch_npu.npu_mm_reduce_scatter_base(
|
|
|
|
|
x,
|
|
|
|
|
self.layer.weight.t(),
|
|
|
|
|
hcom_name,
|
|
|
|
|
world_size,
|
|
|
|
|
reduce_op="sum",
|
|
|
|
|
bias=None,
|
|
|
|
|
comm_turn=0,
|
|
|
|
|
comm_mode=comm_mode)
|
|
|
|
|
if bias_ is not None:
|
|
|
|
|
output.add_(bias_)
|
|
|
|
|
# For w8a8 quant
|
2025-10-28 23:31:19 +08:00
|
|
|
elif mmrs_fusion and (
|
|
|
|
|
isinstance(self.layer.quant_method, AscendLinearMethod)
|
|
|
|
|
and isinstance(self.layer.quant_method.quant_method,
|
2025-11-06 09:05:08 +08:00
|
|
|
AscendW8A8LinearMethod)):
|
[0.11.0][Perf] Add fused matmul/reduce-scatter kernel for performance optimization. (#3725)
### What this PR does / why we need it?
This PR boosts performance by introducing a fused kernel for the matrix
matmul and reduce scatter operations. It supports both unquantized
(e.g., BFloat16) and W8A8 quantized models.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: ZYang6263 <zy626375@gmail.com>
2025-10-25 08:20:43 +08:00
|
|
|
if x.dtype != torch.int8:
|
|
|
|
|
x_quant = quant_per_tensor(
|
|
|
|
|
x, self.layer.aclnn_input_scale_reciprocal,
|
|
|
|
|
self.layer.aclnn_input_offset)
|
|
|
|
|
else:
|
|
|
|
|
x_quant = x
|
|
|
|
|
quant_bias = self.layer.quant_bias
|
|
|
|
|
deq_scale = self.layer.deq_scale
|
|
|
|
|
output_dtype = torch.bfloat16
|
|
|
|
|
output = torch_npu.npu_mm_reduce_scatter_base(
|
|
|
|
|
x_quant,
|
|
|
|
|
self.layer.weight,
|
|
|
|
|
hcom_name,
|
|
|
|
|
world_size,
|
|
|
|
|
reduce_op="sum",
|
|
|
|
|
bias=None,
|
|
|
|
|
comm_turn=0,
|
|
|
|
|
x2_scale=deq_scale,
|
|
|
|
|
output_dtype=output_dtype,
|
|
|
|
|
comm_mode=comm_mode)
|
|
|
|
|
output = torch.add(
|
|
|
|
|
output,
|
|
|
|
|
torch.mul(quant_bias, deq_scale).to(self.layer.params_dtype))
|
|
|
|
|
else:
|
|
|
|
|
output_parallel = self.layer.quant_method.apply(self.layer,
|
|
|
|
|
x,
|
|
|
|
|
bias=bias_)
|
|
|
|
|
output = tensor_model_parallel_reduce_scatter(output_parallel, 0)
|
|
|
|
|
|
2025-10-23 14:45:49 +08:00
|
|
|
return output
|
|
|
|
|
|
2025-09-18 14:09:19 +08:00
|
|
|
def update_attrs(self):
|
|
|
|
|
super().update_attrs()
|
|
|
|
|
self.input_is_parallel = self.layer.input_is_parallel
|
|
|
|
|
self.reduce_results = self.layer.reduce_results
|
|
|
|
|
|
|
|
|
|
|
2025-10-13 23:02:12 +08:00
|
|
|
def _get_column_parallel_op(
|
|
|
|
|
prefix, layer
|
|
|
|
|
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp]]:
|
|
|
|
|
if mlp_tp_enable() and "gate_up_proj" in prefix:
|
|
|
|
|
return MLPColumnParallelOp(layer)
|
|
|
|
|
if enable_sp():
|
|
|
|
|
if "shared_expert" in prefix:
|
|
|
|
|
return None
|
|
|
|
|
if "gate_up_proj" in prefix:
|
|
|
|
|
return SequenceColumnParallelOp(layer)
|
|
|
|
|
if "in_proj" in prefix:
|
|
|
|
|
return SequenceColumnParallelOp(layer)
|
|
|
|
|
if "qkv_proj" in prefix or "conv1d" in prefix:
|
|
|
|
|
return SequenceColumnParallelOp(layer)
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_row_parallel_op(
|
|
|
|
|
prefix, layer
|
|
|
|
|
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
|
|
|
|
|
MatmulAllreduceRowParallelOp, SequenceRowParallelOp]]:
|
|
|
|
|
if "down_proj" in prefix and mlp_tp_enable():
|
|
|
|
|
return MLPRowParallelOp(layer)
|
|
|
|
|
if "o_proj" in prefix and oproj_tp_enable():
|
|
|
|
|
return OProjRowParallelOp(layer)
|
|
|
|
|
if matmul_allreduce_enable():
|
|
|
|
|
return MatmulAllreduceRowParallelOp(layer)
|
|
|
|
|
if enable_sp():
|
|
|
|
|
if "shared_expert" in prefix:
|
|
|
|
|
return None
|
|
|
|
|
if "o_proj" in prefix or "out_proj" in prefix or "down_proj" in prefix:
|
|
|
|
|
return SequenceRowParallelOp(layer)
|
2025-09-18 14:09:19 +08:00
|
|
|
|
2025-10-13 23:02:12 +08:00
|
|
|
return None
|
2025-09-18 14:09:19 +08:00
|
|
|
|
|
|
|
|
|
2025-10-13 23:02:12 +08:00
|
|
|
def get_parallel_op(disable_tp, prefix, layer, direct):
|
2025-10-15 19:36:32 +08:00
|
|
|
if disable_tp or ("shared_experts" in prefix
|
|
|
|
|
and shared_expert_dp_enabled()):
|
2025-09-18 14:09:19 +08:00
|
|
|
return None, 0, 1
|
2025-10-13 23:02:12 +08:00
|
|
|
custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp,
|
|
|
|
|
MLPRowParallelOp, OProjRowParallelOp,
|
2025-09-18 14:09:19 +08:00
|
|
|
MatmulAllreduceRowParallelOp,
|
2025-09-24 11:29:59 +08:00
|
|
|
SequenceRowParallelOp]] = None
|
2025-10-13 23:02:12 +08:00
|
|
|
if direct == "row":
|
|
|
|
|
custom_op = _get_row_parallel_op(prefix, layer)
|
|
|
|
|
|
|
|
|
|
if direct == "column":
|
|
|
|
|
custom_op = _get_column_parallel_op(prefix, layer)
|
2025-09-18 14:09:19 +08:00
|
|
|
|
|
|
|
|
if custom_op is not None:
|
|
|
|
|
return custom_op, custom_op.tp_rank, custom_op.tp_size
|
|
|
|
|
|
|
|
|
|
return None, get_tp_group().rank_in_group, get_tp_group().world_size
|
2025-10-14 17:39:26 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_replicated_op(disable_tp, prefix,
|
|
|
|
|
layer) -> Optional[Union[CustomReplicatedOp]]:
|
|
|
|
|
if disable_tp:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
return CustomReplicatedOp(layer)
|