init v0.11.0rc0

This commit is contained in:
2025-10-14 10:38:28 +08:00
parent 67afd0ea78
commit 66dc16f966
278 changed files with 28130 additions and 11708 deletions

View File

@@ -1,45 +1,159 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
This file is a part of the vllm-ascend project.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
To customize linear communication groups or forward of classes in this file,
extend new linear operations in linear_op.py.
The classes in this file should not be modified, including AscendQKVParallelLinear,
AscendMergedColumnParallelLinear, AscendMergedColumnParallelLinear,
AscendRowParallelLinear and AscendColumnParallelLinear.
"""
from typing import Optional, Union
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
ColumnParallelLinear,
LinearBase,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.distributed import divide
from vllm.model_executor.layers.linear import ( # noqa
WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, LinearBase,
MergedColumnParallelLinear, QKVParallelLinear, QuantizeMethodBase,
RowParallelLinear, UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import \
QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.distributed.parallel_state import (
get_mlp_tensor_model_parallel_rank,
get_mlp_tensor_model_parallel_world_size, get_mlp_tp_group)
from vllm_ascend.ops.linear_op import (get_column_parallel_op,
get_row_parallel_op)
class AscendMlpColumnParallelLinear(ColumnParallelLinear):
"""Linear layer with column parallelism.
# TODO(realliujiaxu): Remove this class after linear of vllm supports custom comm group
class AscendLinearBase(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
nn.Module.__init__(self)
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.skip_bias_add = skip_bias_add
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
self.quant_config = quant_config
self.prefix = prefix
if quant_config is None:
self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)
self.return_bias = return_bias
self.disable_tp = disable_tp
class AscendQKVParallelLinear(QKVParallelLinear):
"""Linear layers for the attention's QKV transformation.
Linear layers for the linear transformation of the query, key, and value
vectors in the attention layer. The weight matrix is concatenated along
the output dimension. The layer is parallelized along the head dimension.
When the number of key/value heads is smaller than the number of query
heads (e.g., multi-query/grouped-query attention), the key/value head may
be replicated while the query heads are partitioned.
"""
def __init__(
self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: Optional[int] = None,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
self.custom_op, _, tp_size = get_column_parallel_op(
disable_tp, prefix, self)
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
self.hidden_size = hidden_size
self.head_size = head_size
self.total_num_heads = total_num_heads
if total_num_kv_heads is None:
total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
# Divide the weight matrix along the last dimension.
self.num_heads = divide(self.total_num_heads, tp_size)
if tp_size >= self.total_num_kv_heads:
self.num_kv_heads = 1
self.num_kv_head_replicas = divide(tp_size,
self.total_num_kv_heads)
else:
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
self.num_kv_head_replicas = 1
input_size = self.hidden_size
output_size = (self.num_heads +
2 * self.num_kv_heads) * tp_size * self.head_size
self.output_sizes = [
self.num_heads * self.head_size * tp_size, # q_proj
self.num_kv_heads * self.head_size * tp_size, # k_proj
self.num_kv_heads * self.head_size * tp_size, # v_proj
]
AscendColumnParallelLinear.__init__(self,
input_size=input_size,
output_size=output_size,
bias=bias,
gather_output=False,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias,
disable_tp=disable_tp)
def forward(
self,
input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.custom_op is not None:
return self.custom_op.apply(input_)
return super().forward(input_)
class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
"""Packed linear layers with column parallelism.
Similar to ColumnParallelLinear, but the weight matrix is concatenated
along the output dimension. When the weight matrix is loaded, the
different partitions are sharded separately.
Use the MLP tensor parallelism group in the MLP module,
and the original TP group in other modules.
@@ -48,73 +162,46 @@ class AscendMlpColumnParallelLinear(ColumnParallelLinear):
def __init__(
self,
input_size: int,
output_size: int,
output_sizes: list[int],
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[list[int]] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
# Divide the weight matrix along the last dimension.
if prefix.find("gate_up_proj") != -1:
self.tp_size = get_mlp_tensor_model_parallel_world_size()
self.tp_rank = get_mlp_tensor_model_parallel_rank()
self.enable_mlp_optimze = True
else:
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.enable_mlp_optimze = False
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, self.tp_size)
for output_size in self.output_sizes
]
LinearBase.__init__(self,
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias)
self.custom_op, self.tp_rank, self.tp_size = get_column_parallel_op(
disable_tp, prefix, self)
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
self.output_sizes = output_sizes
assert all(output_size % self.tp_size == 0
for output_size in output_sizes)
AscendColumnParallelLinear.__init__(self,
input_size=input_size,
output_size=sum(output_sizes),
bias=bias,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias,
disable_tp=disable_tp)
self.gather_output = gather_output
def forward(
self,
input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.custom_op is not None:
return self.custom_op.apply(input_)
if output_sizes is None:
output_sizes = [output_size]
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
return super().forward(input_)
class AscendMlpRowParallelLinear(RowParallelLinear):
class AscendRowParallelLinear(RowParallelLinear):
"""Linear layer with row parallelism.
Use the MLP tensor parallelism group in the MLP module,
and the original TP group in other modules.
@@ -133,28 +220,25 @@ class AscendMlpRowParallelLinear(RowParallelLinear):
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
if prefix.find("down_proj") != -1:
self.tp_size = get_mlp_tensor_model_parallel_world_size()
self.tp_rank = get_mlp_tensor_model_parallel_rank()
self.enable_mlp_optimze = True
else:
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.enable_mlp_optimze = False
self.custom_op, self.tp_rank, self.tp_size = get_row_parallel_op(
disable_tp, prefix, self)
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
# Divide the weight matrix along the first dimension.
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
LinearBase.__init__(self,
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias)
AscendLinearBase.__init__(self,
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias,
disable_tp=disable_tp)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
@@ -184,66 +268,22 @@ class AscendMlpRowParallelLinear(RowParallelLinear):
else:
self.register_parameter("bias", None)
if self.custom_op is not None:
self.custom_op.update_attrs()
def forward(
self,
input_,
is_prefill: bool = True,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.enable_mlp_optimze:
tp_rank = get_mlp_tensor_model_parallel_rank()
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_mlp_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0
or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
output = get_mlp_tp_group().reduce_scatter(output_parallel, 0)
# output = output[:num_tokens,:]
# dispose_tensor(output_parallel)
else:
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
if self.custom_op is not None:
return self.custom_op.apply(input_)
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0
or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
return super().forward(input_)
class AscendMlpMergedColumnParallelLinear(MergedColumnParallelLinear):
"""Packed linear layers with column parallelism.
Similar to ColumnParallelLinear, but the weight matrix is concatenated
along the output dimension. When the weight matrix is loaded, the
different partitions are sharded separately.
class AscendColumnParallelLinear(ColumnParallelLinear):
"""Linear layer with column parallelism.
Use the MLP tensor parallelism group in the MLP module,
and the original TP group in other modules.
@@ -252,58 +292,76 @@ class AscendMlpMergedColumnParallelLinear(MergedColumnParallelLinear):
def __init__(
self,
input_size: int,
output_sizes: list[int],
output_size: int,
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[list[int]] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
self.output_sizes = output_sizes
if prefix.find("gate_up_proj") != -1:
self.tp_size = get_mlp_tensor_model_parallel_world_size()
self.tp_rank = get_mlp_tensor_model_parallel_rank()
self.enable_mlp_optimze = True
self.custom_op, self.tp_rank, self.tp_size = get_column_parallel_op(
disable_tp, prefix, self)
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, self.tp_size)
for output_size in self.output_sizes
]
AscendLinearBase.__init__(self,
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias,
disable_tp=disable_tp)
self.gather_output = gather_output
if output_sizes is None:
output_sizes = [output_size]
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.enable_mlp_optimze = False
assert all(output_size % self.tp_size == 0
for output_size in output_sizes)
AscendMlpColumnParallelLinear.__init__(self,
input_size=input_size,
output_size=sum(output_sizes),
bias=bias,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias)
self.register_parameter("bias", None)
if self.custom_op is not None:
self.custom_op.update_attrs()
def forward(
self,
input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None
# self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
# Matrix multiply.
assert self.quant_method is not None
if self.enable_mlp_optimze:
input2_ = get_mlp_tp_group().all_gather(input_, 0)
output = self.quant_method.apply(self, input2_, bias)
else:
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
if self.custom_op is not None:
return self.custom_op.apply(input_)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
return super().forward(input_)