[Refactor] Quantization Module Refactor (#5738)
### Summary
This PR refactors the `vllm_ascend/quantization` module to improve code
organization, maintainability, and extensibility. The refactoring
introduces a clear separation of concerns with a registry-based scheme
discovery pattern, abstract base classes for quantization schemes, and
dedicated wrapper classes.
### Key Changes
#### 1. **Modular Directory Structure**
| Before | After |
|--------|-------|
| Flat file structure with mixed responsibilities | Organized into
`methods/` subpackage for schemes |
| Single `quant_config.py` (600+ lines) | Separate config files:
`modelslim_config.py`, `compressed_tensors_config.py` |
| `utils.py` with scheme lookup logic | `methods/registry.py` with
decorator-based registration |
#### 2. **Registry-Based Scheme Discovery**
Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a
decorator-based registry pattern:
```python
# Before: Manual dictionary mapping
ASCEND_QUANTIZATION_METHOD_MAP = {
"W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...},
...
}
# After: Decorator-based registration
@register_scheme("W8A8_DYNAMIC", "linear")
class AscendW8A8DynamicLinearMethod(AscendLinearScheme):
...
```
#### 3. **Abstract Base Classes**
Introduced three abstract base classes in `methods/base.py`:
- `AscendLinearScheme` - Base for linear layer quantization
- `AscendMoEScheme` - Base for MoE layer quantization
- `AscendAttentionScheme` - Base for attention layer quantization
#### 4. **Separated Config and Wrapper Classes**
- **Config classes** (`AscendModelSlimConfig`,
`AscendCompressedTensorsConfig`): Handle config parsing and scheme
selection
- **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`,
etc.): Implement vLLM interfaces and delegate to schemes
#### 5. **Cleaner Public API**
```python
# New clean module interface
from vllm_ascend.quantization import (
AscendModelSlimConfig,
AscendCompressedTensorsConfig,
)
from vllm_ascend.quantization.methods import get_scheme_class
```
### Architecture Diagram
```mermaid
classDiagram
direction TB
class QuantizationConfig {
<<vLLM Interface>>
+get_quant_method()
}
class AscendModelSlimConfig {
+quant_description
+get_quant_method()
-create_scheme_for_layer()
}
class AscendCompressedTensorsConfig {
+target_scheme_map
+get_quant_method()
-_get_scheme_from_parts()
}
class AscendLinearMethod {
<<Wrapper>>
+quant_method: AscendLinearScheme
+create_weights()
+apply()
}
class AscendFusedMoEMethod {
<<Wrapper>>
+quant_method: AscendMoEScheme
+create_weights()
+apply()
}
class AscendLinearScheme {
<<Abstract>>
+get_weight()*
+apply()*
+get_pertensor_param()
+get_perchannel_param()
}
class AscendMoEScheme {
<<Abstract>>
+get_weight()*
+get_dynamic_quant_param()*
+apply()*
}
class W8A8DynamicLinear {
+get_weight()
+apply()
}
class W8A8DynamicMoE {
+get_weight()
+apply()
}
QuantizationConfig <|-- AscendModelSlimConfig
QuantizationConfig <|-- AscendCompressedTensorsConfig
AscendModelSlimConfig ..> AscendLinearMethod : creates
AscendModelSlimConfig ..> AscendFusedMoEMethod : creates
AscendCompressedTensorsConfig ..> AscendLinearMethod : creates
AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates
AscendLinearMethod o-- AscendLinearScheme : delegates to
AscendFusedMoEMethod o-- AscendMoEScheme : delegates to
AscendLinearScheme <|-- W8A8DynamicLinear
AscendMoEScheme <|-- W8A8DynamicMoE
```
### Scheme Registration Flow
```mermaid
sequenceDiagram
participant Module as Scheme Module
participant Registry as _SCHEME_REGISTRY
participant Config as QuantConfig
participant Wrapper as Wrapper Class
Note over Module: At import time
Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear")
Registry->>Registry: Store (quant_type, layer_type) -> Class
Note over Config: At runtime
Config->>Config: Determine quant_type from description
Config->>Registry: get_scheme_class(quant_type, layer_type)
Registry-->>Config: Return scheme class
Config->>Config: scheme = scheme_cls()
Config->>Wrapper: Create wrapper with scheme
Wrapper-->>Config: Return wrapper instance
```
### File Changes Summary
| Original Files | Refactored Files |
|----------------|------------------|
| `__init__.py` (empty) | `__init__.py` (exports public API) |
| `quant_config.py` | `modelslim_config.py` + `wrappers.py` |
| `compressed_tensors/` | `compressed_tensors_config.py` |
| `utils.py` | `methods/registry.py` |
| `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` |
| `w8a8.py` | `methods/w8a8_static.py` |
| `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` |
| ... | `methods/base.py` (new) |
### Benefits
1. **Extensibility**: Adding new quantization schemes only requires
implementing the base class and adding `@register_scheme` decorator
2. **Maintainability**: Clear separation between config parsing, wrapper
logic, and scheme implementation
3. **Testability**: Abstract base classes enable easier unit testing and
mocking
4. **Discoverability**: Registry pattern makes it easy to list all
supported schemes
5. **Reduced Coupling**: Config classes no longer need to know about all
scheme implementations
___
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
@@ -20,8 +20,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from pytest_mock import MockerFixture
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
@@ -233,25 +231,6 @@ class MockQuantMethod(nn.Module):
|
||||
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32)))
|
||||
|
||||
|
||||
class MockFusedMoEMethod(FusedMoEMethodBase):
|
||||
moe = MagicMock()
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(self.moe)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
pass
|
||||
|
||||
def apply(self, hidden_states: torch.Tensor,
|
||||
expert_weights: torch.Tensor) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module):
|
||||
pass
|
||||
|
||||
|
||||
class TestExpertsSelector:
|
||||
|
||||
@pytest.mark.parametrize("global_num_experts", [256, 128])
|
||||
|
||||
@@ -7,11 +7,11 @@ from vllm.model_executor.layers.linear import LinearBase
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
|
||||
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
||||
from vllm_ascend.quantization.modelslim_config import AscendModelSlimConfig
|
||||
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
|
||||
|
||||
|
||||
class TestAscendQuantConfig(TestBase):
|
||||
class TestAscendModelSlimConfig(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.sample_config = {
|
||||
@@ -25,7 +25,7 @@ class TestAscendQuantConfig(TestBase):
|
||||
"shard1.weight": "FLOAT",
|
||||
"shard2.weight": "FLOAT",
|
||||
}
|
||||
self.ascend_config = AscendQuantConfig(self.sample_config)
|
||||
self.ascend_config = AscendModelSlimConfig(self.sample_config)
|
||||
self.ascend_config.packed_modules_mapping = None
|
||||
|
||||
def test_init(self):
|
||||
@@ -34,55 +34,55 @@ class TestAscendQuantConfig(TestBase):
|
||||
|
||||
def test_repr(self):
|
||||
repr_str = repr(self.ascend_config)
|
||||
self.assertTrue(repr_str.startswith("AscendQuantConfig:\n"))
|
||||
self.assertTrue(repr_str.startswith("AscendModelSlimConfig:\n"))
|
||||
|
||||
def test_get_name(self):
|
||||
self.assertEqual(AscendQuantConfig.get_name(),
|
||||
self.assertEqual(AscendModelSlimConfig.get_name(),
|
||||
ASCEND_QUANTIZATION_METHOD)
|
||||
|
||||
def test_get_supported_act_dtypes(self):
|
||||
supported_dtypes = AscendQuantConfig.get_supported_act_dtypes()
|
||||
supported_dtypes = AscendModelSlimConfig.get_supported_act_dtypes()
|
||||
self.assertEqual(len(supported_dtypes), 3)
|
||||
|
||||
def test_get_min_capability(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
AscendQuantConfig.get_min_capability()
|
||||
AscendModelSlimConfig.get_min_capability()
|
||||
|
||||
def test_get_config_filenames(self):
|
||||
filenames = AscendQuantConfig.get_config_filenames()
|
||||
filenames = AscendModelSlimConfig.get_config_filenames()
|
||||
self.assertEqual(filenames, ["quant_model_description.json"])
|
||||
|
||||
def test_from_config(self):
|
||||
config = AscendQuantConfig.from_config(self.sample_config)
|
||||
self.assertIsInstance(config, AscendQuantConfig)
|
||||
config = AscendModelSlimConfig.from_config(self.sample_config)
|
||||
self.assertIsInstance(config, AscendModelSlimConfig)
|
||||
self.assertEqual(config.quant_description, self.sample_config)
|
||||
|
||||
@patch('torch.npu.is_available')
|
||||
def test_override_quantization_method(self, mock_is_available):
|
||||
# Test when NPU is available
|
||||
mock_is_available.return_value = True
|
||||
result = AscendQuantConfig.override_quantization_method(None, None)
|
||||
result = AscendModelSlimConfig.override_quantization_method(None, None)
|
||||
self.assertIsNone(result)
|
||||
hf_quant_cfg = {"quant_method": ""}
|
||||
result = AscendQuantConfig.override_quantization_method(
|
||||
result = AscendModelSlimConfig.override_quantization_method(
|
||||
hf_quant_cfg, None)
|
||||
self.assertEqual(result, "ascend")
|
||||
|
||||
# Test when NPU is not available
|
||||
mock_is_available.return_value = False
|
||||
result = AscendQuantConfig.override_quantization_method(None, None)
|
||||
result = AscendModelSlimConfig.override_quantization_method(None, None)
|
||||
self.assertIsNone(result)
|
||||
hf_quant_cfg = {"quant_method": ""}
|
||||
result = AscendQuantConfig.override_quantization_method(
|
||||
result = AscendModelSlimConfig.override_quantization_method(
|
||||
hf_quant_cfg, None)
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_get_quant_method_for_linear(self):
|
||||
mock_config = MagicMock()
|
||||
mock_config.model_config.hf_text_config.model_type = None
|
||||
mock_config.model_config.hf_config.model_type = None
|
||||
linear_layer = MagicMock(spec=LinearBase)
|
||||
# Test skipped layer
|
||||
with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
||||
with patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch.object(self.ascend_config, \
|
||||
'is_layer_skipped_ascend',
|
||||
return_value=True):
|
||||
@@ -90,22 +90,24 @@ class TestAscendQuantConfig(TestBase):
|
||||
self.assertIsInstance(method, AscendUnquantizedLinearMethod)
|
||||
|
||||
# Test quantized layer
|
||||
mock_scheme = MagicMock()
|
||||
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
|
||||
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch('vllm_ascend.quantization.quant_config.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear:
|
||||
patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch("vllm_ascend.quantization.modelslim_config.create_scheme_for_layer", return_value=mock_scheme), \
|
||||
patch('vllm_ascend.quantization.method_adapters.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear:
|
||||
|
||||
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
|
||||
self.assertIs(method, mock_ascend_linear.return_value)
|
||||
mock_ascend_linear.assert_called_once_with(
|
||||
self.ascend_config, ".attn",
|
||||
self.ascend_config.packed_modules_mapping, linear_layer)
|
||||
mock_ascend_linear.assert_called_once_with(mock_scheme)
|
||||
|
||||
def test_get_quant_method_for_attention(self):
|
||||
attention_layer = MagicMock(spec=Attention)
|
||||
mock_config = MagicMock()
|
||||
mock_config.model_config.hf_text_config.model_type = None
|
||||
with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod', \
|
||||
mock_config.model_config.hf_config.model_type = None
|
||||
mock_scheme = MagicMock()
|
||||
with patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch("vllm_ascend.quantization.modelslim_config.create_scheme_for_layer", return_value=mock_scheme), \
|
||||
patch('vllm_ascend.quantization.method_adapters.AscendKVCacheMethod', \
|
||||
return_value=MagicMock()) as mock_ascend_kvcache:
|
||||
# Test with fa_quant_type
|
||||
method = self.ascend_config.get_quant_method(
|
||||
@@ -117,20 +119,22 @@ class TestAscendQuantConfig(TestBase):
|
||||
fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig)
|
||||
fused_moe_layer.moe_config = MagicMock(spec=FusedMoEConfig)
|
||||
mock_config = MagicMock()
|
||||
mock_config.model_config.hf_text_config.model_type = None
|
||||
mock_config.model_config.hf_config.model_type = None
|
||||
|
||||
# Test skipped layer
|
||||
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \
|
||||
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch('vllm_ascend.quantization.quant_config.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
|
||||
patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch('vllm_ascend.ops.fused_moe.fused_moe.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
|
||||
method = self.ascend_config.get_quant_method(
|
||||
fused_moe_layer, "moe_layer")
|
||||
self.assertIs(method, mock_ascend_moe.return_value)
|
||||
|
||||
# Test quantized layer
|
||||
mock_scheme = MagicMock()
|
||||
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
|
||||
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch('vllm_ascend.quantization.quant_config.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
|
||||
patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch("vllm_ascend.quantization.modelslim_config.create_scheme_for_layer", return_value=mock_scheme), \
|
||||
patch('vllm_ascend.quantization.method_adapters.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
|
||||
method = self.ascend_config.get_quant_method(
|
||||
fused_moe_layer, "moe_layer")
|
||||
self.assertIs(method, mock_ascend_moe.return_value)
|
||||
@@ -150,7 +154,7 @@ class TestAscendQuantConfig(TestBase):
|
||||
|
||||
# Test inconsistent fused layer shards
|
||||
bad_config = {"shard1.weight": "FLOAT", "shard2.weight": "INT8"}
|
||||
config = AscendQuantConfig(bad_config)
|
||||
config = AscendModelSlimConfig(bad_config)
|
||||
with self.assertRaises(ValueError):
|
||||
config.is_layer_skipped_ascend("fused_layer", fused_mapping)
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
import types
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.utils import (ASCEND_QUANTIZATION_METHOD_MAP,
|
||||
get_quant_method)
|
||||
|
||||
|
||||
class TestGetQuantMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.original_quantization_method_map = ASCEND_QUANTIZATION_METHOD_MAP.copy(
|
||||
)
|
||||
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
|
||||
for layer_type in layer_map.keys():
|
||||
ASCEND_QUANTIZATION_METHOD_MAP[quant_type][
|
||||
layer_type] = types.new_class(f"{quant_type}_{layer_type}")
|
||||
|
||||
def tearDown(self):
|
||||
# Restore original map
|
||||
ASCEND_QUANTIZATION_METHOD_MAP.clear()
|
||||
ASCEND_QUANTIZATION_METHOD_MAP.update(
|
||||
self.original_quantization_method_map)
|
||||
|
||||
def test_linear_quant_methods(self):
|
||||
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
|
||||
if "linear" in layer_map.keys():
|
||||
prefix = "linear_layer"
|
||||
cls = layer_map["linear"]
|
||||
method = get_quant_method({"linear_layer.weight": quant_type},
|
||||
prefix, "linear")
|
||||
self.assertIsInstance(method, cls)
|
||||
|
||||
def test_moe_quant_methods(self):
|
||||
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
|
||||
if "moe" in layer_map.keys():
|
||||
prefix = "layer"
|
||||
cls = layer_map["moe"]
|
||||
method = get_quant_method({"layer.weight": quant_type}, prefix,
|
||||
"moe")
|
||||
self.assertIsInstance(method, cls)
|
||||
|
||||
def test_invalid_layer_type(self):
|
||||
quant_description = {"linear_layer.weight": "W8A8"}
|
||||
with self.assertRaises(NotImplementedError):
|
||||
get_quant_method(quant_description, "linear_layer", "unsupported")
|
||||
|
||||
def test_invalid_quant_type(self):
|
||||
quant_description = {"linear_layer.weight": "UNKNOWN"}
|
||||
with self.assertRaises(NotImplementedError):
|
||||
get_quant_method(quant_description, "linear_layer", "linear")
|
||||
@@ -3,8 +3,9 @@ from unittest.mock import Mock, patch
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.w4a16 import (AscendW4A16FusedMoEMethod,
|
||||
pack_to_int32, unpack_from_int32)
|
||||
from vllm_ascend.quantization.methods.w4a16 import (AscendW4A16FusedMoEMethod,
|
||||
pack_to_int32,
|
||||
unpack_from_int32)
|
||||
|
||||
|
||||
class TestUnpackFromInt32(TestBase):
|
||||
@@ -42,7 +43,7 @@ class TestUnpackFromInt32(TestBase):
|
||||
class TestPackToInt32(TestBase):
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.quantization.w4a16.torch_npu.npu_convert_weight_to_int4pack"
|
||||
"vllm_ascend.quantization.methods.w4a16.torch_npu.npu_convert_weight_to_int4pack"
|
||||
)
|
||||
def test_pack_to_int32_int8(self, mock_npu_convert_weight_to_int4pack):
|
||||
mock_npu_convert_weight_to_int4pack.return_value = torch.zeros(
|
||||
@@ -57,7 +58,7 @@ class TestPackToInt32(TestBase):
|
||||
self.assertEqual(result.shape, torch.Size([2, 8, 4]))
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.quantization.w4a16.torch_npu.npu_convert_weight_to_int4pack"
|
||||
"vllm_ascend.quantization.methods.w4a16.torch_npu.npu_convert_weight_to_int4pack"
|
||||
)
|
||||
def test_pack_to_int32_int32(self, mock_npu_convert_weight_to_int4pack):
|
||||
|
||||
@@ -97,8 +98,8 @@ class TestAscendW4A16FusedMoEMethod(TestBase):
|
||||
output_size = 128
|
||||
group_size = 32
|
||||
|
||||
@patch("vllm_ascend.quantization.w4a16.get_ascend_config")
|
||||
@patch("vllm_ascend.quantization.w4a16.get_current_vllm_config")
|
||||
@patch("vllm_ascend.quantization.methods.w4a16.get_ascend_config")
|
||||
@patch("vllm_ascend.quantization.methods.w4a16.get_current_vllm_config")
|
||||
def setUp(self, mock_get_current_vllm_config, mock_get_ascend_config):
|
||||
mock_ascend_config = Mock()
|
||||
mock_ascend_config.eplb_config.dynamic_eplb = False
|
||||
@@ -218,7 +219,7 @@ class TestAscendW4A16FusedMoEMethod(TestBase):
|
||||
return layer
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.quantization.w4a16.torch_npu.npu_convert_weight_to_int4pack"
|
||||
"vllm_ascend.quantization.methods.w4a16.torch_npu.npu_convert_weight_to_int4pack"
|
||||
)
|
||||
def test_process_weights_after_loading_with_transpose(
|
||||
self, mock_npu_convert_weight_to_int4pack):
|
||||
|
||||
@@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm_ascend.quantization.w4a4_flatquant_dynamic import (
|
||||
from vllm_ascend.quantization.methods.w4a4_flatquant import (
|
||||
AscendW4A4FlatQuantDynamicLinearMethod, get_decompose_dim,
|
||||
pack_int4_weights)
|
||||
|
||||
@@ -33,7 +33,7 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
|
||||
self.assertEqual(get_decompose_dim(100), (10, 10))
|
||||
self.assertEqual(get_decompose_dim(99), (9, 11))
|
||||
|
||||
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu')
|
||||
@patch('vllm_ascend.quantization.methods.w4a4_flatquant.torch_npu')
|
||||
def test_pack_int4_weights_npu_success(self, mock_torch_npu):
|
||||
"""
|
||||
Tests weight packing using the mocked NPU kernel.
|
||||
@@ -119,7 +119,7 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
|
||||
x = torch.randn(batch_size, self.input_size, dtype=self.params_dtype)
|
||||
return layer, x, m, n
|
||||
|
||||
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu')
|
||||
@patch('vllm_ascend.quantization.methods.w4a4_flatquant.torch_npu')
|
||||
def test_apply_small_batch(self, mock_torch_npu):
|
||||
"""Tests the apply method with a batch size smaller than MAX_BATCH_SIZE."""
|
||||
batch_size = 128
|
||||
@@ -143,9 +143,9 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
|
||||
self.assertEqual(output.shape, (batch_size, self.output_size))
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.quantization.w4a4_flatquant_dynamic.KRONECKER_QUANT_MAX_BATCH_SIZE',
|
||||
'vllm_ascend.quantization.methods.w4a4_flatquant.KRONECKER_QUANT_MAX_BATCH_SIZE',
|
||||
10)
|
||||
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu')
|
||||
@patch('vllm_ascend.quantization.methods.w4a4_flatquant.torch_npu')
|
||||
def test_apply_large_batch(self, mock_torch_npu):
|
||||
"""Tests the apply method with a batch size larger than MAX_BATCH_SIZE."""
|
||||
batch_size = 25
|
||||
@@ -178,7 +178,7 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
|
||||
ValueError, "FlatQuant transform matrices dimension mismatch"):
|
||||
self.method.apply(layer, x)
|
||||
|
||||
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.pack_int4_weights')
|
||||
@patch('vllm_ascend.quantization.methods.w4a4_flatquant.pack_int4_weights')
|
||||
def test_process_weights_after_loading(self, mock_pack_weights):
|
||||
"""Tests weight processing after loading, without transpose."""
|
||||
layer = nn.Module()
|
||||
|
||||
@@ -3,14 +3,14 @@ from unittest.mock import Mock, patch
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.w4a8_dynamic import (
|
||||
from vllm_ascend.quantization.methods.w4a8 import (
|
||||
AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod)
|
||||
|
||||
|
||||
class TestAscendW4A8DynamicLinearMethod(TestBase):
|
||||
|
||||
@patch('vllm.distributed.get_tensor_model_parallel_world_size')
|
||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
|
||||
@patch('vllm_ascend.quantization.methods.w4a8.get_current_vllm_config')
|
||||
def setUp(self, mock_get_current_vllm_config, mock_get_tp_world_size):
|
||||
mock_get_tp_world_size.return_value = 1
|
||||
mock_vllm_config = Mock()
|
||||
@@ -127,10 +127,10 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||
output_size = 56
|
||||
group_size = 2
|
||||
|
||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ascend_config')
|
||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
|
||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group')
|
||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group')
|
||||
@patch('vllm_ascend.quantization.methods.w4a8.get_ascend_config')
|
||||
@patch('vllm_ascend.quantization.methods.w4a8.get_current_vllm_config')
|
||||
@patch('vllm_ascend.quantization.methods.w4a8.get_ep_group')
|
||||
@patch('vllm_ascend.quantization.methods.w4a8.get_mc2_group')
|
||||
@patch('torch.distributed.get_rank', return_value=0)
|
||||
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ep_group,
|
||||
get_current_vllm_config, mock_get_ascend_config):
|
||||
|
||||
@@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.w8a16 import AscendW8A16LinearMethod
|
||||
from vllm_ascend.quantization.methods.w8a16 import AscendW8A16LinearMethod
|
||||
|
||||
|
||||
class TestAscendW8A16LinearMethod(TestBase):
|
||||
|
||||
@@ -4,36 +4,10 @@ from unittest.mock import MagicMock, patch
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.w8a8 import (AscendW8A8LinearMethod,
|
||||
quant_per_tensor)
|
||||
from vllm_ascend.quantization.methods.w8a8_static import AscendW8A8LinearMethod
|
||||
from vllm_ascend.utils import AscendDeviceType
|
||||
|
||||
|
||||
class TestQuantPerTensor(TestBase):
|
||||
|
||||
@patch("torch_npu.npu_quantize")
|
||||
def test_quant_per_tensor(self, mock_npu_quantize):
|
||||
in_tensor = torch.randn(32, 128)
|
||||
input_scale = torch.tensor(0.1)
|
||||
input_offset = torch.tensor(0)
|
||||
|
||||
expected_output = torch.randint(-128, 127, (32, 128), dtype=torch.int8)
|
||||
mock_npu_quantize.return_value = expected_output
|
||||
|
||||
output = quant_per_tensor(in_tensor, input_scale, input_offset)
|
||||
|
||||
mock_npu_quantize.assert_called_once_with(
|
||||
in_tensor,
|
||||
input_scale,
|
||||
input_offset,
|
||||
torch.qint8,
|
||||
-1,
|
||||
False,
|
||||
)
|
||||
|
||||
self.assertTrue(torch.equal(output, expected_output))
|
||||
|
||||
|
||||
class TestAscendW8A8LinearMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
@@ -63,7 +37,9 @@ class TestAscendW8A8LinearMethod(TestBase):
|
||||
self.assertEqual(params['weight_scale'].shape, (10, 1))
|
||||
self.assertEqual(params['weight_offset'].shape, (10, 1))
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.get_weight_prefetch_method")
|
||||
@patch(
|
||||
"vllm_ascend.quantization.methods.w8a8_static.get_weight_prefetch_method"
|
||||
)
|
||||
@patch("torch.ops.vllm.quantize")
|
||||
@patch("torch_npu.npu_quant_matmul")
|
||||
def test_apply_with_x_not_int8(self, mock_npu_quant_matmul, mock_quantize,
|
||||
|
||||
@@ -3,7 +3,7 @@ from unittest.mock import Mock, patch
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.w8a8_dynamic import \
|
||||
from vllm_ascend.quantization.methods.w8a8_dynamic import \
|
||||
AscendW8A8DynamicFusedMoEMethod
|
||||
|
||||
|
||||
@@ -13,13 +13,13 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||
intermediate_size = 128
|
||||
|
||||
@patch("torch.distributed.get_rank")
|
||||
@patch("vllm_ascend.quantization.w8a8_dynamic.get_mc2_group")
|
||||
@patch("vllm_ascend.quantization.w8a8_dynamic.get_ascend_config")
|
||||
@patch("vllm_ascend.quantization.w8a8_dynamic.get_ep_group")
|
||||
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.get_mc2_group")
|
||||
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.get_ascend_config")
|
||||
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.get_ep_group")
|
||||
def setUp(self, mock_get_ep_group, mock_get_ascend_config,
|
||||
mock_get_mc2_group, mock_get_rank):
|
||||
with patch(
|
||||
'vllm_ascend.quantization.w8a8_dynamic.get_current_vllm_config'
|
||||
'vllm_ascend.quantization.methods.w8a8_dynamic.get_current_vllm_config'
|
||||
) as mock_get_current_vllm_config:
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.quant_config = Mock(
|
||||
|
||||
@@ -51,8 +51,9 @@ class TestNPUPlatform(TestBase):
|
||||
self.assertTrue(self.platform.is_sleep_mode_available())
|
||||
|
||||
@patch("vllm_ascend.utils.adapt_patch")
|
||||
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
|
||||
def test_pre_register_and_update_with_parser(self, mock_quant_config, mock_adapt_patch):
|
||||
@patch("vllm_ascend.quantization.modelslim_config.AscendModelSlimConfig")
|
||||
def test_pre_register_and_update_with_parser(self, mock_quant_config,
|
||||
mock_adapt_patch):
|
||||
mock_parser = MagicMock()
|
||||
mock_action = MagicMock()
|
||||
mock_action.choices = ["awq", "gptq"]
|
||||
@@ -66,15 +67,17 @@ class TestNPUPlatform(TestBase):
|
||||
self.assertEqual(len(mock_action.choices), 3) # original 2 + ascend
|
||||
|
||||
@patch("vllm_ascend.utils.adapt_patch")
|
||||
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
|
||||
def test_pre_register_and_update_without_parser(self, mock_quant_config, mock_adapt_patch):
|
||||
@patch("vllm_ascend.quantization.modelslim_config.AscendModelSlimConfig")
|
||||
def test_pre_register_and_update_without_parser(self, mock_quant_config,
|
||||
mock_adapt_patch):
|
||||
self.platform.pre_register_and_update(None)
|
||||
|
||||
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
|
||||
|
||||
@patch("vllm_ascend.utils.adapt_patch")
|
||||
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
|
||||
def test_pre_register_and_update_with_parser_no_quant_action(self, mock_quant_config, mock_adapt_patch):
|
||||
@patch("vllm_ascend.quantization.modelslim_config.AscendModelSlimConfig")
|
||||
def test_pre_register_and_update_with_parser_no_quant_action(
|
||||
self, mock_quant_config, mock_adapt_patch):
|
||||
mock_parser = MagicMock()
|
||||
mock_parser._option_string_actions = {}
|
||||
|
||||
@@ -83,8 +86,9 @@ class TestNPUPlatform(TestBase):
|
||||
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
|
||||
|
||||
@patch("vllm_ascend.utils.adapt_patch")
|
||||
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
|
||||
def test_pre_register_and_update_with_existing_ascend_quant(self, mock_quant_config, mock_adapt_patch):
|
||||
@patch("vllm_ascend.quantization.modelslim_config.AscendModelSlimConfig")
|
||||
def test_pre_register_and_update_with_existing_ascend_quant(
|
||||
self, mock_quant_config, mock_adapt_patch):
|
||||
mock_parser = MagicMock()
|
||||
mock_action = MagicMock()
|
||||
mock_action.choices = ["awq", ASCEND_QUANTIZATION_METHOD]
|
||||
|
||||
Reference in New Issue
Block a user