### Summary
This PR refactors the `vllm_ascend/quantization` module to improve code
organization, maintainability, and extensibility. The refactoring
introduces a clear separation of concerns with a registry-based scheme
discovery pattern, abstract base classes for quantization schemes, and
dedicated wrapper classes.
### Key Changes
#### 1. **Modular Directory Structure**
| Before | After |
|--------|-------|
| Flat file structure with mixed responsibilities | Organized into
`methods/` subpackage for schemes |
| Single `quant_config.py` (600+ lines) | Separate config files:
`modelslim_config.py`, `compressed_tensors_config.py` |
| `utils.py` with scheme lookup logic | `methods/registry.py` with
decorator-based registration |
#### 2. **Registry-Based Scheme Discovery**
Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a
decorator-based registry pattern:
```python
# Before: Manual dictionary mapping
ASCEND_QUANTIZATION_METHOD_MAP = {
"W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...},
...
}
# After: Decorator-based registration
@register_scheme("W8A8_DYNAMIC", "linear")
class AscendW8A8DynamicLinearMethod(AscendLinearScheme):
...
```
#### 3. **Abstract Base Classes**
Introduced three abstract base classes in `methods/base.py`:
- `AscendLinearScheme` - Base for linear layer quantization
- `AscendMoEScheme` - Base for MoE layer quantization
- `AscendAttentionScheme` - Base for attention layer quantization
#### 4. **Separated Config and Wrapper Classes**
- **Config classes** (`AscendModelSlimConfig`,
`AscendCompressedTensorsConfig`): Handle config parsing and scheme
selection
- **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`,
etc.): Implement vLLM interfaces and delegate to schemes
#### 5. **Cleaner Public API**
```python
# New clean module interface
from vllm_ascend.quantization import (
AscendModelSlimConfig,
AscendCompressedTensorsConfig,
)
from vllm_ascend.quantization.methods import get_scheme_class
```
### Architecture Diagram
```mermaid
classDiagram
direction TB
class QuantizationConfig {
<<vLLM Interface>>
+get_quant_method()
}
class AscendModelSlimConfig {
+quant_description
+get_quant_method()
-create_scheme_for_layer()
}
class AscendCompressedTensorsConfig {
+target_scheme_map
+get_quant_method()
-_get_scheme_from_parts()
}
class AscendLinearMethod {
<<Wrapper>>
+quant_method: AscendLinearScheme
+create_weights()
+apply()
}
class AscendFusedMoEMethod {
<<Wrapper>>
+quant_method: AscendMoEScheme
+create_weights()
+apply()
}
class AscendLinearScheme {
<<Abstract>>
+get_weight()*
+apply()*
+get_pertensor_param()
+get_perchannel_param()
}
class AscendMoEScheme {
<<Abstract>>
+get_weight()*
+get_dynamic_quant_param()*
+apply()*
}
class W8A8DynamicLinear {
+get_weight()
+apply()
}
class W8A8DynamicMoE {
+get_weight()
+apply()
}
QuantizationConfig <|-- AscendModelSlimConfig
QuantizationConfig <|-- AscendCompressedTensorsConfig
AscendModelSlimConfig ..> AscendLinearMethod : creates
AscendModelSlimConfig ..> AscendFusedMoEMethod : creates
AscendCompressedTensorsConfig ..> AscendLinearMethod : creates
AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates
AscendLinearMethod o-- AscendLinearScheme : delegates to
AscendFusedMoEMethod o-- AscendMoEScheme : delegates to
AscendLinearScheme <|-- W8A8DynamicLinear
AscendMoEScheme <|-- W8A8DynamicMoE
```
### Scheme Registration Flow
```mermaid
sequenceDiagram
participant Module as Scheme Module
participant Registry as _SCHEME_REGISTRY
participant Config as QuantConfig
participant Wrapper as Wrapper Class
Note over Module: At import time
Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear")
Registry->>Registry: Store (quant_type, layer_type) -> Class
Note over Config: At runtime
Config->>Config: Determine quant_type from description
Config->>Registry: get_scheme_class(quant_type, layer_type)
Registry-->>Config: Return scheme class
Config->>Config: scheme = scheme_cls()
Config->>Wrapper: Create wrapper with scheme
Wrapper-->>Config: Return wrapper instance
```
### File Changes Summary
| Original Files | Refactored Files |
|----------------|------------------|
| `__init__.py` (empty) | `__init__.py` (exports public API) |
| `quant_config.py` | `modelslim_config.py` + `wrappers.py` |
| `compressed_tensors/` | `compressed_tensors_config.py` |
| `utils.py` | `methods/registry.py` |
| `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` |
| `w8a8.py` | `methods/w8a8_static.py` |
| `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` |
| ... | `methods/base.py` (new) |
### Benefits
1. **Extensibility**: Adding new quantization schemes only requires
implementing the base class and adding `@register_scheme` decorator
2. **Maintainability**: Clear separation between config parsing, wrapper
logic, and scheme implementation
3. **Testability**: Abstract base classes enable easier unit testing and
mocking
4. **Discoverability**: Registry pattern makes it easy to list all
supported schemes
5. **Reduced Coupling**: Config classes no longer need to know about all
scheme implementations
___
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
182 lines
7.2 KiB
Python
182 lines
7.2 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
from typing import Any, Dict, Optional
|
|
|
|
import torch
|
|
import torch_npu
|
|
|
|
from vllm_ascend.utils import (COMPRESSED_TENSORS_METHOD, AscendDeviceType,
|
|
get_ascend_device_type,
|
|
get_weight_prefetch_method, maybe_trans_nz)
|
|
|
|
from .base import AscendLinearScheme
|
|
from .registry import register_scheme
|
|
|
|
|
|
@register_scheme("W8A8", "linear")
|
|
class AscendW8A8LinearMethod(AscendLinearScheme):
|
|
"""Linear method for Ascend W8A8 static quantization.
|
|
|
|
This scheme uses static per-tensor quantization for activations
|
|
and per-channel quantization for weights.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
def get_weight(
|
|
self,
|
|
input_size: int,
|
|
output_size: int,
|
|
params_dtype: torch.dtype = torch.bfloat16,
|
|
) -> Dict[str, Any]:
|
|
params_dict = {
|
|
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
|
|
}
|
|
return params_dict
|
|
|
|
def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
params_dict = {}
|
|
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
|
|
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
|
|
return params_dict
|
|
|
|
def get_perchannel_param(
|
|
self,
|
|
output_size: int,
|
|
params_dtype: torch.dtype,
|
|
) -> Dict[str, Any]:
|
|
params_dict = {}
|
|
params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
|
|
if params_dtype == torch.bfloat16:
|
|
params_dict["deq_scale"] = torch.empty(output_size,
|
|
dtype=torch.float32)
|
|
elif params_dtype == torch.float16:
|
|
params_dict["deq_scale"] = torch.empty(output_size,
|
|
dtype=torch.int64)
|
|
params_dict["weight_scale"] = torch.empty(output_size,
|
|
1,
|
|
dtype=params_dtype)
|
|
params_dict["weight_offset"] = torch.empty(output_size,
|
|
1,
|
|
dtype=params_dtype)
|
|
return params_dict
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
bias: Optional[torch.Tensor] = None,
|
|
tp_rank: Optional[int] = 0,
|
|
) -> torch.Tensor:
|
|
if x.dtype != torch.int8:
|
|
layer_cls_name = layer.__class__.__name__
|
|
weight_prefetch_method = get_weight_prefetch_method()
|
|
# prefetch qkvo_proj.weight preprocess
|
|
if weight_prefetch_method:
|
|
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(
|
|
layer_cls_name=layer_cls_name,
|
|
weight=layer.weight,
|
|
start_flag=x,
|
|
)
|
|
try:
|
|
quant_comm_config = getattr(layer, "_quant_comm_config")
|
|
except AttributeError:
|
|
quant_comm_config = {}
|
|
comm_fn = quant_comm_config.get("communication_fn")
|
|
enable_flashcomm2_quant_comm = comm_fn is not None and (
|
|
"o_proj" in layer.prefix or "out_proj" in layer.prefix)
|
|
if enable_flashcomm2_quant_comm:
|
|
quant_input_x = x.contiguous().view(
|
|
-1, layer.aclnn_input_scale_reciprocal.size(0))
|
|
quant_x = torch.ops.vllm.quantize(
|
|
quant_input_x,
|
|
layer.aclnn_input_scale,
|
|
layer.aclnn_input_scale_reciprocal,
|
|
layer.aclnn_input_offset,
|
|
)
|
|
comm_input = quant_x.view(x.size(0), -1)
|
|
assert comm_fn is not None
|
|
x = comm_fn(comm_input)
|
|
else:
|
|
# quant
|
|
x = torch.ops.vllm.quantize(
|
|
x,
|
|
layer.aclnn_input_scale,
|
|
layer.aclnn_input_scale_reciprocal,
|
|
layer.aclnn_input_offset,
|
|
)
|
|
|
|
# prefetch qkvo_proj.weight postprocess
|
|
if weight_prefetch_method:
|
|
weight_prefetch_method.maybe_prefetch_attn_weight_postprocess(
|
|
layer_cls_name=layer_cls_name,
|
|
stop_flag=x,
|
|
)
|
|
|
|
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
|
|
|
try:
|
|
ascend_quant_method = getattr(layer, "ascend_quant_method")
|
|
except AttributeError:
|
|
ascend_quant_method = ""
|
|
if ascend_quant_method == COMPRESSED_TENSORS_METHOD:
|
|
quant_bias = bias
|
|
|
|
if get_ascend_device_type() == AscendDeviceType._310P:
|
|
# On 300I Duo platform, we need transpose again if
|
|
# using nz. This transpose can be skipped in torchair.
|
|
output = torch_npu.npu_quant_matmul(
|
|
x,
|
|
layer.weight.data.transpose(1, 0),
|
|
layer.deq_scale,
|
|
bias=quant_bias,
|
|
output_dtype=layer.params_dtype,
|
|
)
|
|
else:
|
|
output = torch_npu.npu_quant_matmul(
|
|
x,
|
|
layer.weight,
|
|
layer.deq_scale,
|
|
bias=quant_bias,
|
|
output_dtype=layer.params_dtype,
|
|
)
|
|
return output
|
|
|
|
def process_weights_after_loading(self, layer):
|
|
expanding_factor = layer.weight.data.shape[1]
|
|
layer.aclnn_input_scale = torch.nn.Parameter(
|
|
layer.input_scale.data.repeat(expanding_factor),
|
|
requires_grad=False)
|
|
layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(
|
|
layer.input_scale.data.repeat(expanding_factor),
|
|
requires_grad=False)
|
|
layer.aclnn_input_offset = torch.nn.Parameter(
|
|
layer.input_offset.data.repeat(expanding_factor),
|
|
requires_grad=False).to(layer.aclnn_input_scale.dtype)
|
|
if get_ascend_device_type() != AscendDeviceType._310P:
|
|
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
|
layer.weight.data = maybe_trans_nz(layer.weight.data)
|
|
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
|
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
|
ascend_quant_method = getattr(layer, "ascend_quant_method", "")
|
|
if ascend_quant_method == COMPRESSED_TENSORS_METHOD:
|
|
deq_scale = layer.input_scale.data * layer.weight_scale.data
|
|
layer.deq_scale = torch.nn.Parameter(deq_scale,
|
|
requires_grad=False)
|