Files
xc-llm-ascend/vllm_ascend/ops/linear_op.py

532 lines
20 KiB
Python
Raw Normal View History

# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file extends the functionality of linear operations by encapsulating custom
communication groups and forward functions into classes (linear ops).
Current class inheritance structure:
CustomLinearOp
CustomColumnParallelOp
MLPColumnParallelOp
SequenceColumnParallelOp
CustomRowParallelOp
MLPRowParallelOp
OProjRowParallelOp
MatmulAllreduceRowParallelOp
SequenceRowParallelOp
CustomReplicatedOp
How to extend a new linear op? Taking column parallel op as an example:
1. Inherit from CustomColumnParallelOp and create a new class MyColumnParallelOp
2. [Optional] The default communication group is the TP group. If a custom communication group is needed, override the comm_group method
3. Override the apply method according to requirements, which will replace the original linear.forward
4. Add selection logic for MyColumnParallelOp in the get_column_parallel_op method, typically based on prefix and configuration judgments
Row parallel op follows a similar approach - inherit from RowColumnParallelOp and register the new class in get_row_parallel_op.
"""
from typing import Optional, Union
import torch
import torch.distributed as dist
[0.11.0][Perf] Add fused matmul/reduce-scatter kernel for performance optimization. (#3725) ### What this PR does / why we need it? This PR boosts performance by introducing a fused kernel for the matrix matmul and reduce scatter operations. It supports both unquantized (e.g., BFloat16) and W8A8 quantized models. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: ZYang6263 <zy626375@gmail.com>
2025-10-25 08:20:43 +08:00
import torch.nn.functional as F
import torch_npu
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
[0.11.0][Perf] Add fused matmul/reduce-scatter kernel for performance optimization. (#3725) ### What this PR does / why we need it? This PR boosts performance by introducing a fused kernel for the matrix matmul and reduce scatter operations. It supports both unquantized (e.g., BFloat16) and W8A8 quantized models. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: ZYang6263 <zy626375@gmail.com>
2025-10-25 08:20:43 +08:00
from vllm.distributed import (split_tensor_along_last_dim,
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter)
from vllm.distributed.parallel_state import get_tp_group
[0.11.0][Perf] Add fused matmul/reduce-scatter kernel for performance optimization. (#3725) ### What this PR does / why we need it? This PR boosts performance by introducing a fused kernel for the matrix matmul and reduce scatter operations. It supports both unquantized (e.g., BFloat16) and W8A8 quantized models. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: ZYang6263 <zy626375@gmail.com>
2025-10-25 08:20:43 +08:00
from vllm.forward_context import get_forward_context
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
get_otp_group)
from vllm_ascend.utils import (dense_optim_enable, enable_sp,
matmul_allreduce_enable, mlp_tp_enable,
oproj_tp_enable, shared_expert_dp_enabled)
class CustomLinearOp:
def __init__(self, layer):
self.layer = layer
self.bias = None
self.skip_bias_add = None
self.return_bias = None
self.quant_method = None
# Custom communication group, while determining weight sharding
@property
def comm_group(self):
return get_tp_group()
@property
def tp_rank(self):
return self.comm_group.rank_in_group
@property
def tp_size(self):
return self.comm_group.world_size
# Update the attributes required by apply(), obtaining them from the layer.
# Call this after the layer completes its initialization, specifically at the end of layer.init().
def update_attrs(self):
if hasattr(self.layer, "bias"):
self.bias = self.layer.bias
self.skip_bias_add = self.layer.skip_bias_add
self.return_bias = self.layer.return_bias
self.quant_method = self.layer.quant_method
self.prefix = self.layer.prefix
def apply_impl(self, input_):
raise NotImplementedError
# Replace layer.forward to customize the layer computation process.
def apply(self, input_):
output, output_bias = self.apply_impl(input_)
if not self.return_bias:
return output
return output, output_bias
class CustomColumnParallelOp(CustomLinearOp):
def __init__(self, layer):
super().__init__(layer)
self.gather_output = None
def update_attrs(self):
super().update_attrs()
self.gather_output = self.layer.gather_output
class CustomRowParallelOp(CustomLinearOp):
def __init__(self, layer):
super().__init__(layer)
self.reduce_results = None
self.input_is_parallel = None
self.input_size_per_partition = None
def update_attrs(self):
super().update_attrs()
self.input_is_parallel = self.layer.input_is_parallel
self.reduce_results = self.layer.reduce_results
self.input_size_per_partition = self.layer.input_size_per_partition
def apply(self, input_):
output, output_bias = self.apply_impl(input_)
if dense_optim_enable():
torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix)
if not self.return_bias:
return output
return output, output_bias
class CustomReplicatedOp(CustomLinearOp):
def apply_impl(self, input_):
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self.layer, input_, bias)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class MLPColumnParallelOp(CustomColumnParallelOp):
def __init__(self, layer):
super().__init__(layer)
@property
def comm_group(self):
return get_mlp_tp_group()
def apply_impl(
self,
input_: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
input_parallel = self.comm_group.all_gather(input_, 0)
output = self.quant_method.apply(self.layer, input_parallel, bias)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class MLPRowParallelOp(CustomRowParallelOp):
def __init__(self, layer):
super().__init__(layer)
@property
def comm_group(self):
return get_mlp_tp_group()
def apply_impl(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
else:
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()
assert self.quant_method is not None
bias_ = None if (self.tp_rank > 0
or self.skip_bias_add) else self.layer.bias
output_parallel = self.quant_method.apply(self.layer,
input_parallel,
bias=bias_)
output = self.comm_group.reduce_scatter(output_parallel, 0)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class OProjRowParallelOp(CustomRowParallelOp):
def __init__(self, layer):
super().__init__(layer)
@property
def comm_group(self):
return get_otp_group()
def apply_impl(
self,
input_: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
else:
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()
# Prepare tensors for all-to-all communication
local_batch_size = input_parallel.size(0)
chunk_size = self.input_size_per_partition
total_batch_size = local_batch_size * self.tp_size
# Reshape tensor for efficient cross-device transfer:
# [batch, dim] -> [tp_size, batch, chunk] -> flattened
send_buf = (input_parallel.reshape(-1,
self.tp_size, chunk_size).transpose(
0, 1).contiguous().view(-1))
# Create receive buffer
recv_buf = torch.empty(total_batch_size * chunk_size,
dtype=input_parallel.dtype,
device=input_parallel.device)
# Perform all-to-all communication
dist.all_to_all_single(recv_buf,
send_buf,
group=self.comm_group.device_group)
input_parallel = recv_buf.view(total_batch_size, chunk_size)
# Only fuse bias add for rank 0 to avoid duplicate bias addition in TP>1
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self.layer,
input_parallel,
bias=bias_)
# otp-specific: Combine partial results across devices
output = self.comm_group.reduce_scatter(output_parallel, dim=0)
output = output.view(input_.shape[0], self.layer.output_size)
# Handle bias return based on configuration
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def update_attrs(self):
super().update_attrs()
self.input_is_parallel = self.layer.input_is_parallel
self.input_size_per_partition = self.layer.input_size_per_partition
class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
_HCOMM_INFO = None
def __init__(self, layer):
super().__init__(layer)
self.hcomm_info = self.get_hcomm_info(self.comm_group.device_group)
def apply_impl(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
else:
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()
"""Calculate the output tensor of forward by considering
fusing communication and computation."""
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
if self.reduce_results and self.tp_size > 1:
output = torch_npu.npu_mm_all_reduce_base(input_parallel,
self.weight_t,
self.hcomm_info,
bias=bias_)
else:
assert self.quant_method is not None
output = self.quant_method.apply(self.layer,
input_parallel,
bias=bias_)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
@classmethod
def get_hcomm_info(cls, group: ProcessGroup) -> str:
"""Get the HCCL communication information for the given group."""
if cls._HCOMM_INFO is not None:
return cls._HCOMM_INFO
rank = torch.distributed.get_rank(group)
if torch.__version__ > "2.0":
global_rank = torch.distributed.get_global_rank(group, rank)
cls._HCOMM_INFO = group._get_backend(
torch.device("npu")).get_hccl_comm_name(global_rank)
else:
cls._HCOMM_INFO = group.get_hccl_comm_name(rank)
return cls._HCOMM_INFO
def update_attrs(self):
super().update_attrs()
self.weight_t = self.layer.weight.t()
class SequenceColumnParallelOp(CustomColumnParallelOp):
def apply_impl(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism.
Implemented multiple optimization projects for dense models, such as FlashComm and
communication-computation fusion.
"""
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
output_parallel = self.quant_method.apply(self.layer, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = self.comm_group.all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class SequenceRowParallelOp(CustomRowParallelOp):
def apply_impl(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism.
Implemented multiple optimization projects for dense models, such as FlashComm and
communication-computation fusion.
"""
if self.input_is_parallel:
input_parallel = input_
else:
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()
assert self.quant_method is not None
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
if self.tp_size == 1 or not self.reduce_results:
output = self.quant_method.apply(self.layer,
input_parallel,
bias=bias_)
else:
output = torch.ops.vllm.matmul_and_reduce(input_parallel,
self.prefix)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def matmul_and_reduce(self, input_parallel: torch.Tensor,
bias_: Optional[Parameter]) -> torch.Tensor:
assert self.quant_method is not None
[0.11.0][Perf] Add fused matmul/reduce-scatter kernel for performance optimization. (#3725) ### What this PR does / why we need it? This PR boosts performance by introducing a fused kernel for the matrix matmul and reduce scatter operations. It supports both unquantized (e.g., BFloat16) and W8A8 quantized models. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: ZYang6263 <zy626375@gmail.com>
2025-10-25 08:20:43 +08:00
try:
forward_context = get_forward_context()
sp_enabled = forward_context.sp_enabled
mmrs_fusion = forward_context.mmrs_fusion
[0.11.0][Perf] Add fused matmul/reduce-scatter kernel for performance optimization. (#3725) ### What this PR does / why we need it? This PR boosts performance by introducing a fused kernel for the matrix matmul and reduce scatter operations. It supports both unquantized (e.g., BFloat16) and W8A8 quantized models. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: ZYang6263 <zy626375@gmail.com>
2025-10-25 08:20:43 +08:00
except AssertionError:
sp_enabled = False
mmrs_fusion = False
[0.11.0][Perf] Add fused matmul/reduce-scatter kernel for performance optimization. (#3725) ### What this PR does / why we need it? This PR boosts performance by introducing a fused kernel for the matrix matmul and reduce scatter operations. It supports both unquantized (e.g., BFloat16) and W8A8 quantized models. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: ZYang6263 <zy626375@gmail.com>
2025-10-25 08:20:43 +08:00
x = input_parallel
if not sp_enabled:
output_parallel = self.layer.quant_method.apply(self.layer,
x,
bias=bias_)
return tensor_model_parallel_all_reduce(output_parallel)
pad_size = forward_context.pad_size
if pad_size > 0:
x = F.pad(x, (0, 0, 0, pad_size))
world_size = self.layer.tp_size
comm_mode = "aiv"
hcom_name = get_tp_group().device_group._get_backend(
torch.device('npu')).get_hccl_comm_name(self.layer.tp_rank)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.quantization.w8a8 import (AscendW8A8LinearMethod,
quant_per_tensor)
# For unquant
if mmrs_fusion and isinstance(self.layer.quant_method,
UnquantizedLinearMethod):
[0.11.0][Perf] Add fused matmul/reduce-scatter kernel for performance optimization. (#3725) ### What this PR does / why we need it? This PR boosts performance by introducing a fused kernel for the matrix matmul and reduce scatter operations. It supports both unquantized (e.g., BFloat16) and W8A8 quantized models. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: ZYang6263 <zy626375@gmail.com>
2025-10-25 08:20:43 +08:00
output = torch_npu.npu_mm_reduce_scatter_base(
x,
self.layer.weight.t(),
hcom_name,
world_size,
reduce_op="sum",
bias=None,
comm_turn=0,
comm_mode=comm_mode)
if bias_ is not None:
output.add_(bias_)
# For w8a8 quant
elif mmrs_fusion and (
isinstance(self.layer.quant_method, AscendLinearMethod)
and isinstance(self.layer.quant_method.quant_method,
AscendW8A8LinearMethod)):
[0.11.0][Perf] Add fused matmul/reduce-scatter kernel for performance optimization. (#3725) ### What this PR does / why we need it? This PR boosts performance by introducing a fused kernel for the matrix matmul and reduce scatter operations. It supports both unquantized (e.g., BFloat16) and W8A8 quantized models. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: ZYang6263 <zy626375@gmail.com>
2025-10-25 08:20:43 +08:00
if x.dtype != torch.int8:
x_quant = quant_per_tensor(
x, self.layer.aclnn_input_scale_reciprocal,
self.layer.aclnn_input_offset)
else:
x_quant = x
quant_bias = self.layer.quant_bias
deq_scale = self.layer.deq_scale
output_dtype = torch.bfloat16
output = torch_npu.npu_mm_reduce_scatter_base(
x_quant,
self.layer.weight,
hcom_name,
world_size,
reduce_op="sum",
bias=None,
comm_turn=0,
x2_scale=deq_scale,
output_dtype=output_dtype,
comm_mode=comm_mode)
output = torch.add(
output,
torch.mul(quant_bias, deq_scale).to(self.layer.params_dtype))
else:
output_parallel = self.layer.quant_method.apply(self.layer,
x,
bias=bias_)
output = tensor_model_parallel_reduce_scatter(output_parallel, 0)
return output
def update_attrs(self):
super().update_attrs()
self.input_is_parallel = self.layer.input_is_parallel
self.reduce_results = self.layer.reduce_results
def _get_column_parallel_op(
prefix, layer
) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp]]:
if mlp_tp_enable() and "gate_up_proj" in prefix:
return MLPColumnParallelOp(layer)
if enable_sp():
if "shared_expert" in prefix:
return None
if "gate_up_proj" in prefix:
return SequenceColumnParallelOp(layer)
if "in_proj" in prefix:
return SequenceColumnParallelOp(layer)
if "qkv_proj" in prefix or "conv1d" in prefix:
return SequenceColumnParallelOp(layer)
return None
def _get_row_parallel_op(
prefix, layer
) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
MatmulAllreduceRowParallelOp, SequenceRowParallelOp]]:
if "down_proj" in prefix and mlp_tp_enable():
return MLPRowParallelOp(layer)
if "o_proj" in prefix and oproj_tp_enable():
return OProjRowParallelOp(layer)
if matmul_allreduce_enable():
return MatmulAllreduceRowParallelOp(layer)
if enable_sp():
if "shared_expert" in prefix:
return None
if "o_proj" in prefix or "out_proj" in prefix or "down_proj" in prefix:
return SequenceRowParallelOp(layer)
return None
def get_parallel_op(disable_tp, prefix, layer, direct):
if disable_tp or ("shared_experts" in prefix
and shared_expert_dp_enabled()):
return None, 0, 1
custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp,
MLPRowParallelOp, OProjRowParallelOp,
MatmulAllreduceRowParallelOp,
SequenceRowParallelOp]] = None
if direct == "row":
custom_op = _get_row_parallel_op(prefix, layer)
if direct == "column":
custom_op = _get_column_parallel_op(prefix, layer)
if custom_op is not None:
return custom_op, custom_op.tp_rank, custom_op.tp_size
return None, get_tp_group().rank_in_group, get_tp_group().world_size
def get_replicated_op(disable_tp, prefix,
layer) -> Optional[Union[CustomReplicatedOp]]:
if disable_tp:
return None
return CustomReplicatedOp(layer)