[Refactor] Quantization Module Refactor (#5738)

### Summary

This PR refactors the `vllm_ascend/quantization` module to improve code
organization, maintainability, and extensibility. The refactoring
introduces a clear separation of concerns with a registry-based scheme
discovery pattern, abstract base classes for quantization schemes, and
dedicated wrapper classes.

### Key Changes

#### 1. **Modular Directory Structure**

| Before | After |
|--------|-------|
| Flat file structure with mixed responsibilities | Organized into
`methods/` subpackage for schemes |
| Single `quant_config.py` (600+ lines) | Separate config files:
`modelslim_config.py`, `compressed_tensors_config.py` |
| `utils.py` with scheme lookup logic | `methods/registry.py` with
decorator-based registration |

#### 2. **Registry-Based Scheme Discovery**

Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a
decorator-based registry pattern:

```python
# Before: Manual dictionary mapping
ASCEND_QUANTIZATION_METHOD_MAP = {
    "W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...},
    ...
}

# After: Decorator-based registration
@register_scheme("W8A8_DYNAMIC", "linear")
class AscendW8A8DynamicLinearMethod(AscendLinearScheme):
    ...
```

#### 3. **Abstract Base Classes**

Introduced three abstract base classes in `methods/base.py`:
- `AscendLinearScheme` - Base for linear layer quantization
- `AscendMoEScheme` - Base for MoE layer quantization  
- `AscendAttentionScheme` - Base for attention layer quantization

#### 4. **Separated Config and Wrapper Classes**

- **Config classes** (`AscendModelSlimConfig`,
`AscendCompressedTensorsConfig`): Handle config parsing and scheme
selection
- **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`,
etc.): Implement vLLM interfaces and delegate to schemes

#### 5. **Cleaner Public API**

```python
# New clean module interface
from vllm_ascend.quantization import (
    AscendModelSlimConfig,
    AscendCompressedTensorsConfig,
)
from vllm_ascend.quantization.methods import get_scheme_class
```

### Architecture Diagram

```mermaid
classDiagram
    direction TB
    
    class QuantizationConfig {
        <<vLLM Interface>>
        +get_quant_method()
    }
    
    class AscendModelSlimConfig {
        +quant_description
        +get_quant_method()
        -create_scheme_for_layer()
    }
    
    class AscendCompressedTensorsConfig {
        +target_scheme_map
        +get_quant_method()
        -_get_scheme_from_parts()
    }
    
    class AscendLinearMethod {
        <<Wrapper>>
        +quant_method: AscendLinearScheme
        +create_weights()
        +apply()
    }
    
    class AscendFusedMoEMethod {
        <<Wrapper>>
        +quant_method: AscendMoEScheme
        +create_weights()
        +apply()
    }
    
    class AscendLinearScheme {
        <<Abstract>>
        +get_weight()*
        +apply()*
        +get_pertensor_param()
        +get_perchannel_param()
    }
    
    class AscendMoEScheme {
        <<Abstract>>
        +get_weight()*
        +get_dynamic_quant_param()*
        +apply()*
    }
    
    class W8A8DynamicLinear {
        +get_weight()
        +apply()
    }
    
    class W8A8DynamicMoE {
        +get_weight()
        +apply()
    }
    
    QuantizationConfig <|-- AscendModelSlimConfig
    QuantizationConfig <|-- AscendCompressedTensorsConfig
    
    AscendModelSlimConfig ..> AscendLinearMethod : creates
    AscendModelSlimConfig ..> AscendFusedMoEMethod : creates
    AscendCompressedTensorsConfig ..> AscendLinearMethod : creates
    AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates
    
    AscendLinearMethod o-- AscendLinearScheme : delegates to
    AscendFusedMoEMethod o-- AscendMoEScheme : delegates to
    
    AscendLinearScheme <|-- W8A8DynamicLinear
    AscendMoEScheme <|-- W8A8DynamicMoE
```

### Scheme Registration Flow

```mermaid
sequenceDiagram
    participant Module as Scheme Module
    participant Registry as _SCHEME_REGISTRY
    participant Config as QuantConfig
    participant Wrapper as Wrapper Class
    
    Note over Module: At import time
    Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear")
    Registry->>Registry: Store (quant_type, layer_type) -> Class
    
    Note over Config: At runtime
    Config->>Config: Determine quant_type from description
    Config->>Registry: get_scheme_class(quant_type, layer_type)
    Registry-->>Config: Return scheme class
    Config->>Config: scheme = scheme_cls()
    Config->>Wrapper: Create wrapper with scheme
    Wrapper-->>Config: Return wrapper instance
```

### File Changes Summary

| Original Files | Refactored Files |
|----------------|------------------|
| `__init__.py` (empty) | `__init__.py` (exports public API) |
| `quant_config.py` | `modelslim_config.py` + `wrappers.py` |
| `compressed_tensors/` | `compressed_tensors_config.py` |
| `utils.py` | `methods/registry.py` |
| `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` |
| `w8a8.py` | `methods/w8a8_static.py` |
| `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` |
| ... | `methods/base.py` (new) |

### Benefits

1. **Extensibility**: Adding new quantization schemes only requires
implementing the base class and adding `@register_scheme` decorator
2. **Maintainability**: Clear separation between config parsing, wrapper
logic, and scheme implementation
3. **Testability**: Abstract base classes enable easier unit testing and
mocking
4. **Discoverability**: Registry pattern makes it easy to list all
supported schemes
5. **Reduced Coupling**: Config classes no longer need to know about all
scheme implementations

___

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
Cao Yi
2026-01-23 14:13:47 +08:00
committed by GitHub
parent 8378bc28b0
commit a69ef10c3a
36 changed files with 2044 additions and 1524 deletions

View File

@@ -3,14 +3,14 @@ from unittest.mock import Mock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.quantization.w4a8_dynamic import (
from vllm_ascend.quantization.methods.w4a8 import (
AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod)
class TestAscendW4A8DynamicLinearMethod(TestBase):
@patch('vllm.distributed.get_tensor_model_parallel_world_size')
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
@patch('vllm_ascend.quantization.methods.w4a8.get_current_vllm_config')
def setUp(self, mock_get_current_vllm_config, mock_get_tp_world_size):
mock_get_tp_world_size.return_value = 1
mock_vllm_config = Mock()
@@ -127,10 +127,10 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
output_size = 56
group_size = 2
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ascend_config')
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group')
@patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group')
@patch('vllm_ascend.quantization.methods.w4a8.get_ascend_config')
@patch('vllm_ascend.quantization.methods.w4a8.get_current_vllm_config')
@patch('vllm_ascend.quantization.methods.w4a8.get_ep_group')
@patch('vllm_ascend.quantization.methods.w4a8.get_mc2_group')
@patch('torch.distributed.get_rank', return_value=0)
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ep_group,
get_current_vllm_config, mock_get_ascend_config):