[Refactor] Quantization Module Refactor (#5738)
### Summary
This PR refactors the `vllm_ascend/quantization` module to improve code
organization, maintainability, and extensibility. The refactoring
introduces a clear separation of concerns with a registry-based scheme
discovery pattern, abstract base classes for quantization schemes, and
dedicated wrapper classes.
### Key Changes
#### 1. **Modular Directory Structure**
| Before | After |
|--------|-------|
| Flat file structure with mixed responsibilities | Organized into
`methods/` subpackage for schemes |
| Single `quant_config.py` (600+ lines) | Separate config files:
`modelslim_config.py`, `compressed_tensors_config.py` |
| `utils.py` with scheme lookup logic | `methods/registry.py` with
decorator-based registration |
#### 2. **Registry-Based Scheme Discovery**
Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a
decorator-based registry pattern:
```python
# Before: Manual dictionary mapping
ASCEND_QUANTIZATION_METHOD_MAP = {
"W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...},
...
}
# After: Decorator-based registration
@register_scheme("W8A8_DYNAMIC", "linear")
class AscendW8A8DynamicLinearMethod(AscendLinearScheme):
...
```
#### 3. **Abstract Base Classes**
Introduced three abstract base classes in `methods/base.py`:
- `AscendLinearScheme` - Base for linear layer quantization
- `AscendMoEScheme` - Base for MoE layer quantization
- `AscendAttentionScheme` - Base for attention layer quantization
#### 4. **Separated Config and Wrapper Classes**
- **Config classes** (`AscendModelSlimConfig`,
`AscendCompressedTensorsConfig`): Handle config parsing and scheme
selection
- **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`,
etc.): Implement vLLM interfaces and delegate to schemes
#### 5. **Cleaner Public API**
```python
# New clean module interface
from vllm_ascend.quantization import (
AscendModelSlimConfig,
AscendCompressedTensorsConfig,
)
from vllm_ascend.quantization.methods import get_scheme_class
```
### Architecture Diagram
```mermaid
classDiagram
direction TB
class QuantizationConfig {
<<vLLM Interface>>
+get_quant_method()
}
class AscendModelSlimConfig {
+quant_description
+get_quant_method()
-create_scheme_for_layer()
}
class AscendCompressedTensorsConfig {
+target_scheme_map
+get_quant_method()
-_get_scheme_from_parts()
}
class AscendLinearMethod {
<<Wrapper>>
+quant_method: AscendLinearScheme
+create_weights()
+apply()
}
class AscendFusedMoEMethod {
<<Wrapper>>
+quant_method: AscendMoEScheme
+create_weights()
+apply()
}
class AscendLinearScheme {
<<Abstract>>
+get_weight()*
+apply()*
+get_pertensor_param()
+get_perchannel_param()
}
class AscendMoEScheme {
<<Abstract>>
+get_weight()*
+get_dynamic_quant_param()*
+apply()*
}
class W8A8DynamicLinear {
+get_weight()
+apply()
}
class W8A8DynamicMoE {
+get_weight()
+apply()
}
QuantizationConfig <|-- AscendModelSlimConfig
QuantizationConfig <|-- AscendCompressedTensorsConfig
AscendModelSlimConfig ..> AscendLinearMethod : creates
AscendModelSlimConfig ..> AscendFusedMoEMethod : creates
AscendCompressedTensorsConfig ..> AscendLinearMethod : creates
AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates
AscendLinearMethod o-- AscendLinearScheme : delegates to
AscendFusedMoEMethod o-- AscendMoEScheme : delegates to
AscendLinearScheme <|-- W8A8DynamicLinear
AscendMoEScheme <|-- W8A8DynamicMoE
```
### Scheme Registration Flow
```mermaid
sequenceDiagram
participant Module as Scheme Module
participant Registry as _SCHEME_REGISTRY
participant Config as QuantConfig
participant Wrapper as Wrapper Class
Note over Module: At import time
Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear")
Registry->>Registry: Store (quant_type, layer_type) -> Class
Note over Config: At runtime
Config->>Config: Determine quant_type from description
Config->>Registry: get_scheme_class(quant_type, layer_type)
Registry-->>Config: Return scheme class
Config->>Config: scheme = scheme_cls()
Config->>Wrapper: Create wrapper with scheme
Wrapper-->>Config: Return wrapper instance
```
### File Changes Summary
| Original Files | Refactored Files |
|----------------|------------------|
| `__init__.py` (empty) | `__init__.py` (exports public API) |
| `quant_config.py` | `modelslim_config.py` + `wrappers.py` |
| `compressed_tensors/` | `compressed_tensors_config.py` |
| `utils.py` | `methods/registry.py` |
| `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` |
| `w8a8.py` | `methods/w8a8_static.py` |
| `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` |
| ... | `methods/base.py` (new) |
### Benefits
1. **Extensibility**: Adding new quantization schemes only requires
implementing the base class and adding `@register_scheme` decorator
2. **Maintainability**: Clear separation between config parsing, wrapper
logic, and scheme implementation
3. **Testability**: Abstract base classes enable easier unit testing and
mocking
4. **Discoverability**: Registry pattern makes it easy to list all
supported schemes
5. **Reduced Coupling**: Config classes no longer need to know about all
scheme implementations
___
- vLLM version: v0.13.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d
---------
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
2026-01-23 14:13:47 +08:00
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any , Callable , Dict , Optional
import torch
import torch_npu
from vllm . config import get_current_vllm_config
from vllm . forward_context import get_forward_context
from vllm_ascend . ascend_config import get_ascend_config
from vllm_ascend . ops . fused_moe . experts_selector import select_experts
from . base import AscendMoEScheme
from . registry import register_scheme
def unpack_from_int32 (
weight : torch . Tensor ,
shape : torch . Size ,
num_bits : int ,
packed_dim : int = 1 ,
) - > torch . Tensor :
""" Unpacks quantized weights from int32 format back to original bits.
: param weight : The packed int32 tensor containing quantized weights
: param shape : Original shape to restore , defaults to None
: param num_bits : The number of bits used for quantization ( < = 8 )
: param packed_dim : Dimension along which weights are packed ( 0 or 1 ) , defaults to 1
: return : Unpacked tensor with int8 dtype after applying offset correction
"""
assert weight . dtype == torch . int32 , f " Expecting `weight.dtype` is torch.int32 but got { weight . dtype } . "
assert num_bits < = 8 , f " Expecting `num_bits` should not be larger than 8 but got { num_bits } . "
pack_factor = 32 / / num_bits
mask = ( 1 << num_bits ) - 1
if packed_dim == 1 :
unpacked_weight = torch . zeros (
( weight . shape [ 0 ] , weight . shape [ 1 ] * pack_factor ) ,
device = weight . device ,
dtype = torch . int32 ,
)
for i in range ( pack_factor ) :
unpacked_weight [ : , i : : pack_factor ] = ( weight >>
( num_bits * i ) ) & mask
original_row_size = int ( shape [ 1 ] )
unpacked_weight = unpacked_weight [ : , : original_row_size ]
else :
unpacked_weight = torch . zeros (
( weight . shape [ 0 ] * pack_factor , weight . shape [ 1 ] ) ,
device = weight . device ,
dtype = torch . int32 ,
)
for i in range ( pack_factor ) :
unpacked_weight [ i : : pack_factor , : ] = ( weight >>
( num_bits * i ) ) & mask
original_row_size = int ( shape [ 0 ] )
unpacked_weight = unpacked_weight [ : original_row_size , : ]
offset = pow ( 2 , num_bits ) / / 2
unpacked_weight = ( unpacked_weight - offset ) . to ( torch . int8 )
return unpacked_weight
def pack_to_int32 ( weight : torch . Tensor ) - > torch . Tensor :
""" Packs quantized weights into int32 format for storage.
: param weight : The 3 D tensor to pack , must be int8 or int32 dtype
: return : Packed tensor with int32 dtype optimized for storage
"""
assert weight . dim (
) == 3 , f " Expecting `weight.dim()` is 3 ([e, n, k] or [e, k, n]) but got { weight . dim ( ) } . "
assert weight . dtype in [
torch . int8 , torch . int32
] , f " Expecting `weight.dtype` is torch.int8 or torch.int32 bug got { weight . dtype } . "
if weight . dtype == torch . int32 :
assert weight . shape [
- 1 ] % 8 == 0 , " the last dim of weight needs to be divided by 8. "
packed_weight = torch_npu . npu_convert_weight_to_int4pack (
weight . flatten ( 0 , 1 ) )
packed_weight = packed_weight . view ( weight . shape [ 0 ] , weight . shape [ 1 ] ,
- 1 )
else :
assert weight . shape [
- 1 ] % 4 == 0 , " the last dim of weight needs to be divided by 4. "
packed_weight = weight . view ( torch . int32 ) . contiguous ( )
return packed_weight
@register_scheme ( " W4A16 " , " moe " )
class AscendW4A16FusedMoEMethod ( AscendMoEScheme ) :
""" FusedMoE method for Ascend W4A16. """
def __init__ ( self ) - > None :
self . transpose_weight = True
self . num_bits = 4 # dtype = torch.int4
self . pack_factor = 8 # pack 8 of torch.int4 tensors to torch.int32
vllm_config = get_current_vllm_config ( )
self . group_size = vllm_config . quant_config . quant_description . get (
" group_size " , 32 )
self . dynamic_eplb = get_ascend_config ( ) . eplb_config . dynamic_eplb
def get_weight (
self ,
num_experts : int ,
intermediate_size_per_partition : int ,
hidden_sizes : int ,
params_dtype : torch . dtype ,
) - > Dict [ str , Any ] :
assert intermediate_size_per_partition % self . pack_factor == 0 , f " Expecting `intermediate_size_per_partition` { intermediate_size_per_partition } can be divided by `pack_factor` { self . pack_factor } "
assert hidden_sizes % self . pack_factor == 0 , f " Expecting `hidden_sizes` { hidden_sizes } can be divided by `pack_factor` { self . pack_factor } "
param_dict = { }
param_dict [ " w13_weight_packed " ] = torch . empty (
num_experts ,
2 * intermediate_size_per_partition ,
hidden_sizes / / self . pack_factor ,
dtype = torch . int32 )
param_dict [ " w2_weight_packed " ] = torch . empty (
num_experts ,
hidden_sizes ,
intermediate_size_per_partition / / self . pack_factor ,
dtype = torch . int32 )
return param_dict
def get_dynamic_quant_param (
self ,
num_experts : int ,
intermediate_size_per_partition : int ,
hidden_sizes : int ,
params_dtype : torch . dtype ,
) - > Dict [ str , Any ] :
assert intermediate_size_per_partition % self . group_size == 0 , f " Expecting `intermediate_size_per_partition` { intermediate_size_per_partition } can be divided by `group_size` { self . group_size } "
assert hidden_sizes % self . group_size == 0 , f " Expecting `hidden_sizes` { hidden_sizes } can be divided by `group_size` { self . group_size } "
param_dict = { }
param_dict [ " w13_weight_scale " ] = torch . empty (
num_experts ,
2 * intermediate_size_per_partition ,
hidden_sizes / / self . group_size ,
dtype = torch . bfloat16 )
param_dict [ " w2_weight_scale " ] = torch . empty (
num_experts ,
hidden_sizes ,
intermediate_size_per_partition / / self . group_size ,
dtype = torch . bfloat16 )
param_dict [ " w13_weight_shape " ] = torch . empty ( num_experts ,
2 ,
dtype = torch . int32 )
param_dict [ " w2_weight_shape " ] = torch . empty ( num_experts ,
2 ,
dtype = torch . int32 )
param_dict [ " w13_weight_offset " ] = torch . zeros (
num_experts ,
2 * intermediate_size_per_partition ,
hidden_sizes / / self . group_size ,
dtype = torch . bfloat16 )
param_dict [ " w2_weight_offset " ] = torch . zeros (
num_experts ,
hidden_sizes ,
intermediate_size_per_partition / / self . group_size ,
dtype = torch . bfloat16 )
return param_dict
def apply (
self ,
layer : torch . nn . Module ,
x : torch . Tensor ,
router_logits : torch . Tensor ,
top_k : int ,
renormalize : bool ,
use_grouped_topk : bool = False ,
global_num_experts : int = - 1 ,
expert_map : Optional [ torch . Tensor ] = None ,
topk_group : Optional [ int ] = None ,
num_expert_group : Optional [ int ] = None ,
custom_routing_function : Optional [ Callable ] = None ,
scoring_func : str = " softmax " ,
routed_scaling_factor : float = 1.0 ,
e_score_correction_bias : Optional [ torch . Tensor ] = None ,
is_prefill : bool = True ,
enable_force_load_balance : bool = True ,
log2phy : Optional [ torch . Tensor ] = None ,
global_redundant_expert_num : int = 0 ,
* * kwargs ,
) - > torch . Tensor :
assert router_logits . shape [
1 ] == global_num_experts - global_redundant_expert_num , " Number of global experts mismatch (excluding redundancy) "
topk_weights , topk_ids = select_experts (
hidden_states = x ,
router_logits = router_logits ,
top_k = top_k ,
use_grouped_topk = use_grouped_topk ,
renormalize = renormalize ,
topk_group = topk_group ,
num_expert_group = num_expert_group ,
custom_routing_function = custom_routing_function ,
scoring_func = scoring_func ,
e_score_correction_bias = e_score_correction_bias ,
global_num_experts = global_num_experts )
topk_ids = topk_ids . to ( torch . int32 )
topk_weights = topk_weights . to ( x . dtype )
moe_comm_method = get_forward_context ( ) . moe_comm_method
return moe_comm_method . fused_experts (
hidden_states = x ,
w1 = layer . w13_weight_packed ,
w2 = layer . w2_weight_packed ,
w1_scale = layer . w13_weight_scale ,
w2_scale = layer . w2_weight_scale ,
w1_offset = layer . w13_weight_offset ,
w2_offset = layer . w2_weight_offset ,
topk_weights = topk_weights ,
topk_ids = topk_ids ,
use_int4_w4a16 = True ,
expert_map = expert_map ,
log2phy = log2phy ,
dynamic_eplb = self . dynamic_eplb ,
mc2_mask = kwargs . get ( " mc2_mask " , None ) )
def process_weights_after_loading ( self , layer : torch . nn . Module ) - > None :
if self . transpose_weight :
w13_shape = layer . w13_weight_packed . data . shape
w2_shape = layer . w2_weight_packed . data . shape
unpacked_w13_weight = ( unpack_from_int32 (
layer . w13_weight_packed . data . flatten ( 0 , 1 ) ,
torch . Size ( [
w13_shape [ 0 ] * w13_shape [ 1 ] ,
w13_shape [ 2 ] * self . pack_factor
] ) ,
self . num_bits ,
) . view ( w13_shape [ 0 ] , w13_shape [ 1 ] ,
- 1 ) . transpose ( 1 , 2 ) . contiguous ( ) . int ( ) )
unpacked_w2_weight = ( unpack_from_int32 (
layer . w2_weight_packed . data . flatten ( 0 , 1 ) ,
torch . Size ( [
w2_shape [ 0 ] * w2_shape [ 1 ] , w2_shape [ 2 ] * self . pack_factor
] ) ,
self . num_bits ,
) . view ( w2_shape [ 0 ] , w2_shape [ 1 ] ,
- 1 ) . transpose ( 1 , 2 ) . contiguous ( ) . int ( ) )
layer . w13_weight_packed . data = pack_to_int32 ( unpacked_w13_weight )
layer . w2_weight_packed . data = pack_to_int32 ( unpacked_w2_weight )
layer . w13_weight_scale . data = layer . w13_weight_scale . data . transpose (
1 , 2 ) . contiguous ( )
layer . w2_weight_scale . data = layer . w2_weight_scale . data . transpose (
1 , 2 ) . contiguous ( )
layer . w13_weight_offset . data = layer . w13_weight_offset . data . transpose (
1 , 2 ) . contiguous ( )
layer . w2_weight_offset . data = layer . w2_weight_offset . data . transpose (
1 , 2 ) . contiguous ( )