### Summary
This PR refactors the `vllm_ascend/quantization` module to improve code
organization, maintainability, and extensibility. The refactoring
introduces a clear separation of concerns with a registry-based scheme
discovery pattern, abstract base classes for quantization schemes, and
dedicated wrapper classes.
### Key Changes
#### 1. **Modular Directory Structure**
| Before | After |
|--------|-------|
| Flat file structure with mixed responsibilities | Organized into
`methods/` subpackage for schemes |
| Single `quant_config.py` (600+ lines) | Separate config files:
`modelslim_config.py`, `compressed_tensors_config.py` |
| `utils.py` with scheme lookup logic | `methods/registry.py` with
decorator-based registration |
#### 2. **Registry-Based Scheme Discovery**
Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a
decorator-based registry pattern:
```python
# Before: Manual dictionary mapping
ASCEND_QUANTIZATION_METHOD_MAP = {
"W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...},
...
}
# After: Decorator-based registration
@register_scheme("W8A8_DYNAMIC", "linear")
class AscendW8A8DynamicLinearMethod(AscendLinearScheme):
...
```
#### 3. **Abstract Base Classes**
Introduced three abstract base classes in `methods/base.py`:
- `AscendLinearScheme` - Base for linear layer quantization
- `AscendMoEScheme` - Base for MoE layer quantization
- `AscendAttentionScheme` - Base for attention layer quantization
#### 4. **Separated Config and Wrapper Classes**
- **Config classes** (`AscendModelSlimConfig`,
`AscendCompressedTensorsConfig`): Handle config parsing and scheme
selection
- **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`,
etc.): Implement vLLM interfaces and delegate to schemes
#### 5. **Cleaner Public API**
```python
# New clean module interface
from vllm_ascend.quantization import (
AscendModelSlimConfig,
AscendCompressedTensorsConfig,
)
from vllm_ascend.quantization.methods import get_scheme_class
```
### Architecture Diagram
```mermaid
classDiagram
direction TB
class QuantizationConfig {
<<vLLM Interface>>
+get_quant_method()
}
class AscendModelSlimConfig {
+quant_description
+get_quant_method()
-create_scheme_for_layer()
}
class AscendCompressedTensorsConfig {
+target_scheme_map
+get_quant_method()
-_get_scheme_from_parts()
}
class AscendLinearMethod {
<<Wrapper>>
+quant_method: AscendLinearScheme
+create_weights()
+apply()
}
class AscendFusedMoEMethod {
<<Wrapper>>
+quant_method: AscendMoEScheme
+create_weights()
+apply()
}
class AscendLinearScheme {
<<Abstract>>
+get_weight()*
+apply()*
+get_pertensor_param()
+get_perchannel_param()
}
class AscendMoEScheme {
<<Abstract>>
+get_weight()*
+get_dynamic_quant_param()*
+apply()*
}
class W8A8DynamicLinear {
+get_weight()
+apply()
}
class W8A8DynamicMoE {
+get_weight()
+apply()
}
QuantizationConfig <|-- AscendModelSlimConfig
QuantizationConfig <|-- AscendCompressedTensorsConfig
AscendModelSlimConfig ..> AscendLinearMethod : creates
AscendModelSlimConfig ..> AscendFusedMoEMethod : creates
AscendCompressedTensorsConfig ..> AscendLinearMethod : creates
AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates
AscendLinearMethod o-- AscendLinearScheme : delegates to
AscendFusedMoEMethod o-- AscendMoEScheme : delegates to
AscendLinearScheme <|-- W8A8DynamicLinear
AscendMoEScheme <|-- W8A8DynamicMoE
```
### Scheme Registration Flow
```mermaid
sequenceDiagram
participant Module as Scheme Module
participant Registry as _SCHEME_REGISTRY
participant Config as QuantConfig
participant Wrapper as Wrapper Class
Note over Module: At import time
Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear")
Registry->>Registry: Store (quant_type, layer_type) -> Class
Note over Config: At runtime
Config->>Config: Determine quant_type from description
Config->>Registry: get_scheme_class(quant_type, layer_type)
Registry-->>Config: Return scheme class
Config->>Config: scheme = scheme_cls()
Config->>Wrapper: Create wrapper with scheme
Wrapper-->>Config: Return wrapper instance
```
### File Changes Summary
| Original Files | Refactored Files |
|----------------|------------------|
| `__init__.py` (empty) | `__init__.py` (exports public API) |
| `quant_config.py` | `modelslim_config.py` + `wrappers.py` |
| `compressed_tensors/` | `compressed_tensors_config.py` |
| `utils.py` | `methods/registry.py` |
| `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` |
| `w8a8.py` | `methods/w8a8_static.py` |
| `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` |
| ... | `methods/base.py` (new) |
### Benefits
1. **Extensibility**: Adding new quantization schemes only requires
implementing the base class and adding `@register_scheme` decorator
2. **Maintainability**: Clear separation between config parsing, wrapper
logic, and scheme implementation
3. **Testability**: Abstract base classes enable easier unit testing and
mocking
4. **Discoverability**: Registry pattern makes it easy to list all
supported schemes
5. **Reduced Coupling**: Config classes no longer need to know about all
scheme implementations
___
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
195 lines
7.8 KiB
Python
195 lines
7.8 KiB
Python
import os
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import torch
|
|
|
|
from tests.ut.base import TestBase
|
|
from vllm_ascend.quantization.methods.w8a8_static import AscendW8A8LinearMethod
|
|
from vllm_ascend.utils import AscendDeviceType
|
|
|
|
|
|
class TestAscendW8A8LinearMethod(TestBase):
|
|
|
|
def setUp(self):
|
|
self.method = AscendW8A8LinearMethod()
|
|
|
|
def test_get_weight(self):
|
|
weight = self.method.get_weight(10, 20)
|
|
self.assertEqual(weight['weight'].dtype, torch.int8)
|
|
self.assertEqual(weight['weight'].shape, (20, 10))
|
|
|
|
def test_get_pertensor_param(self):
|
|
params = self.method.get_pertensor_param(torch.bfloat16)
|
|
self.assertEqual(params['input_scale'].dtype, torch.bfloat16)
|
|
self.assertEqual(params['input_offset'].dtype, torch.int8)
|
|
self.assertEqual(params['input_scale'].shape, (1, ))
|
|
self.assertEqual(params['input_offset'].shape, (1, ))
|
|
|
|
def test_get_perchannel_param(self):
|
|
params = self.method.get_perchannel_param(10, torch.bfloat16)
|
|
|
|
self.assertEqual(params['quant_bias'].dtype, torch.int32)
|
|
self.assertEqual(params['deq_scale'].dtype, torch.float32)
|
|
self.assertEqual(params['weight_scale'].dtype, torch.bfloat16)
|
|
self.assertEqual(params['weight_offset'].dtype, torch.bfloat16)
|
|
self.assertEqual(params['quant_bias'].shape, (10, ))
|
|
self.assertEqual(params['deq_scale'].shape, (10, ))
|
|
self.assertEqual(params['weight_scale'].shape, (10, 1))
|
|
self.assertEqual(params['weight_offset'].shape, (10, 1))
|
|
|
|
@patch(
|
|
"vllm_ascend.quantization.methods.w8a8_static.get_weight_prefetch_method"
|
|
)
|
|
@patch("torch.ops.vllm.quantize")
|
|
@patch("torch_npu.npu_quant_matmul")
|
|
def test_apply_with_x_not_int8(self, mock_npu_quant_matmul, mock_quantize,
|
|
mock_get_weight_prefetch_method):
|
|
layer = MagicMock()
|
|
layer.aclnn_input_scale = 0.1
|
|
layer.aclnn_input_offset = 0.2
|
|
layer.weight = torch.randn(128, 256)
|
|
layer.deq_scale = 0.3
|
|
|
|
mock_get_weight_prefetch_method.return_value = MagicMock()
|
|
|
|
x = torch.randn(32, 128)
|
|
bias = torch.randn(256)
|
|
mock_quantize.return_value = torch.randint(-128,
|
|
127,
|
|
x.shape,
|
|
dtype=torch.int8)
|
|
|
|
expected_y_output = torch.randn(32, 256)
|
|
mock_npu_quant_matmul.return_value = expected_y_output
|
|
|
|
output = self.method.apply(layer, x, bias)
|
|
|
|
expected_y_output += bias
|
|
self.assertTrue(torch.equal(output, expected_y_output))
|
|
|
|
@patch("torch_npu.npu_quant_matmul")
|
|
def test_apply_with_x_is_int8(self, mock_npu_quant_matmul):
|
|
layer = MagicMock()
|
|
layer.aclnn_input_scale = 0.1
|
|
layer.aclnn_input_offset = 0.2
|
|
layer.weight = torch.randn(128, 256)
|
|
layer.deq_scale = 0.3
|
|
|
|
x = torch.randint(-128, 127, (32, 128), dtype=torch.int8)
|
|
bias = torch.randn(256)
|
|
|
|
expected_y_output = torch.randn(32, 256)
|
|
mock_npu_quant_matmul.return_value = expected_y_output
|
|
|
|
output = self.method.apply(layer, x, bias)
|
|
expected_y_output += bias
|
|
self.assertTrue(torch.equal(output, expected_y_output))
|
|
|
|
@patch('vllm_ascend.utils.get_ascend_device_type',
|
|
return_value=AscendDeviceType._310P)
|
|
@patch("torch_npu.npu_quant_matmul")
|
|
def test_apply_with_x_is_310p(self, mock_npu_quant_matmul,
|
|
mock_soc_version):
|
|
layer = MagicMock()
|
|
layer.aclnn_input_scale = 0.1
|
|
layer.aclnn_input_offset = 0.2
|
|
layer.weight = torch.randn(128, 256)
|
|
layer.deq_scale = 0.3
|
|
|
|
x = torch.randint(-128, 127, (32, 128), dtype=torch.int8)
|
|
bias = torch.randn(256)
|
|
|
|
expected_y_output = torch.randn(32, 256)
|
|
mock_npu_quant_matmul.return_value = expected_y_output
|
|
|
|
output = self.method.apply(layer, x, bias)
|
|
expected_y_output += bias
|
|
self.assertTrue(torch.equal(output, expected_y_output))
|
|
|
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"})
|
|
@patch('torch_npu.npu_format_cast')
|
|
def test_process_weights_after_loading_with_nz0(self,
|
|
mock_npu_format_cast):
|
|
layer = MagicMock()
|
|
|
|
layer.weight.data = torch.randint(-127,
|
|
128, (128, 256),
|
|
dtype=torch.int8)
|
|
layer.input_scale.data = torch.tensor([0.1])
|
|
layer.input_offset.data = torch.tensor([0])
|
|
layer.deq_scale = torch.tensor([0.5])
|
|
layer.weight_scale.data = torch.randn(128, 1)
|
|
layer.weight_offset.data = torch.randn(128, 1)
|
|
|
|
mock_npu_format_cast.return_value = MagicMock
|
|
self.method.process_weights_after_loading(layer)
|
|
|
|
expected_offset = torch.tensor([0]).repeat(256).to(torch.int8)
|
|
self.assertTrue(
|
|
torch.equal(layer.aclnn_input_offset.data, expected_offset))
|
|
self.assertFalse(layer.aclnn_input_offset.requires_grad)
|
|
|
|
self.assertFalse(layer.deq_scale.requires_grad)
|
|
|
|
self.assertEqual(layer.weight_scale.data.shape, (128, ))
|
|
self.assertEqual(layer.weight_offset.data.shape, (128, ))
|
|
mock_npu_format_cast.assert_not_called()
|
|
|
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "1"})
|
|
@patch('torch_npu.npu_format_cast')
|
|
def test_process_weights_after_loading_with_nz1(self,
|
|
mock_npu_format_cast):
|
|
layer = MagicMock()
|
|
|
|
layer.weight.data = torch.randint(-127,
|
|
128, (128, 256),
|
|
dtype=torch.int8)
|
|
layer.input_scale.data = torch.tensor([0.1])
|
|
layer.input_offset.data = torch.tensor([0])
|
|
layer.deq_scale = torch.tensor([0.5])
|
|
layer.weight_scale.data = torch.randn(128, 1)
|
|
layer.weight_offset.data = torch.randn(128, 1)
|
|
|
|
mock_npu_format_cast.return_value = MagicMock
|
|
self.method.process_weights_after_loading(layer)
|
|
|
|
expected_offset = torch.tensor([0]).repeat(256).to(torch.int8)
|
|
self.assertTrue(
|
|
torch.equal(layer.aclnn_input_offset.data, expected_offset))
|
|
self.assertFalse(layer.aclnn_input_offset.requires_grad)
|
|
|
|
self.assertFalse(layer.deq_scale.requires_grad)
|
|
|
|
self.assertEqual(layer.weight_scale.data.shape, (128, ))
|
|
self.assertEqual(layer.weight_offset.data.shape, (128, ))
|
|
mock_npu_format_cast.assert_called_once()
|
|
|
|
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "2"})
|
|
@patch('torch_npu.npu_format_cast')
|
|
def test_process_weights_after_loading_with_nz2(self,
|
|
mock_npu_format_cast):
|
|
layer = MagicMock()
|
|
|
|
layer.weight.data = torch.randint(-127,
|
|
128, (128, 256),
|
|
dtype=torch.int8)
|
|
layer.input_scale.data = torch.tensor([0.1])
|
|
layer.input_offset.data = torch.tensor([0])
|
|
layer.deq_scale = torch.tensor([0.5])
|
|
layer.weight_scale.data = torch.randn(128, 1)
|
|
layer.weight_offset.data = torch.randn(128, 1)
|
|
|
|
mock_npu_format_cast.return_value = MagicMock
|
|
self.method.process_weights_after_loading(layer)
|
|
|
|
expected_offset = torch.tensor([0]).repeat(256).to(torch.int8)
|
|
self.assertTrue(
|
|
torch.equal(layer.aclnn_input_offset.data, expected_offset))
|
|
self.assertFalse(layer.aclnn_input_offset.requires_grad)
|
|
|
|
self.assertFalse(layer.deq_scale.requires_grad)
|
|
|
|
self.assertEqual(layer.weight_scale.data.shape, (128, ))
|
|
self.assertEqual(layer.weight_offset.data.shape, (128, ))
|
|
mock_npu_format_cast.assert_called_once()
|