Files
xc-llm-ascend/vllm_ascend/quantization/methods/base.py
Cao Yi a69ef10c3a [Refactor] Quantization Module Refactor (#5738)
### Summary

This PR refactors the `vllm_ascend/quantization` module to improve code
organization, maintainability, and extensibility. The refactoring
introduces a clear separation of concerns with a registry-based scheme
discovery pattern, abstract base classes for quantization schemes, and
dedicated wrapper classes.

### Key Changes

#### 1. **Modular Directory Structure**

| Before | After |
|--------|-------|
| Flat file structure with mixed responsibilities | Organized into
`methods/` subpackage for schemes |
| Single `quant_config.py` (600+ lines) | Separate config files:
`modelslim_config.py`, `compressed_tensors_config.py` |
| `utils.py` with scheme lookup logic | `methods/registry.py` with
decorator-based registration |

#### 2. **Registry-Based Scheme Discovery**

Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a
decorator-based registry pattern:

```python
# Before: Manual dictionary mapping
ASCEND_QUANTIZATION_METHOD_MAP = {
    "W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...},
    ...
}

# After: Decorator-based registration
@register_scheme("W8A8_DYNAMIC", "linear")
class AscendW8A8DynamicLinearMethod(AscendLinearScheme):
    ...
```

#### 3. **Abstract Base Classes**

Introduced three abstract base classes in `methods/base.py`:
- `AscendLinearScheme` - Base for linear layer quantization
- `AscendMoEScheme` - Base for MoE layer quantization  
- `AscendAttentionScheme` - Base for attention layer quantization

#### 4. **Separated Config and Wrapper Classes**

- **Config classes** (`AscendModelSlimConfig`,
`AscendCompressedTensorsConfig`): Handle config parsing and scheme
selection
- **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`,
etc.): Implement vLLM interfaces and delegate to schemes

#### 5. **Cleaner Public API**

```python
# New clean module interface
from vllm_ascend.quantization import (
    AscendModelSlimConfig,
    AscendCompressedTensorsConfig,
)
from vllm_ascend.quantization.methods import get_scheme_class
```

### Architecture Diagram

```mermaid
classDiagram
    direction TB
    
    class QuantizationConfig {
        <<vLLM Interface>>
        +get_quant_method()
    }
    
    class AscendModelSlimConfig {
        +quant_description
        +get_quant_method()
        -create_scheme_for_layer()
    }
    
    class AscendCompressedTensorsConfig {
        +target_scheme_map
        +get_quant_method()
        -_get_scheme_from_parts()
    }
    
    class AscendLinearMethod {
        <<Wrapper>>
        +quant_method: AscendLinearScheme
        +create_weights()
        +apply()
    }
    
    class AscendFusedMoEMethod {
        <<Wrapper>>
        +quant_method: AscendMoEScheme
        +create_weights()
        +apply()
    }
    
    class AscendLinearScheme {
        <<Abstract>>
        +get_weight()*
        +apply()*
        +get_pertensor_param()
        +get_perchannel_param()
    }
    
    class AscendMoEScheme {
        <<Abstract>>
        +get_weight()*
        +get_dynamic_quant_param()*
        +apply()*
    }
    
    class W8A8DynamicLinear {
        +get_weight()
        +apply()
    }
    
    class W8A8DynamicMoE {
        +get_weight()
        +apply()
    }
    
    QuantizationConfig <|-- AscendModelSlimConfig
    QuantizationConfig <|-- AscendCompressedTensorsConfig
    
    AscendModelSlimConfig ..> AscendLinearMethod : creates
    AscendModelSlimConfig ..> AscendFusedMoEMethod : creates
    AscendCompressedTensorsConfig ..> AscendLinearMethod : creates
    AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates
    
    AscendLinearMethod o-- AscendLinearScheme : delegates to
    AscendFusedMoEMethod o-- AscendMoEScheme : delegates to
    
    AscendLinearScheme <|-- W8A8DynamicLinear
    AscendMoEScheme <|-- W8A8DynamicMoE
```

### Scheme Registration Flow

```mermaid
sequenceDiagram
    participant Module as Scheme Module
    participant Registry as _SCHEME_REGISTRY
    participant Config as QuantConfig
    participant Wrapper as Wrapper Class
    
    Note over Module: At import time
    Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear")
    Registry->>Registry: Store (quant_type, layer_type) -> Class
    
    Note over Config: At runtime
    Config->>Config: Determine quant_type from description
    Config->>Registry: get_scheme_class(quant_type, layer_type)
    Registry-->>Config: Return scheme class
    Config->>Config: scheme = scheme_cls()
    Config->>Wrapper: Create wrapper with scheme
    Wrapper-->>Config: Return wrapper instance
```

### File Changes Summary

| Original Files | Refactored Files |
|----------------|------------------|
| `__init__.py` (empty) | `__init__.py` (exports public API) |
| `quant_config.py` | `modelslim_config.py` + `wrappers.py` |
| `compressed_tensors/` | `compressed_tensors_config.py` |
| `utils.py` | `methods/registry.py` |
| `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` |
| `w8a8.py` | `methods/w8a8_static.py` |
| `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` |
| ... | `methods/base.py` (new) |

### Benefits

1. **Extensibility**: Adding new quantization schemes only requires
implementing the base class and adding `@register_scheme` decorator
2. **Maintainability**: Clear separation between config parsing, wrapper
logic, and scheme implementation
3. **Testability**: Abstract base classes enable easier unit testing and
mocking
4. **Discoverability**: Registry pattern makes it easy to list all
supported schemes
5. **Reduced Coupling**: Config classes no longer need to know about all
scheme implementations

___

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
2026-01-23 14:13:47 +08:00

280 lines
9.4 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Abstract base classes for Ascend quantization schemes."""
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Callable, Dict, Optional
import torch
class QuantType(Enum):
"""Quantization type enum for MoE schemes."""
NONE = 0
W8A8 = 1
W4A8 = 2
class AscendLinearScheme(ABC):
"""Base class for all linear quantization schemes.
Subclasses must implement get_weight() and apply() methods.
Other methods have default implementations that return empty dicts
or do nothing.
"""
@abstractmethod
def get_weight(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
"""Return weight tensor specifications.
Args:
input_size: Input dimension of the linear layer.
output_size: Output dimension of the linear layer.
params_dtype: Data type for parameters.
Returns:
Dictionary mapping parameter names to empty tensors with
the correct shape and dtype.
"""
...
def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]:
"""Return per-tensor parameter specifications (e.g., input_scale).
Args:
params_dtype: Data type for parameters.
Returns:
Dictionary mapping parameter names to empty tensors.
"""
return {}
def get_perchannel_param(self, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
"""Return per-channel parameter specifications (e.g., weight_scale).
Args:
output_size: Output dimension of the linear layer.
params_dtype: Data type for parameters.
Returns:
Dictionary mapping parameter names to empty tensors.
"""
return {}
def get_pergroup_param(self,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
layer_type: Optional[str] = None) -> Dict[str, Any]:
"""Return per-group parameter specifications.
Args:
input_size: Input dimension of the linear layer.
output_size: Output dimension of the linear layer.
params_dtype: Data type for parameters.
layer_type: Type of layer (e.g., "row" for RowParallelLinear).
Returns:
Dictionary mapping parameter names to empty tensors.
"""
return {}
@abstractmethod
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0) -> torch.Tensor:
"""Forward computation.
Args:
layer: The linear layer module.
x: Input tensor.
bias: Optional bias tensor.
tp_rank: Tensor parallel rank.
Returns:
Output tensor after quantized linear operation.
"""
...
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Post-loading weight processing (transpose, format conversion, etc.).
Args:
layer: The linear layer module.
"""
pass
class AscendAttentionScheme(ABC):
"""Base class for all attention quantization schemes.
Subclasses must implement apply() method.
Other methods have default implementations.
"""
def create_weights(self, layer: torch.nn.Module) -> None:
"""Create weights for attention quantization.
Args:
layer: The attention layer module.
"""
pass
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Post-loading weight processing for attention layer.
Args:
layer: The attention layer module.
"""
pass
@abstractmethod
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata,
attn_type, scale, output) -> torch.Tensor:
"""Forward computation for attention layer.
Args:
layer: The attention layer module.
query: Query tensor.
key: Key tensor.
value: Value tensor.
kv_cache: KV cache.
attn_metadata: Attention metadata.
attn_type: Attention type.
scale: Scale factor.
output: Output tensor.
Returns:
Output tensor after attention computation.
"""
...
class AscendMoEScheme(ABC):
"""Base class for all MoE quantization schemes.
Subclasses must implement get_weight(), get_dynamic_quant_param(),
and apply() methods.
Attributes:
quant_type: The quantization type for this scheme. Subclasses should
override this class attribute to declare their quant type.
"""
# Default quant type - subclasses should override this
quant_type: QuantType = QuantType.NONE
@abstractmethod
def get_weight(self, num_experts: int,
intermediate_size_per_partition: int, hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
"""Return weight tensor specifications for MoE layer.
Args:
num_experts: Number of experts.
intermediate_size_per_partition: Intermediate size per partition.
hidden_sizes: Hidden dimension size.
params_dtype: Data type for parameters.
Returns:
Dictionary mapping parameter names to empty tensors.
"""
...
@abstractmethod
def get_dynamic_quant_param(self, num_experts: int,
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
"""Return dynamic quantization parameters for MoE layer.
Args:
num_experts: Number of experts.
intermediate_size_per_partition: Intermediate size per partition.
hidden_sizes: Hidden dimension size.
params_dtype: Data type for parameters.
Returns:
Dictionary mapping parameter names to empty tensors.
"""
...
@abstractmethod
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = False,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
**kwargs,
) -> torch.Tensor:
"""Forward computation for MoE layer.
Args:
layer: The MoE layer module.
x: Input hidden states.
router_logits: Router logits for expert selection.
top_k: Number of experts to select per token.
renormalize: Whether to renormalize expert weights.
use_grouped_topk: Whether to use grouped top-k selection.
global_num_experts: Total number of experts globally.
expert_map: Mapping from local to global expert indices.
topk_group: Group size for grouped top-k.
num_expert_group: Number of expert groups.
custom_routing_function: Custom routing function.
scoring_func: Scoring function name.
routed_scaling_factor: Scaling factor for routed experts.
e_score_correction_bias: Expert score correction bias.
is_prefill: Whether in prefill phase.
enable_force_load_balance: Whether to force load balancing.
log2phy: Logical to physical expert mapping.
global_redundant_expert_num: Number of redundant experts.
**kwargs: Additional keyword arguments.
Returns:
Output tensor after MoE computation.
"""
...
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Post-loading weight processing for MoE layer.
Args:
layer: The MoE layer module.
"""
pass