Files
xc-llm-ascend/vllm_ascend/ops/linear.py
rjg-lyh 0005479b9c [main] mlp weight prefetch in Qwen Dense Models (#2816)
### What this PR does / why we need it?
This PR prefetchs the weight of mlp layers in Qwen Dense Models to
optimize the performance in Decode phase mainly.

### Does this PR introduce _any_ user-facing change?
 No.

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: main
- vLLM main:
a1213fae5f

Signed-off-by: rjg-lyh <1318825571@qq.com>
Co-authored-by: Shuming19 <313093131@qq.com>
2025-09-11 21:20:09 +08:00

651 lines
25 KiB
Python

"""
Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
This file is a part of the vllm-ascend project.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Optional, Union
import torch
import torch.distributed as dist
import torch.nn as nn
import torch_npu
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from vllm.distributed import divide, split_tensor_along_last_dim
from vllm.distributed.parallel_state import get_tp_group
from vllm.lora.utils import LinearBase
from vllm.model_executor.layers.linear import ( # noqa
WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear,
MergedColumnParallelLinear, QKVParallelLinear, QuantizeMethodBase,
RowParallelLinear, UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import \
QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
get_otp_group)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, dense_optim_enable,
matmul_allreduce_enable, mlp_tp_enable,
oproj_tp_enable)
_HCOMM_INFO = None
class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
"""Linear method without quantization."""
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
if torch.version.cann.startswith("8.3"):
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if torch.version.cann.startswith("8.3"):
if bias is None:
return torch.matmul(x, layer.weight)
else:
return torch.matmul(x, layer.weight) + bias
else:
return torch.nn.functional.linear(x, layer.weight, bias)
class AscendColumnParallelLinear(ColumnParallelLinear):
"""Linear layer with column parallelism.
Use the MLP tensor parallelism group in the MLP module,
and the original TP group in other modules.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[list[int]] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
self.comm_group = None
if prefix.find("gate_up_proj") != -1 and mlp_tp_enable():
self.comm_group = get_mlp_tp_group()
else:
self.comm_group = get_tp_group()
self.tp_size = self.comm_group.world_size
self.tp_rank = self.comm_group.rank_in_group
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, self.tp_size)
for output_size in self.output_sizes
]
AscendLinearBase.__init__(self,
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias,
disable_tp=disable_tp)
self.gather_output = gather_output
if output_sizes is None:
output_sizes = [output_size]
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
class AscendRowParallelLinear(RowParallelLinear):
"""Linear layer with row parallelism.
Use the MLP tensor parallelism group in the MLP module,
and the original TP group in other modules.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
if prefix.find("down_proj") != -1 and mlp_tp_enable():
comm_group = get_mlp_tp_group()
self.forward_type = "mlp_tp"
elif prefix.find("o_proj") != -1 and oproj_tp_enable():
comm_group = get_otp_group()
self.forward_type = "oproj_tp"
elif matmul_allreduce_enable():
comm_group = get_tp_group()
self.forward_type = "matmul_allreduce"
self.hcomm_info = self.get_hcomm_info(comm_group.device_group)
elif dense_optim_enable():
comm_group = get_tp_group()
self.forward_type = "dense_optim"
else:
comm_group = get_tp_group()
self.forward_type = "normal"
self.comm_group = comm_group
# TODO: check for disable_tp
self.tp_size = self.comm_group.world_size
self.tp_rank = self.comm_group.rank_in_group
# Divide the weight matrix along the first dimension.
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
AscendLinearBase.__init__(self,
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias,
disable_tp=disable_tp)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
if matmul_allreduce_enable():
self.weight_t = self.weight.t()
@staticmethod
def get_hcomm_info(group: ProcessGroup) -> str:
"""Get the HCCL communication information for the given group."""
global _HCOMM_INFO
if _HCOMM_INFO is not None:
return _HCOMM_INFO
rank = torch.distributed.get_rank(group)
if torch.__version__ > "2.0":
global_rank = torch.distributed.get_global_rank(group, rank)
_HCOMM_INFO = group._get_backend(
torch.device("npu")).get_hccl_comm_name(global_rank)
else:
_HCOMM_INFO = group.get_hccl_comm_name(rank)
return _HCOMM_INFO
def forward(
self,
input_,
is_prefill: bool = True,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
# Choose different forward function according to the type of TP group
if self.forward_type == "oproj_tp":
return self._forward_oproj_tp(input_)
elif self.forward_type == "mlp_tp":
return self._forward_mlp_tp(input_)
elif self.forward_type == "matmul_allreduce":
return self._forward_matmul_allreduce(input_)
elif self.forward_type == "dense_optim":
return self._forward_dense_optim(input_)
else:
return super().forward(input_)
# enable custom MLP tensor parallel
def _forward_mlp_tp(self, input_: torch.Tensor) -> torch.Tensor:
if self.input_is_parallel:
input_parallel = input_
else:
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()
assert self.quant_method is not None
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
output = self.comm_group.reduce_scatter(output_parallel, 0)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
# enable custom Oproj tensor parallel
def _forward_oproj_tp(
self,
input_: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
else:
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()
# Prepare tensors for all-to-all communication
local_batch_size = input_parallel.size(0)
chunk_size = self.input_size_per_partition
total_batch_size = local_batch_size * self.tp_size
# Reshape tensor for efficient cross-device transfer:
# [batch, dim] -> [tp_size, batch, chunk] -> flattened
send_buf = (input_parallel.reshape(-1,
self.tp_size, chunk_size).transpose(
0, 1).contiguous().view(-1))
# Create receive buffer
recv_buf = torch.empty(total_batch_size * chunk_size,
dtype=input_parallel.dtype,
device=input_parallel.device)
# Perform all-to-all communication
dist.all_to_all_single(recv_buf,
send_buf,
group=self.comm_group.device_group)
input_parallel = recv_buf.view(total_batch_size, chunk_size)
# Only fuse bias add for rank 0 to avoid duplicate bias addition in TP>1
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
# otp-specific: Combine partial results across devices
output = self.comm_group.reduce_scatter(output_parallel, dim=0)
# Handle bias return based on configuration
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def _forward_matmul_allreduce(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
else:
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()
"""Calculate the output tensor of forward by considering
fusing communication and computation."""
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
if self.reduce_results and self.tp_size > 1:
output = torch_npu.npu_mm_all_reduce_base(input_parallel,
self.weight_t,
self.hcomm_info,
bias=bias_)
else:
output = self.quant_method.apply(self, input_parallel, bias=bias_)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def _forward_dense_optim(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism.
Implemented multiple optimization projects for dense models, such as FlashComm and
communication-computation fusion.
"""
if self.input_is_parallel:
input_parallel = input_
else:
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()
assert self.quant_method is not None
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
if self.tp_size == 1 or not self.reduce_results:
output = self.quant_method.apply(self, input_parallel, bias=bias_)
else:
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
"""Packed linear layers with column parallelism.
Similar to ColumnParallelLinear, but the weight matrix is concatenated
along the output dimension. When the weight matrix is loaded, the
different partitions are sharded separately.
Use the MLP tensor parallelism group in the MLP module,
and the original TP group in other modules.
"""
def __init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
if prefix.find("gate_up_proj") != -1 and mlp_tp_enable():
comm_group = get_mlp_tp_group()
self.forward_type = "mlp_tp"
elif dense_optim_enable():
comm_group = get_tp_group()
self.forward_type = "dense_optim"
else:
comm_group = get_tp_group()
self.forward_type = "normal_tp"
self.comm_group = comm_group
# TODO: check for disable_tp
self.tp_rank = comm_group.rank_in_group
self.tp_size = comm_group.world_size
self.output_sizes = output_sizes
assert all(output_size % self.tp_size == 0
for output_size in output_sizes)
AscendColumnParallelLinear.__init__(self,
input_size=input_size,
output_size=sum(output_sizes),
bias=bias,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias,
disable_tp=disable_tp)
def forward(
self,
input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.forward_type == "mlp_tp":
return self._forward_mlp_tp(input_)
elif self.forward_type == "dense_optim":
return self._forward_dense_optim(input_)
else:
return super().forward(input_)
def _forward_mlp_tp(
self,
input_: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
input_parallel = get_mlp_tp_group().all_gather(input_, 0)
output = self.quant_method.apply(self, input_parallel, bias)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def _forward_dense_optim(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism.
Implemented multiple optimization projects for dense models, such as FlashComm and
communication-computation fusion.
"""
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = self.comm_group.all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
class AscendQKVParallelLinear(QKVParallelLinear):
"""Linear layers for the attention's QKV transformation.
Linear layers for the linear transformation of the query, key, and value
vectors in the attention layer. The weight matrix is concatenated along
the output dimension. The layer is parallelized along the head dimension.
When the number of key/value heads is smaller than the number of query
heads (e.g., multi-query/grouped-query attention), the key/value head may
be replicated while the query heads are partitioned.
"""
def __init__(
self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: Optional[int] = None,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
if dense_optim_enable():
self.forward_type = "dense_optim"
else:
self.forward_type = "normal_tp"
self.comm_group = get_tp_group()
self.hidden_size = hidden_size
self.head_size = head_size
self.total_num_heads = total_num_heads
if total_num_kv_heads is None:
total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
# Divide the weight matrix along the last dimension.
# TODO: check for disable_tp
tp_size = self.comm_group.world_size
self.num_heads = divide(self.total_num_heads, tp_size)
if tp_size >= self.total_num_kv_heads:
self.num_kv_heads = 1
self.num_kv_head_replicas = divide(tp_size,
self.total_num_kv_heads)
else:
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
self.num_kv_head_replicas = 1
input_size = self.hidden_size
output_size = (self.num_heads +
2 * self.num_kv_heads) * tp_size * self.head_size
self.output_sizes = [
self.num_heads * self.head_size * tp_size, # q_proj
self.num_kv_heads * self.head_size * tp_size, # k_proj
self.num_kv_heads * self.head_size * tp_size, # v_proj
]
AscendColumnParallelLinear.__init__(self,
input_size=input_size,
output_size=output_size,
bias=bias,
gather_output=False,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias,
disable_tp=disable_tp)
def forward(
self,
input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.forward_type == "dense_optim":
return self._forward_dense_optim(input_)
else:
return super().forward(input_)
def _forward_dense_optim(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism.
Implemented multiple optimization projects for dense models, such as FlashComm and
communication-computation fusion.
"""
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
layer_num = self.prefix.split('.')[2]
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
input_, layer_num != '0')
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = self.comm_group.all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
class AscendLinearBase(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
nn.Module.__init__(self)
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.skip_bias_add = skip_bias_add
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
self.quant_config = quant_config
self.prefix = prefix
if quant_config is None:
self.quant_method: Optional[
QuantizeMethodBase] = AscendUnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)
self.return_bias = return_bias
self.disable_tp = disable_tp