### Summary
This PR refactors the `vllm_ascend/quantization` module to improve code
organization, maintainability, and extensibility. The refactoring
introduces a clear separation of concerns with a registry-based scheme
discovery pattern, abstract base classes for quantization schemes, and
dedicated wrapper classes.
### Key Changes
#### 1. **Modular Directory Structure**
| Before | After |
|--------|-------|
| Flat file structure with mixed responsibilities | Organized into
`methods/` subpackage for schemes |
| Single `quant_config.py` (600+ lines) | Separate config files:
`modelslim_config.py`, `compressed_tensors_config.py` |
| `utils.py` with scheme lookup logic | `methods/registry.py` with
decorator-based registration |
#### 2. **Registry-Based Scheme Discovery**
Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a
decorator-based registry pattern:
```python
# Before: Manual dictionary mapping
ASCEND_QUANTIZATION_METHOD_MAP = {
"W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...},
...
}
# After: Decorator-based registration
@register_scheme("W8A8_DYNAMIC", "linear")
class AscendW8A8DynamicLinearMethod(AscendLinearScheme):
...
```
#### 3. **Abstract Base Classes**
Introduced three abstract base classes in `methods/base.py`:
- `AscendLinearScheme` - Base for linear layer quantization
- `AscendMoEScheme` - Base for MoE layer quantization
- `AscendAttentionScheme` - Base for attention layer quantization
#### 4. **Separated Config and Wrapper Classes**
- **Config classes** (`AscendModelSlimConfig`,
`AscendCompressedTensorsConfig`): Handle config parsing and scheme
selection
- **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`,
etc.): Implement vLLM interfaces and delegate to schemes
#### 5. **Cleaner Public API**
```python
# New clean module interface
from vllm_ascend.quantization import (
AscendModelSlimConfig,
AscendCompressedTensorsConfig,
)
from vllm_ascend.quantization.methods import get_scheme_class
```
### Architecture Diagram
```mermaid
classDiagram
direction TB
class QuantizationConfig {
<<vLLM Interface>>
+get_quant_method()
}
class AscendModelSlimConfig {
+quant_description
+get_quant_method()
-create_scheme_for_layer()
}
class AscendCompressedTensorsConfig {
+target_scheme_map
+get_quant_method()
-_get_scheme_from_parts()
}
class AscendLinearMethod {
<<Wrapper>>
+quant_method: AscendLinearScheme
+create_weights()
+apply()
}
class AscendFusedMoEMethod {
<<Wrapper>>
+quant_method: AscendMoEScheme
+create_weights()
+apply()
}
class AscendLinearScheme {
<<Abstract>>
+get_weight()*
+apply()*
+get_pertensor_param()
+get_perchannel_param()
}
class AscendMoEScheme {
<<Abstract>>
+get_weight()*
+get_dynamic_quant_param()*
+apply()*
}
class W8A8DynamicLinear {
+get_weight()
+apply()
}
class W8A8DynamicMoE {
+get_weight()
+apply()
}
QuantizationConfig <|-- AscendModelSlimConfig
QuantizationConfig <|-- AscendCompressedTensorsConfig
AscendModelSlimConfig ..> AscendLinearMethod : creates
AscendModelSlimConfig ..> AscendFusedMoEMethod : creates
AscendCompressedTensorsConfig ..> AscendLinearMethod : creates
AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates
AscendLinearMethod o-- AscendLinearScheme : delegates to
AscendFusedMoEMethod o-- AscendMoEScheme : delegates to
AscendLinearScheme <|-- W8A8DynamicLinear
AscendMoEScheme <|-- W8A8DynamicMoE
```
### Scheme Registration Flow
```mermaid
sequenceDiagram
participant Module as Scheme Module
participant Registry as _SCHEME_REGISTRY
participant Config as QuantConfig
participant Wrapper as Wrapper Class
Note over Module: At import time
Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear")
Registry->>Registry: Store (quant_type, layer_type) -> Class
Note over Config: At runtime
Config->>Config: Determine quant_type from description
Config->>Registry: get_scheme_class(quant_type, layer_type)
Registry-->>Config: Return scheme class
Config->>Config: scheme = scheme_cls()
Config->>Wrapper: Create wrapper with scheme
Wrapper-->>Config: Return wrapper instance
```
### File Changes Summary
| Original Files | Refactored Files |
|----------------|------------------|
| `__init__.py` (empty) | `__init__.py` (exports public API) |
| `quant_config.py` | `modelslim_config.py` + `wrappers.py` |
| `compressed_tensors/` | `compressed_tensors_config.py` |
| `utils.py` | `methods/registry.py` |
| `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` |
| `w8a8.py` | `methods/w8a8_static.py` |
| `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` |
| ... | `methods/base.py` (new) |
### Benefits
1. **Extensibility**: Adding new quantization schemes only requires
implementing the base class and adding `@register_scheme` decorator
2. **Maintainability**: Clear separation between config parsing, wrapper
logic, and scheme implementation
3. **Testability**: Abstract base classes enable easier unit testing and
mocking
4. **Discoverability**: Registry pattern makes it easy to list all
supported schemes
5. **Reduced Coupling**: Config classes no longer need to know about all
scheme implementations
___
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
163 lines
7.7 KiB
Python
163 lines
7.7 KiB
Python
from unittest.mock import MagicMock, patch
|
|
|
|
from vllm.attention.layer import Attention
|
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
|
from vllm.model_executor.layers.linear import LinearBase
|
|
|
|
from tests.ut.base import TestBase
|
|
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
|
|
from vllm_ascend.quantization.modelslim_config import AscendModelSlimConfig
|
|
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
|
|
|
|
|
|
class TestAscendModelSlimConfig(TestBase):
|
|
|
|
def setUp(self):
|
|
self.sample_config = {
|
|
"weight": "INT8",
|
|
"fa_quant_type": "C8",
|
|
"layer1.weight": "INT8",
|
|
"layer2.weight": "FLOAT",
|
|
"fused_layer.weight": "FLOAT",
|
|
"fused_layer.shard1.weight": "FLOAT",
|
|
"fused_layer.shard2.weight": "FLOAT",
|
|
"shard1.weight": "FLOAT",
|
|
"shard2.weight": "FLOAT",
|
|
}
|
|
self.ascend_config = AscendModelSlimConfig(self.sample_config)
|
|
self.ascend_config.packed_modules_mapping = None
|
|
|
|
def test_init(self):
|
|
self.assertEqual(self.ascend_config.quant_description,
|
|
self.sample_config)
|
|
|
|
def test_repr(self):
|
|
repr_str = repr(self.ascend_config)
|
|
self.assertTrue(repr_str.startswith("AscendModelSlimConfig:\n"))
|
|
|
|
def test_get_name(self):
|
|
self.assertEqual(AscendModelSlimConfig.get_name(),
|
|
ASCEND_QUANTIZATION_METHOD)
|
|
|
|
def test_get_supported_act_dtypes(self):
|
|
supported_dtypes = AscendModelSlimConfig.get_supported_act_dtypes()
|
|
self.assertEqual(len(supported_dtypes), 3)
|
|
|
|
def test_get_min_capability(self):
|
|
with self.assertRaises(NotImplementedError):
|
|
AscendModelSlimConfig.get_min_capability()
|
|
|
|
def test_get_config_filenames(self):
|
|
filenames = AscendModelSlimConfig.get_config_filenames()
|
|
self.assertEqual(filenames, ["quant_model_description.json"])
|
|
|
|
def test_from_config(self):
|
|
config = AscendModelSlimConfig.from_config(self.sample_config)
|
|
self.assertIsInstance(config, AscendModelSlimConfig)
|
|
self.assertEqual(config.quant_description, self.sample_config)
|
|
|
|
@patch('torch.npu.is_available')
|
|
def test_override_quantization_method(self, mock_is_available):
|
|
# Test when NPU is available
|
|
mock_is_available.return_value = True
|
|
result = AscendModelSlimConfig.override_quantization_method(None, None)
|
|
self.assertIsNone(result)
|
|
hf_quant_cfg = {"quant_method": ""}
|
|
result = AscendModelSlimConfig.override_quantization_method(
|
|
hf_quant_cfg, None)
|
|
self.assertEqual(result, "ascend")
|
|
|
|
# Test when NPU is not available
|
|
mock_is_available.return_value = False
|
|
result = AscendModelSlimConfig.override_quantization_method(None, None)
|
|
self.assertIsNone(result)
|
|
hf_quant_cfg = {"quant_method": ""}
|
|
result = AscendModelSlimConfig.override_quantization_method(
|
|
hf_quant_cfg, None)
|
|
self.assertIsNone(result)
|
|
|
|
def test_get_quant_method_for_linear(self):
|
|
mock_config = MagicMock()
|
|
mock_config.model_config.hf_config.model_type = None
|
|
linear_layer = MagicMock(spec=LinearBase)
|
|
# Test skipped layer
|
|
with patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \
|
|
patch.object(self.ascend_config, \
|
|
'is_layer_skipped_ascend',
|
|
return_value=True):
|
|
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
|
|
self.assertIsInstance(method, AscendUnquantizedLinearMethod)
|
|
|
|
# Test quantized layer
|
|
mock_scheme = MagicMock()
|
|
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
|
|
patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \
|
|
patch("vllm_ascend.quantization.modelslim_config.create_scheme_for_layer", return_value=mock_scheme), \
|
|
patch('vllm_ascend.quantization.method_adapters.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear:
|
|
|
|
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
|
|
self.assertIs(method, mock_ascend_linear.return_value)
|
|
mock_ascend_linear.assert_called_once_with(mock_scheme)
|
|
|
|
def test_get_quant_method_for_attention(self):
|
|
attention_layer = MagicMock(spec=Attention)
|
|
mock_config = MagicMock()
|
|
mock_config.model_config.hf_config.model_type = None
|
|
mock_scheme = MagicMock()
|
|
with patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \
|
|
patch("vllm_ascend.quantization.modelslim_config.create_scheme_for_layer", return_value=mock_scheme), \
|
|
patch('vllm_ascend.quantization.method_adapters.AscendKVCacheMethod', \
|
|
return_value=MagicMock()) as mock_ascend_kvcache:
|
|
# Test with fa_quant_type
|
|
method = self.ascend_config.get_quant_method(
|
|
attention_layer, ".attn")
|
|
self.assertIs(method, mock_ascend_kvcache.return_value)
|
|
|
|
def test_get_quant_method_for_fused_moe(self):
|
|
fused_moe_layer = MagicMock(spec=FusedMoE)
|
|
fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig)
|
|
fused_moe_layer.moe_config = MagicMock(spec=FusedMoEConfig)
|
|
mock_config = MagicMock()
|
|
mock_config.model_config.hf_config.model_type = None
|
|
|
|
# Test skipped layer
|
|
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \
|
|
patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \
|
|
patch('vllm_ascend.ops.fused_moe.fused_moe.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
|
|
method = self.ascend_config.get_quant_method(
|
|
fused_moe_layer, "moe_layer")
|
|
self.assertIs(method, mock_ascend_moe.return_value)
|
|
|
|
# Test quantized layer
|
|
mock_scheme = MagicMock()
|
|
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
|
|
patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \
|
|
patch("vllm_ascend.quantization.modelslim_config.create_scheme_for_layer", return_value=mock_scheme), \
|
|
patch('vllm_ascend.quantization.method_adapters.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
|
|
method = self.ascend_config.get_quant_method(
|
|
fused_moe_layer, "moe_layer")
|
|
self.assertIs(method, mock_ascend_moe.return_value)
|
|
|
|
def test_is_layer_skipped_ascend(self):
|
|
# Test non-fused layer that should be quantized
|
|
self.assertFalse(self.ascend_config.is_layer_skipped_ascend("layer1"))
|
|
|
|
# Test non-fused layer that should be skipped
|
|
self.assertTrue(self.ascend_config.is_layer_skipped_ascend("layer2"))
|
|
|
|
# Test fused layer
|
|
fused_mapping = {"fused_layer": ["shard1", "shard2"]}
|
|
self.assertTrue(
|
|
self.ascend_config.is_layer_skipped_ascend("fused_layer",
|
|
fused_mapping))
|
|
|
|
# Test inconsistent fused layer shards
|
|
bad_config = {"shard1.weight": "FLOAT", "shard2.weight": "INT8"}
|
|
config = AscendModelSlimConfig(bad_config)
|
|
with self.assertRaises(ValueError):
|
|
config.is_layer_skipped_ascend("fused_layer", fused_mapping)
|
|
|
|
def test_get_scaled_act_names(self):
|
|
self.assertEqual(self.ascend_config.get_scaled_act_names(), [])
|