[Feat] Unquantized Linear to nz and control all nz-cast (#3356)

### What this PR does / why we need it?
Currently, when executing to the Linear layer of models in vLLM-Ascend,
the weights format is ND in unquantized case and skipped ascend case.
This PR supplements the execution logic for Linear layer. We use a new
global variable: VLLM_ASCEND_ENABLE_NZ. When VLLM_ASCEND_ENABLE_NZ=1 and
CANN version is 8.3, the weights of the Linear layer will be converted
to FRACTAL_NZ, in both unquantized case and skipped ascend case. We also
use VLLM_ASCEND_ENABLE_NZ to control the existing NZ conversion, such as
w8a8-quantized case.

### Does this PR introduce _any_ user-facing change?
Add a new global variable VLLM_ASCEND_ENABLE_NZ. If you want to use NZ
format, you should set VLLM_ASCEND_ENABLE_NZ=1.

### How was this patch tested?

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
anon189Ty
2025-10-14 17:39:26 +08:00
committed by GitHub
parent 5c45c227dc
commit 07e39620ea
22 changed files with 413 additions and 49 deletions

View File

@@ -24,17 +24,29 @@ from typing import Optional, Union
import torch
import torch.nn as nn
import torch_npu
from torch.nn.parameter import Parameter
from vllm.distributed import divide
from vllm.model_executor.layers.linear import ( # noqa
WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, LinearBase,
MergedColumnParallelLinear, QKVParallelLinear, QuantizeMethodBase,
RowParallelLinear, UnquantizedLinearMethod)
ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import \
QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.ops.linear_op import get_parallel_op
from vllm_ascend.ops.linear_op import get_parallel_op, get_replicated_op
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
"""Linear method without quantization"""
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
if is_enable_nz() and torch.version.cann.startswith("8.3"):
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
# TODO(realliujiaxu): Remove this class after linear of vllm supports custom comm group
@@ -65,7 +77,7 @@ class AscendLinearBase(LinearBase):
self.prefix = prefix
if quant_config is None:
self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod()
QuantizeMethodBase] = AscendUnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)
@@ -364,3 +376,81 @@ class AscendColumnParallelLinear(ColumnParallelLinear):
return self.custom_op.apply(input_)
return super().forward(input_)
class AscendReplicatedLinear(ReplicatedLinear):
"""Ascend Replicated linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: Take no effect for replicated linear layers.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
self.custom_op = get_replicated_op(disable_tp, prefix, self)
# If MergedReplicatedLinear, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = self.output_sizes
else:
self.output_partition_sizes = [output_size]
AscendLinearBase.__init__(self,
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix,
return_bias=return_bias,
disable_tp=disable_tp)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size, [self.output_size],
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader)
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
if self.custom_op is not None:
self.custom_op.update_attrs()
def forward(
self,
input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.custom_op is not None:
return self.custom_op.apply(input_)
return super().forward(input_)