### Summary
This PR refactors the `vllm_ascend/quantization` module to improve code
organization, maintainability, and extensibility. The refactoring
introduces a clear separation of concerns with a registry-based scheme
discovery pattern, abstract base classes for quantization schemes, and
dedicated wrapper classes.
### Key Changes
#### 1. **Modular Directory Structure**
| Before | After |
|--------|-------|
| Flat file structure with mixed responsibilities | Organized into
`methods/` subpackage for schemes |
| Single `quant_config.py` (600+ lines) | Separate config files:
`modelslim_config.py`, `compressed_tensors_config.py` |
| `utils.py` with scheme lookup logic | `methods/registry.py` with
decorator-based registration |
#### 2. **Registry-Based Scheme Discovery**
Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a
decorator-based registry pattern:
```python
# Before: Manual dictionary mapping
ASCEND_QUANTIZATION_METHOD_MAP = {
"W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...},
...
}
# After: Decorator-based registration
@register_scheme("W8A8_DYNAMIC", "linear")
class AscendW8A8DynamicLinearMethod(AscendLinearScheme):
...
```
#### 3. **Abstract Base Classes**
Introduced three abstract base classes in `methods/base.py`:
- `AscendLinearScheme` - Base for linear layer quantization
- `AscendMoEScheme` - Base for MoE layer quantization
- `AscendAttentionScheme` - Base for attention layer quantization
#### 4. **Separated Config and Wrapper Classes**
- **Config classes** (`AscendModelSlimConfig`,
`AscendCompressedTensorsConfig`): Handle config parsing and scheme
selection
- **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`,
etc.): Implement vLLM interfaces and delegate to schemes
#### 5. **Cleaner Public API**
```python
# New clean module interface
from vllm_ascend.quantization import (
AscendModelSlimConfig,
AscendCompressedTensorsConfig,
)
from vllm_ascend.quantization.methods import get_scheme_class
```
### Architecture Diagram
```mermaid
classDiagram
direction TB
class QuantizationConfig {
<<vLLM Interface>>
+get_quant_method()
}
class AscendModelSlimConfig {
+quant_description
+get_quant_method()
-create_scheme_for_layer()
}
class AscendCompressedTensorsConfig {
+target_scheme_map
+get_quant_method()
-_get_scheme_from_parts()
}
class AscendLinearMethod {
<<Wrapper>>
+quant_method: AscendLinearScheme
+create_weights()
+apply()
}
class AscendFusedMoEMethod {
<<Wrapper>>
+quant_method: AscendMoEScheme
+create_weights()
+apply()
}
class AscendLinearScheme {
<<Abstract>>
+get_weight()*
+apply()*
+get_pertensor_param()
+get_perchannel_param()
}
class AscendMoEScheme {
<<Abstract>>
+get_weight()*
+get_dynamic_quant_param()*
+apply()*
}
class W8A8DynamicLinear {
+get_weight()
+apply()
}
class W8A8DynamicMoE {
+get_weight()
+apply()
}
QuantizationConfig <|-- AscendModelSlimConfig
QuantizationConfig <|-- AscendCompressedTensorsConfig
AscendModelSlimConfig ..> AscendLinearMethod : creates
AscendModelSlimConfig ..> AscendFusedMoEMethod : creates
AscendCompressedTensorsConfig ..> AscendLinearMethod : creates
AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates
AscendLinearMethod o-- AscendLinearScheme : delegates to
AscendFusedMoEMethod o-- AscendMoEScheme : delegates to
AscendLinearScheme <|-- W8A8DynamicLinear
AscendMoEScheme <|-- W8A8DynamicMoE
```
### Scheme Registration Flow
```mermaid
sequenceDiagram
participant Module as Scheme Module
participant Registry as _SCHEME_REGISTRY
participant Config as QuantConfig
participant Wrapper as Wrapper Class
Note over Module: At import time
Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear")
Registry->>Registry: Store (quant_type, layer_type) -> Class
Note over Config: At runtime
Config->>Config: Determine quant_type from description
Config->>Registry: get_scheme_class(quant_type, layer_type)
Registry-->>Config: Return scheme class
Config->>Config: scheme = scheme_cls()
Config->>Wrapper: Create wrapper with scheme
Wrapper-->>Config: Return wrapper instance
```
### File Changes Summary
| Original Files | Refactored Files |
|----------------|------------------|
| `__init__.py` (empty) | `__init__.py` (exports public API) |
| `quant_config.py` | `modelslim_config.py` + `wrappers.py` |
| `compressed_tensors/` | `compressed_tensors_config.py` |
| `utils.py` | `methods/registry.py` |
| `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` |
| `w8a8.py` | `methods/w8a8_static.py` |
| `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` |
| ... | `methods/base.py` (new) |
### Benefits
1. **Extensibility**: Adding new quantization schemes only requires
implementing the base class and adding `@register_scheme` decorator
2. **Maintainability**: Clear separation between config parsing, wrapper
logic, and scheme implementation
3. **Testability**: Abstract base classes enable easier unit testing and
mocking
4. **Discoverability**: Registry pattern makes it easy to list all
supported schemes
5. **Reduced Coupling**: Config classes no longer need to know about all
scheme implementations
___
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
280 lines
9.4 KiB
Python
280 lines
9.4 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
"""Abstract base classes for Ascend quantization schemes."""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from enum import Enum
|
|
from typing import Any, Callable, Dict, Optional
|
|
|
|
import torch
|
|
|
|
|
|
class QuantType(Enum):
|
|
"""Quantization type enum for MoE schemes."""
|
|
NONE = 0
|
|
W8A8 = 1
|
|
W4A8 = 2
|
|
|
|
|
|
class AscendLinearScheme(ABC):
|
|
"""Base class for all linear quantization schemes.
|
|
|
|
Subclasses must implement get_weight() and apply() methods.
|
|
Other methods have default implementations that return empty dicts
|
|
or do nothing.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def get_weight(self, input_size: int, output_size: int,
|
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
"""Return weight tensor specifications.
|
|
|
|
Args:
|
|
input_size: Input dimension of the linear layer.
|
|
output_size: Output dimension of the linear layer.
|
|
params_dtype: Data type for parameters.
|
|
|
|
Returns:
|
|
Dictionary mapping parameter names to empty tensors with
|
|
the correct shape and dtype.
|
|
"""
|
|
...
|
|
|
|
def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
"""Return per-tensor parameter specifications (e.g., input_scale).
|
|
|
|
Args:
|
|
params_dtype: Data type for parameters.
|
|
|
|
Returns:
|
|
Dictionary mapping parameter names to empty tensors.
|
|
"""
|
|
return {}
|
|
|
|
def get_perchannel_param(self, output_size: int,
|
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
"""Return per-channel parameter specifications (e.g., weight_scale).
|
|
|
|
Args:
|
|
output_size: Output dimension of the linear layer.
|
|
params_dtype: Data type for parameters.
|
|
|
|
Returns:
|
|
Dictionary mapping parameter names to empty tensors.
|
|
"""
|
|
return {}
|
|
|
|
def get_pergroup_param(self,
|
|
input_size: int,
|
|
output_size: int,
|
|
params_dtype: torch.dtype,
|
|
layer_type: Optional[str] = None) -> Dict[str, Any]:
|
|
"""Return per-group parameter specifications.
|
|
|
|
Args:
|
|
input_size: Input dimension of the linear layer.
|
|
output_size: Output dimension of the linear layer.
|
|
params_dtype: Data type for parameters.
|
|
layer_type: Type of layer (e.g., "row" for RowParallelLinear).
|
|
|
|
Returns:
|
|
Dictionary mapping parameter names to empty tensors.
|
|
"""
|
|
return {}
|
|
|
|
@abstractmethod
|
|
def apply(self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
bias: Optional[torch.Tensor] = None,
|
|
tp_rank: Optional[int] = 0) -> torch.Tensor:
|
|
"""Forward computation.
|
|
|
|
Args:
|
|
layer: The linear layer module.
|
|
x: Input tensor.
|
|
bias: Optional bias tensor.
|
|
tp_rank: Tensor parallel rank.
|
|
|
|
Returns:
|
|
Output tensor after quantized linear operation.
|
|
"""
|
|
...
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
"""Post-loading weight processing (transpose, format conversion, etc.).
|
|
|
|
Args:
|
|
layer: The linear layer module.
|
|
"""
|
|
pass
|
|
|
|
|
|
class AscendAttentionScheme(ABC):
|
|
"""Base class for all attention quantization schemes.
|
|
|
|
Subclasses must implement apply() method.
|
|
Other methods have default implementations.
|
|
"""
|
|
|
|
def create_weights(self, layer: torch.nn.Module) -> None:
|
|
"""Create weights for attention quantization.
|
|
|
|
Args:
|
|
layer: The attention layer module.
|
|
"""
|
|
pass
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
"""Post-loading weight processing for attention layer.
|
|
|
|
Args:
|
|
layer: The attention layer module.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
|
|
key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata,
|
|
attn_type, scale, output) -> torch.Tensor:
|
|
"""Forward computation for attention layer.
|
|
|
|
Args:
|
|
layer: The attention layer module.
|
|
query: Query tensor.
|
|
key: Key tensor.
|
|
value: Value tensor.
|
|
kv_cache: KV cache.
|
|
attn_metadata: Attention metadata.
|
|
attn_type: Attention type.
|
|
scale: Scale factor.
|
|
output: Output tensor.
|
|
|
|
Returns:
|
|
Output tensor after attention computation.
|
|
"""
|
|
...
|
|
|
|
|
|
class AscendMoEScheme(ABC):
|
|
"""Base class for all MoE quantization schemes.
|
|
|
|
Subclasses must implement get_weight(), get_dynamic_quant_param(),
|
|
and apply() methods.
|
|
|
|
Attributes:
|
|
quant_type: The quantization type for this scheme. Subclasses should
|
|
override this class attribute to declare their quant type.
|
|
"""
|
|
|
|
# Default quant type - subclasses should override this
|
|
quant_type: QuantType = QuantType.NONE
|
|
|
|
@abstractmethod
|
|
def get_weight(self, num_experts: int,
|
|
intermediate_size_per_partition: int, hidden_sizes: int,
|
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
"""Return weight tensor specifications for MoE layer.
|
|
|
|
Args:
|
|
num_experts: Number of experts.
|
|
intermediate_size_per_partition: Intermediate size per partition.
|
|
hidden_sizes: Hidden dimension size.
|
|
params_dtype: Data type for parameters.
|
|
|
|
Returns:
|
|
Dictionary mapping parameter names to empty tensors.
|
|
"""
|
|
...
|
|
|
|
@abstractmethod
|
|
def get_dynamic_quant_param(self, num_experts: int,
|
|
intermediate_size_per_partition: int,
|
|
hidden_sizes: int,
|
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
"""Return dynamic quantization parameters for MoE layer.
|
|
|
|
Args:
|
|
num_experts: Number of experts.
|
|
intermediate_size_per_partition: Intermediate size per partition.
|
|
hidden_sizes: Hidden dimension size.
|
|
params_dtype: Data type for parameters.
|
|
|
|
Returns:
|
|
Dictionary mapping parameter names to empty tensors.
|
|
"""
|
|
...
|
|
|
|
@abstractmethod
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
renormalize: bool,
|
|
use_grouped_topk: bool = False,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
routed_scaling_factor: float = 1.0,
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
is_prefill: bool = True,
|
|
enable_force_load_balance: bool = False,
|
|
log2phy: Optional[torch.Tensor] = None,
|
|
global_redundant_expert_num: int = 0,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
"""Forward computation for MoE layer.
|
|
|
|
Args:
|
|
layer: The MoE layer module.
|
|
x: Input hidden states.
|
|
router_logits: Router logits for expert selection.
|
|
top_k: Number of experts to select per token.
|
|
renormalize: Whether to renormalize expert weights.
|
|
use_grouped_topk: Whether to use grouped top-k selection.
|
|
global_num_experts: Total number of experts globally.
|
|
expert_map: Mapping from local to global expert indices.
|
|
topk_group: Group size for grouped top-k.
|
|
num_expert_group: Number of expert groups.
|
|
custom_routing_function: Custom routing function.
|
|
scoring_func: Scoring function name.
|
|
routed_scaling_factor: Scaling factor for routed experts.
|
|
e_score_correction_bias: Expert score correction bias.
|
|
is_prefill: Whether in prefill phase.
|
|
enable_force_load_balance: Whether to force load balancing.
|
|
log2phy: Logical to physical expert mapping.
|
|
global_redundant_expert_num: Number of redundant experts.
|
|
**kwargs: Additional keyword arguments.
|
|
|
|
Returns:
|
|
Output tensor after MoE computation.
|
|
"""
|
|
...
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
"""Post-loading weight processing for MoE layer.
|
|
|
|
Args:
|
|
layer: The MoE layer module.
|
|
"""
|
|
pass
|