[Feat] Unquantized Linear to nz and control all nz-cast (#3356)

### What this PR does / why we need it?
Currently, when executing to the Linear layer of models in vLLM-Ascend,
the weights format is ND in unquantized case and skipped ascend case.
This PR supplements the execution logic for Linear layer. We use a new
global variable: VLLM_ASCEND_ENABLE_NZ. When VLLM_ASCEND_ENABLE_NZ=1 and
CANN version is 8.3, the weights of the Linear layer will be converted
to FRACTAL_NZ, in both unquantized case and skipped ascend case. We also
use VLLM_ASCEND_ENABLE_NZ to control the existing NZ conversion, such as
w8a8-quantized case.

### Does this PR introduce _any_ user-facing change?
Add a new global variable VLLM_ASCEND_ENABLE_NZ. If you want to use NZ
format, you should set VLLM_ASCEND_ENABLE_NZ=1.

### How was this patch tested?

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
anon189Ty
2025-10-14 17:39:26 +08:00
committed by GitHub
parent 5c45c227dc
commit 07e39620ea
22 changed files with 413 additions and 49 deletions

View File

@@ -17,16 +17,16 @@ This file extends the functionality of linear operations by encapsulating custom
communication groups and forward functions into classes (linear ops).
Current class inheritance structure:
CustomTensorParallelOp
CustomLinearOp
├── CustomColumnParallelOp
│ ├── MLPColumnParallelOp
│ ├── SequenceColumnParallelOp
└── CustomRowParallelOp
├── MLPRowParallelOp
├── OProjRowParallelOp
├── MatmulAllreduceRowParallelOp
└── SequenceRowParallelOp
├── MLPRowParallelOp
├── OProjRowParallelOp
├── MatmulAllreduceRowParallelOp
└── SequenceRowParallelOp
└── CustomReplicatedOp
How to extend a new linear op? Taking column parallel op as an example:
1. Inherit from CustomColumnParallelOp and create a new class MyColumnParallelOp
2. [Optional] The default communication group is the TP group. If a custom communication group is needed, override the comm_group method
@@ -52,7 +52,7 @@ from vllm_ascend.utils import (dense_optim_enable, enable_sp,
oproj_tp_enable)
class CustomTensorParallelOp:
class CustomLinearOp:
def __init__(self, layer):
self.layer = layer
@@ -95,7 +95,7 @@ class CustomTensorParallelOp:
return output, output_bias
class CustomColumnParallelOp(CustomTensorParallelOp):
class CustomColumnParallelOp(CustomLinearOp):
def __init__(self, layer):
super().__init__(layer)
@@ -106,7 +106,7 @@ class CustomColumnParallelOp(CustomTensorParallelOp):
self.gather_output = self.layer.gather_output
class CustomRowParallelOp(CustomTensorParallelOp):
class CustomRowParallelOp(CustomLinearOp):
def __init__(self, layer):
super().__init__(layer)
@@ -129,6 +129,18 @@ class CustomRowParallelOp(CustomTensorParallelOp):
return output, output_bias
class CustomReplicatedOp(CustomLinearOp):
def apply_impl(self, input_):
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self.layer, input_, bias)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class MLPColumnParallelOp(CustomColumnParallelOp):
def __init__(self, layer):
@@ -422,3 +434,11 @@ def get_parallel_op(disable_tp, prefix, layer, direct):
return custom_op, custom_op.tp_rank, custom_op.tp_size
return None, get_tp_group().rank_in_group, get_tp_group().world_size
def get_replicated_op(disable_tp, prefix,
layer) -> Optional[Union[CustomReplicatedOp]]:
if disable_tp:
return None
return CustomReplicatedOp(layer)