[Feat] Unquantized Linear to nz and control all nz-cast (#3356)
### What this PR does / why we need it? Currently, when executing to the Linear layer of models in vLLM-Ascend, the weights format is ND in unquantized case and skipped ascend case. This PR supplements the execution logic for Linear layer. We use a new global variable: VLLM_ASCEND_ENABLE_NZ. When VLLM_ASCEND_ENABLE_NZ=1 and CANN version is 8.3, the weights of the Linear layer will be converted to FRACTAL_NZ, in both unquantized case and skipped ascend case. We also use VLLM_ASCEND_ENABLE_NZ to control the existing NZ conversion, such as w8a8-quantized case. ### Does this PR introduce _any_ user-facing change? Add a new global variable VLLM_ASCEND_ENABLE_NZ. If you want to use NZ format, you should set VLLM_ASCEND_ENABLE_NZ=1. ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
@@ -17,16 +17,16 @@ This file extends the functionality of linear operations by encapsulating custom
|
||||
communication groups and forward functions into classes (linear ops).
|
||||
|
||||
Current class inheritance structure:
|
||||
CustomTensorParallelOp
|
||||
CustomLinearOp
|
||||
├── CustomColumnParallelOp
|
||||
│ ├── MLPColumnParallelOp
|
||||
│ ├── SequenceColumnParallelOp
|
||||
└── CustomRowParallelOp
|
||||
├── MLPRowParallelOp
|
||||
├── OProjRowParallelOp
|
||||
├── MatmulAllreduceRowParallelOp
|
||||
└── SequenceRowParallelOp
|
||||
|
||||
│ ├── MLPRowParallelOp
|
||||
│ ├── OProjRowParallelOp
|
||||
│ ├── MatmulAllreduceRowParallelOp
|
||||
│ └── SequenceRowParallelOp
|
||||
└── CustomReplicatedOp
|
||||
How to extend a new linear op? Taking column parallel op as an example:
|
||||
1. Inherit from CustomColumnParallelOp and create a new class MyColumnParallelOp
|
||||
2. [Optional] The default communication group is the TP group. If a custom communication group is needed, override the comm_group method
|
||||
@@ -52,7 +52,7 @@ from vllm_ascend.utils import (dense_optim_enable, enable_sp,
|
||||
oproj_tp_enable)
|
||||
|
||||
|
||||
class CustomTensorParallelOp:
|
||||
class CustomLinearOp:
|
||||
|
||||
def __init__(self, layer):
|
||||
self.layer = layer
|
||||
@@ -95,7 +95,7 @@ class CustomTensorParallelOp:
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class CustomColumnParallelOp(CustomTensorParallelOp):
|
||||
class CustomColumnParallelOp(CustomLinearOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
@@ -106,7 +106,7 @@ class CustomColumnParallelOp(CustomTensorParallelOp):
|
||||
self.gather_output = self.layer.gather_output
|
||||
|
||||
|
||||
class CustomRowParallelOp(CustomTensorParallelOp):
|
||||
class CustomRowParallelOp(CustomLinearOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
@@ -129,6 +129,18 @@ class CustomRowParallelOp(CustomTensorParallelOp):
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class CustomReplicatedOp(CustomLinearOp):
|
||||
|
||||
def apply_impl(self, input_):
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
assert self.quant_method is not None
|
||||
|
||||
output = self.quant_method.apply(self.layer, input_, bias)
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class MLPColumnParallelOp(CustomColumnParallelOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
@@ -422,3 +434,11 @@ def get_parallel_op(disable_tp, prefix, layer, direct):
|
||||
return custom_op, custom_op.tp_rank, custom_op.tp_size
|
||||
|
||||
return None, get_tp_group().rank_in_group, get_tp_group().world_size
|
||||
|
||||
|
||||
def get_replicated_op(disable_tp, prefix,
|
||||
layer) -> Optional[Union[CustomReplicatedOp]]:
|
||||
if disable_tp:
|
||||
return None
|
||||
|
||||
return CustomReplicatedOp(layer)
|
||||
|
||||
Reference in New Issue
Block a user