### What this PR does / why we need it? This PR refactors SequenceRowParallelOp forward. In order to further expand the operator inclusion scope in dynamic judgment scenarios, this PR customizes the entire matmul computation and communication as a custom operator masking. With this refactor, it will support directly writing code such as common operation fusion into the SequenceRowParallelOp class's member function matmul_and_reduce, without the need to register more redundant custom masking operators. ### How was this patch tested? CI passed with new added/existing test. Signed-off-by: rjg-lyh <1318825571@qq.com>
466 lines
18 KiB
Python
466 lines
18 KiB
Python
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
To customize linear communication groups or forward of classes in this file,
|
|
extend new linear operations in linear_op.py.
|
|
The classes in this file should not be modified, including AscendQKVParallelLinear,
|
|
AscendMergedColumnParallelLinear, AscendMergedColumnParallelLinear,
|
|
AscendRowParallelLinear and AscendColumnParallelLinear.
|
|
"""
|
|
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch_npu
|
|
from torch.nn.parameter import Parameter
|
|
from vllm.config import get_current_vllm_config
|
|
from vllm.distributed import divide
|
|
from vllm.model_executor.layers.linear import ( # noqa
|
|
WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, LinearBase,
|
|
MergedColumnParallelLinear, QKVParallelLinear, QuantizeMethodBase,
|
|
ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod)
|
|
from vllm.model_executor.layers.quantization.base_config import \
|
|
QuantizationConfig
|
|
from vllm.model_executor.utils import set_weight_attrs
|
|
|
|
from vllm_ascend.ops.linear_op import get_parallel_op, get_replicated_op
|
|
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
|
|
|
|
|
class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
|
|
"""Linear method without quantization"""
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
super().process_weights_after_loading(layer)
|
|
if (is_enable_nz() and torch.version.cann.startswith("8.3") and
|
|
layer.weight.data.dtype in [torch.float16, torch.bfloat16]):
|
|
layer.weight.data = torch_npu.npu_format_cast(
|
|
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
|
|
|
|
|
# TODO(realliujiaxu): Remove this class after linear of vllm supports custom comm group
|
|
class AscendLinearBase(LinearBase):
|
|
|
|
def __init__(
|
|
self,
|
|
input_size: int,
|
|
output_size: int,
|
|
skip_bias_add: bool = False,
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
*,
|
|
return_bias: bool = True,
|
|
disable_tp: bool = False,
|
|
):
|
|
nn.Module.__init__(self)
|
|
|
|
# Keep input parameters
|
|
self.input_size = input_size
|
|
self.output_size = output_size
|
|
self.skip_bias_add = skip_bias_add
|
|
if params_dtype is None:
|
|
params_dtype = torch.get_default_dtype()
|
|
self.params_dtype = params_dtype
|
|
self.quant_config = quant_config
|
|
self.prefix = prefix
|
|
if quant_config is None:
|
|
self.quant_method: Optional[
|
|
QuantizeMethodBase] = AscendUnquantizedLinearMethod()
|
|
else:
|
|
self.quant_method = quant_config.get_quant_method(self,
|
|
prefix=prefix)
|
|
self.return_bias = return_bias
|
|
self.disable_tp = disable_tp
|
|
|
|
|
|
class AscendQKVParallelLinear(QKVParallelLinear):
|
|
"""Linear layers for the attention's QKV transformation.
|
|
|
|
Linear layers for the linear transformation of the query, key, and value
|
|
vectors in the attention layer. The weight matrix is concatenated along
|
|
the output dimension. The layer is parallelized along the head dimension.
|
|
When the number of key/value heads is smaller than the number of query
|
|
heads (e.g., multi-query/grouped-query attention), the key/value head may
|
|
be replicated while the query heads are partitioned.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
head_size: int,
|
|
total_num_heads: int,
|
|
total_num_kv_heads: Optional[int] = None,
|
|
bias: bool = True,
|
|
skip_bias_add: bool = False,
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
*,
|
|
return_bias: bool = True,
|
|
disable_tp: bool = False,
|
|
):
|
|
self.custom_op, _, tp_size = get_parallel_op(disable_tp, prefix, self,
|
|
"column")
|
|
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
|
|
self.hidden_size = hidden_size
|
|
self.head_size = head_size
|
|
self.total_num_heads = total_num_heads
|
|
if total_num_kv_heads is None:
|
|
total_num_kv_heads = total_num_heads
|
|
self.total_num_kv_heads = total_num_kv_heads
|
|
# Divide the weight matrix along the last dimension.
|
|
self.num_heads = divide(self.total_num_heads, tp_size)
|
|
if tp_size >= self.total_num_kv_heads:
|
|
self.num_kv_heads = 1
|
|
self.num_kv_head_replicas = divide(tp_size,
|
|
self.total_num_kv_heads)
|
|
else:
|
|
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
|
self.num_kv_head_replicas = 1
|
|
input_size = self.hidden_size
|
|
output_size = (self.num_heads +
|
|
2 * self.num_kv_heads) * tp_size * self.head_size
|
|
self.output_sizes = [
|
|
self.num_heads * self.head_size * tp_size, # q_proj
|
|
self.num_kv_heads * self.head_size * tp_size, # k_proj
|
|
self.num_kv_heads * self.head_size * tp_size, # v_proj
|
|
]
|
|
AscendColumnParallelLinear.__init__(self,
|
|
input_size=input_size,
|
|
output_size=output_size,
|
|
bias=bias,
|
|
gather_output=False,
|
|
skip_bias_add=skip_bias_add,
|
|
params_dtype=params_dtype,
|
|
quant_config=quant_config,
|
|
prefix=prefix,
|
|
return_bias=return_bias,
|
|
disable_tp=disable_tp)
|
|
|
|
def forward(
|
|
self,
|
|
input_,
|
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
|
if self.custom_op is not None:
|
|
return self.custom_op.apply(input_)
|
|
|
|
return super().forward(input_)
|
|
|
|
|
|
class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
|
|
"""Packed linear layers with column parallelism.
|
|
|
|
Similar to ColumnParallelLinear, but the weight matrix is concatenated
|
|
along the output dimension. When the weight matrix is loaded, the
|
|
different partitions are sharded separately.
|
|
|
|
Use the MLP tensor parallelism group in the MLP module,
|
|
and the original TP group in other modules.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_size: int,
|
|
output_sizes: list[int],
|
|
bias: bool = True,
|
|
gather_output: bool = False,
|
|
skip_bias_add: bool = False,
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
*,
|
|
return_bias: bool = True,
|
|
disable_tp: bool = False,
|
|
):
|
|
self.custom_op, self.tp_rank, self.tp_size = get_parallel_op(
|
|
disable_tp, prefix, self, "column")
|
|
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
|
|
self.output_sizes = output_sizes
|
|
assert all(output_size % self.tp_size == 0
|
|
for output_size in output_sizes)
|
|
AscendColumnParallelLinear.__init__(self,
|
|
input_size=input_size,
|
|
output_size=sum(output_sizes),
|
|
bias=bias,
|
|
gather_output=gather_output,
|
|
skip_bias_add=skip_bias_add,
|
|
params_dtype=params_dtype,
|
|
quant_config=quant_config,
|
|
prefix=prefix,
|
|
return_bias=return_bias,
|
|
disable_tp=disable_tp)
|
|
|
|
def forward(
|
|
self,
|
|
input_,
|
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
|
if self.custom_op is not None:
|
|
return self.custom_op.apply(input_)
|
|
|
|
return super().forward(input_)
|
|
|
|
|
|
class AscendRowParallelLinear(RowParallelLinear):
|
|
"""Linear layer with row parallelism.
|
|
Use the MLP tensor parallelism group in the MLP module,
|
|
and the original TP group in other modules.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_size: int,
|
|
output_size: int,
|
|
bias: bool = True,
|
|
input_is_parallel: bool = True,
|
|
skip_bias_add: bool = False,
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
reduce_results: bool = True,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
*,
|
|
return_bias: bool = True,
|
|
disable_tp: bool = False,
|
|
):
|
|
compilation_config = get_current_vllm_config().compilation_config
|
|
# TODO(shaopeng-666): Remove the visual check after the mm model reconstruction is complete.
|
|
if prefix in compilation_config.static_forward_context and \
|
|
"visual" not in prefix:
|
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
|
compilation_config.static_forward_context[prefix] = self
|
|
|
|
self.custom_op, self.tp_rank, self.tp_size = get_parallel_op(
|
|
disable_tp, prefix, self, "row")
|
|
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
|
|
# Divide the weight matrix along the first dimension.
|
|
self.input_size_per_partition = divide(input_size, self.tp_size)
|
|
self.output_size_per_partition = output_size
|
|
self.output_partition_sizes = [output_size]
|
|
|
|
AscendLinearBase.__init__(self,
|
|
input_size,
|
|
output_size,
|
|
skip_bias_add,
|
|
params_dtype,
|
|
quant_config,
|
|
prefix,
|
|
return_bias=return_bias,
|
|
disable_tp=disable_tp)
|
|
|
|
self.input_is_parallel = input_is_parallel
|
|
self.reduce_results = reduce_results
|
|
|
|
assert self.quant_method is not None
|
|
self.quant_method.create_weights(
|
|
layer=self,
|
|
input_size_per_partition=self.input_size_per_partition,
|
|
output_partition_sizes=self.output_partition_sizes,
|
|
input_size=self.input_size,
|
|
output_size=self.output_size,
|
|
params_dtype=self.params_dtype,
|
|
weight_loader=(
|
|
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
|
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
|
if not reduce_results and (bias and not skip_bias_add):
|
|
raise ValueError("When not reduce the results, adding bias to the "
|
|
"results can lead to incorrect results")
|
|
|
|
if bias:
|
|
self.bias = Parameter(
|
|
torch.empty(self.output_size, dtype=params_dtype))
|
|
set_weight_attrs(self.bias, {
|
|
"output_dim": 0,
|
|
"weight_loader": self.weight_loader,
|
|
})
|
|
else:
|
|
self.register_parameter("bias", None)
|
|
|
|
if self.custom_op is not None:
|
|
self.custom_op.update_attrs()
|
|
|
|
def forward(
|
|
self,
|
|
input_,
|
|
is_prefill: bool = True,
|
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
|
if self.custom_op is not None:
|
|
return self.custom_op.apply(input_)
|
|
|
|
return super().forward(input_)
|
|
|
|
|
|
class AscendColumnParallelLinear(ColumnParallelLinear):
|
|
"""Linear layer with column parallelism.
|
|
|
|
Use the MLP tensor parallelism group in the MLP module,
|
|
and the original TP group in other modules.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_size: int,
|
|
output_size: int,
|
|
bias: bool = True,
|
|
gather_output: bool = False,
|
|
skip_bias_add: bool = False,
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
output_sizes: Optional[list[int]] = None,
|
|
prefix: str = "",
|
|
*,
|
|
return_bias: bool = True,
|
|
disable_tp: bool = False,
|
|
):
|
|
self.custom_op, self.tp_rank, self.tp_size = get_parallel_op(
|
|
disable_tp, prefix, self, "column")
|
|
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
|
|
self.input_size_per_partition = input_size
|
|
self.output_size_per_partition = divide(output_size, self.tp_size)
|
|
self.output_partition_sizes = [self.output_size_per_partition]
|
|
# If QKV or MergedColumn, use output size of each partition.
|
|
if hasattr(self, "output_sizes"):
|
|
self.output_partition_sizes = [
|
|
divide(output_size, self.tp_size)
|
|
for output_size in self.output_sizes
|
|
]
|
|
|
|
AscendLinearBase.__init__(self,
|
|
input_size,
|
|
output_size,
|
|
skip_bias_add,
|
|
params_dtype,
|
|
quant_config,
|
|
prefix,
|
|
return_bias=return_bias,
|
|
disable_tp=disable_tp)
|
|
|
|
self.gather_output = gather_output
|
|
|
|
if output_sizes is None:
|
|
output_sizes = [output_size]
|
|
|
|
assert self.quant_method is not None
|
|
self.quant_method.create_weights(
|
|
layer=self,
|
|
input_size_per_partition=self.input_size_per_partition,
|
|
output_partition_sizes=self.output_partition_sizes,
|
|
input_size=self.input_size,
|
|
output_size=self.output_size,
|
|
params_dtype=self.params_dtype,
|
|
weight_loader=(
|
|
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
|
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
|
if bias:
|
|
self.bias = Parameter(
|
|
torch.empty(self.output_size_per_partition,
|
|
dtype=params_dtype))
|
|
set_weight_attrs(self.bias, {
|
|
"output_dim": 0,
|
|
"weight_loader": self.weight_loader,
|
|
})
|
|
else:
|
|
self.register_parameter("bias", None)
|
|
|
|
if self.custom_op is not None:
|
|
self.custom_op.update_attrs()
|
|
|
|
def forward(
|
|
self,
|
|
input_,
|
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
|
if self.custom_op is not None:
|
|
return self.custom_op.apply(input_)
|
|
|
|
return super().forward(input_)
|
|
|
|
|
|
class AscendReplicatedLinear(ReplicatedLinear):
|
|
"""Ascend Replicated linear layer.
|
|
|
|
Args:
|
|
input_size: input dimension of the linear layer.
|
|
output_size: output dimension of the linear layer.
|
|
bias: If true, add bias.
|
|
skip_bias_add: If true, skip adding bias but instead return it.
|
|
params_dtype: Data type for the parameters.
|
|
quant_config: Quantization configure.
|
|
prefix: The name of the layer in the state dict, including all parents
|
|
(e.g. model.layers.0.qkv_proj)
|
|
return_bias: If true, return bias together with outputs in forward pass.
|
|
disable_tp: Take no effect for replicated linear layers.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_size: int,
|
|
output_size: int,
|
|
bias: bool = True,
|
|
skip_bias_add: bool = False,
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
*,
|
|
return_bias: bool = True,
|
|
disable_tp: bool = False,
|
|
):
|
|
self.custom_op = get_replicated_op(disable_tp, prefix, self)
|
|
# If MergedReplicatedLinear, use output size of each partition.
|
|
if hasattr(self, "output_sizes"):
|
|
self.output_partition_sizes = self.output_sizes
|
|
else:
|
|
self.output_partition_sizes = [output_size]
|
|
|
|
AscendLinearBase.__init__(self,
|
|
input_size,
|
|
output_size,
|
|
skip_bias_add,
|
|
params_dtype,
|
|
quant_config,
|
|
prefix=prefix,
|
|
return_bias=return_bias,
|
|
disable_tp=disable_tp)
|
|
|
|
# All the linear layer supports quant method.
|
|
assert self.quant_method is not None
|
|
self.quant_method.create_weights(self,
|
|
self.input_size, [self.output_size],
|
|
self.input_size,
|
|
self.output_size,
|
|
self.params_dtype,
|
|
weight_loader=self.weight_loader)
|
|
|
|
if bias:
|
|
self.bias = Parameter(
|
|
torch.empty(self.output_size, dtype=self.params_dtype))
|
|
set_weight_attrs(self.bias, {
|
|
"output_dim": 0,
|
|
"weight_loader": self.weight_loader,
|
|
})
|
|
else:
|
|
self.register_parameter("bias", None)
|
|
|
|
if self.custom_op is not None:
|
|
self.custom_op.update_attrs()
|
|
|
|
def forward(
|
|
self,
|
|
input_,
|
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
|
if self.custom_op is not None:
|
|
return self.custom_op.apply(input_)
|
|
|
|
return super().forward(input_)
|