### Summary
This PR refactors the `vllm_ascend/quantization` module to improve code
organization, maintainability, and extensibility. The refactoring
introduces a clear separation of concerns with a registry-based scheme
discovery pattern, abstract base classes for quantization schemes, and
dedicated wrapper classes.
### Key Changes
#### 1. **Modular Directory Structure**
| Before | After |
|--------|-------|
| Flat file structure with mixed responsibilities | Organized into
`methods/` subpackage for schemes |
| Single `quant_config.py` (600+ lines) | Separate config files:
`modelslim_config.py`, `compressed_tensors_config.py` |
| `utils.py` with scheme lookup logic | `methods/registry.py` with
decorator-based registration |
#### 2. **Registry-Based Scheme Discovery**
Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a
decorator-based registry pattern:
```python
# Before: Manual dictionary mapping
ASCEND_QUANTIZATION_METHOD_MAP = {
"W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...},
...
}
# After: Decorator-based registration
@register_scheme("W8A8_DYNAMIC", "linear")
class AscendW8A8DynamicLinearMethod(AscendLinearScheme):
...
```
#### 3. **Abstract Base Classes**
Introduced three abstract base classes in `methods/base.py`:
- `AscendLinearScheme` - Base for linear layer quantization
- `AscendMoEScheme` - Base for MoE layer quantization
- `AscendAttentionScheme` - Base for attention layer quantization
#### 4. **Separated Config and Wrapper Classes**
- **Config classes** (`AscendModelSlimConfig`,
`AscendCompressedTensorsConfig`): Handle config parsing and scheme
selection
- **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`,
etc.): Implement vLLM interfaces and delegate to schemes
#### 5. **Cleaner Public API**
```python
# New clean module interface
from vllm_ascend.quantization import (
AscendModelSlimConfig,
AscendCompressedTensorsConfig,
)
from vllm_ascend.quantization.methods import get_scheme_class
```
### Architecture Diagram
```mermaid
classDiagram
direction TB
class QuantizationConfig {
<<vLLM Interface>>
+get_quant_method()
}
class AscendModelSlimConfig {
+quant_description
+get_quant_method()
-create_scheme_for_layer()
}
class AscendCompressedTensorsConfig {
+target_scheme_map
+get_quant_method()
-_get_scheme_from_parts()
}
class AscendLinearMethod {
<<Wrapper>>
+quant_method: AscendLinearScheme
+create_weights()
+apply()
}
class AscendFusedMoEMethod {
<<Wrapper>>
+quant_method: AscendMoEScheme
+create_weights()
+apply()
}
class AscendLinearScheme {
<<Abstract>>
+get_weight()*
+apply()*
+get_pertensor_param()
+get_perchannel_param()
}
class AscendMoEScheme {
<<Abstract>>
+get_weight()*
+get_dynamic_quant_param()*
+apply()*
}
class W8A8DynamicLinear {
+get_weight()
+apply()
}
class W8A8DynamicMoE {
+get_weight()
+apply()
}
QuantizationConfig <|-- AscendModelSlimConfig
QuantizationConfig <|-- AscendCompressedTensorsConfig
AscendModelSlimConfig ..> AscendLinearMethod : creates
AscendModelSlimConfig ..> AscendFusedMoEMethod : creates
AscendCompressedTensorsConfig ..> AscendLinearMethod : creates
AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates
AscendLinearMethod o-- AscendLinearScheme : delegates to
AscendFusedMoEMethod o-- AscendMoEScheme : delegates to
AscendLinearScheme <|-- W8A8DynamicLinear
AscendMoEScheme <|-- W8A8DynamicMoE
```
### Scheme Registration Flow
```mermaid
sequenceDiagram
participant Module as Scheme Module
participant Registry as _SCHEME_REGISTRY
participant Config as QuantConfig
participant Wrapper as Wrapper Class
Note over Module: At import time
Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear")
Registry->>Registry: Store (quant_type, layer_type) -> Class
Note over Config: At runtime
Config->>Config: Determine quant_type from description
Config->>Registry: get_scheme_class(quant_type, layer_type)
Registry-->>Config: Return scheme class
Config->>Config: scheme = scheme_cls()
Config->>Wrapper: Create wrapper with scheme
Wrapper-->>Config: Return wrapper instance
```
### File Changes Summary
| Original Files | Refactored Files |
|----------------|------------------|
| `__init__.py` (empty) | `__init__.py` (exports public API) |
| `quant_config.py` | `modelslim_config.py` + `wrappers.py` |
| `compressed_tensors/` | `compressed_tensors_config.py` |
| `utils.py` | `methods/registry.py` |
| `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` |
| `w8a8.py` | `methods/w8a8_static.py` |
| `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` |
| ... | `methods/base.py` (new) |
### Benefits
1. **Extensibility**: Adding new quantization schemes only requires
implementing the base class and adding `@register_scheme` decorator
2. **Maintainability**: Clear separation between config parsing, wrapper
logic, and scheme implementation
3. **Testability**: Abstract base classes enable easier unit testing and
mocking
4. **Discoverability**: Registry pattern makes it easy to list all
supported schemes
5. **Reduced Coupling**: Config classes no longer need to know about all
scheme implementations
___
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
271 lines
12 KiB
Python
271 lines
12 KiB
Python
from unittest.mock import Mock, patch
|
|
|
|
import torch
|
|
|
|
from tests.ut.base import TestBase
|
|
from vllm_ascend.quantization.methods.w4a16 import (AscendW4A16FusedMoEMethod,
|
|
pack_to_int32,
|
|
unpack_from_int32)
|
|
|
|
|
|
class TestUnpackFromInt32(TestBase):
|
|
|
|
def test_unpack_from_int32_packed_dim_1(self):
|
|
weight = torch.tensor([[305419896, -1420531520]], dtype=torch.int32)
|
|
shape = torch.Size([1, 8])
|
|
num_bits = 4
|
|
|
|
result = unpack_from_int32(weight, shape, num_bits, packed_dim=1)
|
|
|
|
self.assertEqual(result.dtype, torch.int8)
|
|
self.assertEqual(result.shape, shape)
|
|
|
|
def test_unpack_from_int32_packed_dim_0(self):
|
|
weight = torch.tensor([[305419896], [-1420531520]], dtype=torch.int32)
|
|
shape = torch.Size([8, 1])
|
|
num_bits = 4
|
|
|
|
result = unpack_from_int32(weight, shape, num_bits, packed_dim=0)
|
|
|
|
self.assertEqual(result.dtype, torch.int8)
|
|
self.assertEqual(result.shape, shape)
|
|
|
|
def test_unpack_from_int32_assertions(self):
|
|
with self.assertRaises(AssertionError):
|
|
weight = torch.tensor([[1, 2]], dtype=torch.int64)
|
|
unpack_from_int32(weight, torch.Size([8, 1]), 4)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
weight = torch.tensor([[1, 2]], dtype=torch.int32)
|
|
unpack_from_int32(weight, torch.Size([8, 1]), 16)
|
|
|
|
|
|
class TestPackToInt32(TestBase):
|
|
|
|
@patch(
|
|
"vllm_ascend.quantization.methods.w4a16.torch_npu.npu_convert_weight_to_int4pack"
|
|
)
|
|
def test_pack_to_int32_int8(self, mock_npu_convert_weight_to_int4pack):
|
|
mock_npu_convert_weight_to_int4pack.return_value = torch.zeros(
|
|
(2, 4), dtype=torch.int32)
|
|
|
|
weight = torch.zeros((2, 8, 16), dtype=torch.int8)
|
|
result = pack_to_int32(weight)
|
|
|
|
self.assertEqual(result.dtype, torch.int32)
|
|
mock_npu_convert_weight_to_int4pack.assert_not_called()
|
|
|
|
self.assertEqual(result.shape, torch.Size([2, 8, 4]))
|
|
|
|
@patch(
|
|
"vllm_ascend.quantization.methods.w4a16.torch_npu.npu_convert_weight_to_int4pack"
|
|
)
|
|
def test_pack_to_int32_int32(self, mock_npu_convert_weight_to_int4pack):
|
|
|
|
def mock_convert_weight(weight):
|
|
return weight
|
|
|
|
mock_npu_convert_weight_to_int4pack.side_effect = mock_convert_weight
|
|
weight = torch.zeros((2, 8, 8), dtype=torch.int32)
|
|
result = pack_to_int32(weight)
|
|
|
|
self.assertEqual(result.dtype, torch.int32)
|
|
self.assertEqual(result.shape, weight.shape)
|
|
|
|
def test_pack_to_int32_assertion_dim(self):
|
|
with self.assertRaises(AssertionError):
|
|
weight = torch.zeros((8, 8), dtype=torch.int8)
|
|
pack_to_int32(weight)
|
|
|
|
def test_pack_to_int32_assertion_dtype(self):
|
|
with self.assertRaises(AssertionError):
|
|
weight = torch.zeros((2, 8, 8), dtype=torch.float32)
|
|
pack_to_int32(weight)
|
|
|
|
def test_pack_to_int32_assertion_divisible(self):
|
|
with self.assertRaises(AssertionError):
|
|
weight = torch.zeros((2, 8, 7), dtype=torch.int32)
|
|
pack_to_int32(weight)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
weight = torch.zeros((2, 8, 7), dtype=torch.int8)
|
|
pack_to_int32(weight)
|
|
|
|
|
|
class TestAscendW4A16FusedMoEMethod(TestBase):
|
|
experts = 8
|
|
input_size = 32
|
|
output_size = 128
|
|
group_size = 32
|
|
|
|
@patch("vllm_ascend.quantization.methods.w4a16.get_ascend_config")
|
|
@patch("vllm_ascend.quantization.methods.w4a16.get_current_vllm_config")
|
|
def setUp(self, mock_get_current_vllm_config, mock_get_ascend_config):
|
|
mock_ascend_config = Mock()
|
|
mock_ascend_config.eplb_config.dynamic_eplb = False
|
|
mock_ascend_config.eplb_config.expert_map_record_path = None
|
|
mock_get_ascend_config.return_value = mock_ascend_config
|
|
|
|
mock_vllm_config = Mock()
|
|
mock_vllm_config.quant_config = Mock(quant_description={
|
|
"group_size": self.group_size,
|
|
})
|
|
mock_get_current_vllm_config.return_value = mock_vllm_config
|
|
|
|
self.quant_method = AscendW4A16FusedMoEMethod()
|
|
|
|
def test_init(self):
|
|
self.assertTrue(self.quant_method.transpose_weight)
|
|
self.assertEqual(self.quant_method.num_bits, 4)
|
|
self.assertEqual(self.quant_method.pack_factor, 8)
|
|
self.assertEqual(self.quant_method.group_size, self.group_size)
|
|
self.assertFalse(self.quant_method.dynamic_eplb)
|
|
|
|
def test_get_weight(self):
|
|
param_dict = self.quant_method.get_weight(self.experts,
|
|
self.input_size,
|
|
self.output_size,
|
|
torch.bfloat16)
|
|
|
|
self.assertEqual(param_dict["w13_weight_packed"].dtype, torch.int32)
|
|
expected_w13_shape = (self.experts, 2 * self.input_size,
|
|
self.output_size //
|
|
self.quant_method.pack_factor)
|
|
self.assertEqual(param_dict["w13_weight_packed"].shape,
|
|
expected_w13_shape)
|
|
|
|
self.assertEqual(param_dict["w2_weight_packed"].dtype, torch.int32)
|
|
expected_w2_shape = (self.experts, self.output_size,
|
|
self.input_size // self.quant_method.pack_factor)
|
|
self.assertEqual(param_dict["w2_weight_packed"].shape,
|
|
expected_w2_shape)
|
|
|
|
def test_get_dynamic_quant_param(self):
|
|
param_dict = self.quant_method.get_dynamic_quant_param(
|
|
self.experts, self.input_size, self.output_size, torch.bfloat16)
|
|
|
|
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
|
|
expected_w13_scale_shape = (self.experts, 2 * self.input_size,
|
|
self.output_size // self.group_size)
|
|
self.assertEqual(param_dict["w13_weight_scale"].shape,
|
|
expected_w13_scale_shape)
|
|
|
|
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16)
|
|
expected_w2_scale_shape = (self.experts, self.output_size,
|
|
self.input_size // self.group_size)
|
|
self.assertEqual(param_dict["w2_weight_scale"].shape,
|
|
expected_w2_scale_shape)
|
|
|
|
self.assertEqual(param_dict["w13_weight_shape"].dtype, torch.int32)
|
|
self.assertEqual(param_dict["w13_weight_shape"].shape,
|
|
(self.experts, 2))
|
|
|
|
self.assertEqual(param_dict["w2_weight_shape"].dtype, torch.int32)
|
|
self.assertEqual(param_dict["w2_weight_shape"].shape,
|
|
(self.experts, 2))
|
|
|
|
self.assertEqual(param_dict["w13_weight_offset"].dtype, torch.bfloat16)
|
|
self.assertEqual(param_dict["w13_weight_offset"].shape,
|
|
expected_w13_scale_shape)
|
|
|
|
self.assertEqual(param_dict["w2_weight_offset"].dtype, torch.bfloat16)
|
|
self.assertEqual(param_dict["w2_weight_offset"].shape,
|
|
expected_w2_scale_shape)
|
|
|
|
def build_layer(self):
|
|
"""Build a mock layer for testing"""
|
|
layer = torch.nn.Module()
|
|
|
|
w13_shape = (self.experts, 2 * self.input_size,
|
|
self.output_size // self.quant_method.pack_factor)
|
|
w2_shape = (self.experts, self.output_size,
|
|
self.input_size // self.quant_method.pack_factor)
|
|
|
|
layer.w13_weight_packed = torch.nn.Parameter(torch.randint(
|
|
-100, 100, w13_shape, dtype=torch.int32),
|
|
requires_grad=False)
|
|
layer.w2_weight_packed = torch.nn.Parameter(torch.randint(
|
|
-100, 100, w2_shape, dtype=torch.int32),
|
|
requires_grad=False)
|
|
|
|
w13_scale_shape = (self.experts, 2 * self.input_size,
|
|
self.output_size // self.group_size)
|
|
w2_scale_shape = (self.experts, self.output_size,
|
|
self.input_size // self.group_size)
|
|
|
|
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
|
w13_scale_shape, dtype=torch.bfloat16),
|
|
requires_grad=False)
|
|
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
|
|
w2_scale_shape, dtype=torch.bfloat16),
|
|
requires_grad=False)
|
|
|
|
layer.w13_weight_offset = torch.nn.Parameter(torch.zeros(
|
|
w13_scale_shape, dtype=torch.bfloat16),
|
|
requires_grad=False)
|
|
layer.w2_weight_offset = torch.nn.Parameter(torch.zeros(
|
|
w2_scale_shape, dtype=torch.bfloat16),
|
|
requires_grad=False)
|
|
|
|
layer.w13_weight_shape = torch.nn.Parameter(torch.tensor(
|
|
[[2 * self.input_size, self.output_size]] * self.experts,
|
|
dtype=torch.int32),
|
|
requires_grad=False)
|
|
layer.w2_weight_shape = torch.nn.Parameter(torch.tensor(
|
|
[[self.output_size, self.input_size]] * self.experts,
|
|
dtype=torch.int32),
|
|
requires_grad=False)
|
|
|
|
return layer
|
|
|
|
@patch(
|
|
"vllm_ascend.quantization.methods.w4a16.torch_npu.npu_convert_weight_to_int4pack"
|
|
)
|
|
def test_process_weights_after_loading_with_transpose(
|
|
self, mock_npu_convert_weight_to_int4pack):
|
|
|
|
def mock_convert_weight(weight):
|
|
new_shape = list(weight.shape)
|
|
new_shape[-1] = new_shape[-1] // 8
|
|
return torch.zeros(new_shape, dtype=torch.int32)
|
|
|
|
mock_npu_convert_weight_to_int4pack.side_effect = mock_convert_weight
|
|
|
|
layer = self.build_layer()
|
|
self.quant_method.transpose_weight = True
|
|
|
|
self.quant_method.process_weights_after_loading(layer)
|
|
|
|
self.assertEqual(layer.w13_weight_packed.data.shape,
|
|
torch.Size([8, 128, 8]))
|
|
self.assertEqual(layer.w2_weight_packed.data.shape,
|
|
torch.Size([8, 32, 16]))
|
|
|
|
self.assertEqual(layer.w13_weight_scale.data.shape,
|
|
torch.Size([8, 4, 64]))
|
|
self.assertEqual(layer.w2_weight_scale.data.shape,
|
|
torch.Size([8, 1, 128]))
|
|
self.assertEqual(layer.w13_weight_offset.data.shape,
|
|
torch.Size([8, 4, 64]))
|
|
self.assertEqual(layer.w2_weight_offset.data.shape,
|
|
torch.Size([8, 1, 128]))
|
|
|
|
self.assertTrue(layer.w13_weight_scale.data.is_contiguous())
|
|
self.assertTrue(layer.w2_weight_scale.data.is_contiguous())
|
|
self.assertTrue(layer.w13_weight_offset.data.is_contiguous())
|
|
self.assertTrue(layer.w2_weight_offset.data.is_contiguous())
|
|
|
|
def test_process_weights_after_loading_without_transpose(self):
|
|
layer = self.build_layer()
|
|
self.quant_method.transpose_weight = False
|
|
|
|
original_w13_data = layer.w13_weight_packed.data.clone()
|
|
original_w2_data = layer.w2_weight_packed.data.clone()
|
|
|
|
self.quant_method.process_weights_after_loading(layer)
|
|
|
|
self.assertTrue(
|
|
torch.equal(layer.w13_weight_packed.data, original_w13_data))
|
|
self.assertTrue(
|
|
torch.equal(layer.w2_weight_packed.data, original_w2_data))
|