[Refactor] Quantization Module Refactor (#5738)

### Summary

This PR refactors the `vllm_ascend/quantization` module to improve code
organization, maintainability, and extensibility. The refactoring
introduces a clear separation of concerns with a registry-based scheme
discovery pattern, abstract base classes for quantization schemes, and
dedicated wrapper classes.

### Key Changes

#### 1. **Modular Directory Structure**

| Before | After |
|--------|-------|
| Flat file structure with mixed responsibilities | Organized into
`methods/` subpackage for schemes |
| Single `quant_config.py` (600+ lines) | Separate config files:
`modelslim_config.py`, `compressed_tensors_config.py` |
| `utils.py` with scheme lookup logic | `methods/registry.py` with
decorator-based registration |

#### 2. **Registry-Based Scheme Discovery**

Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a
decorator-based registry pattern:

```python
# Before: Manual dictionary mapping
ASCEND_QUANTIZATION_METHOD_MAP = {
    "W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...},
    ...
}

# After: Decorator-based registration
@register_scheme("W8A8_DYNAMIC", "linear")
class AscendW8A8DynamicLinearMethod(AscendLinearScheme):
    ...
```

#### 3. **Abstract Base Classes**

Introduced three abstract base classes in `methods/base.py`:
- `AscendLinearScheme` - Base for linear layer quantization
- `AscendMoEScheme` - Base for MoE layer quantization  
- `AscendAttentionScheme` - Base for attention layer quantization

#### 4. **Separated Config and Wrapper Classes**

- **Config classes** (`AscendModelSlimConfig`,
`AscendCompressedTensorsConfig`): Handle config parsing and scheme
selection
- **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`,
etc.): Implement vLLM interfaces and delegate to schemes

#### 5. **Cleaner Public API**

```python
# New clean module interface
from vllm_ascend.quantization import (
    AscendModelSlimConfig,
    AscendCompressedTensorsConfig,
)
from vllm_ascend.quantization.methods import get_scheme_class
```

### Architecture Diagram

```mermaid
classDiagram
    direction TB
    
    class QuantizationConfig {
        <<vLLM Interface>>
        +get_quant_method()
    }
    
    class AscendModelSlimConfig {
        +quant_description
        +get_quant_method()
        -create_scheme_for_layer()
    }
    
    class AscendCompressedTensorsConfig {
        +target_scheme_map
        +get_quant_method()
        -_get_scheme_from_parts()
    }
    
    class AscendLinearMethod {
        <<Wrapper>>
        +quant_method: AscendLinearScheme
        +create_weights()
        +apply()
    }
    
    class AscendFusedMoEMethod {
        <<Wrapper>>
        +quant_method: AscendMoEScheme
        +create_weights()
        +apply()
    }
    
    class AscendLinearScheme {
        <<Abstract>>
        +get_weight()*
        +apply()*
        +get_pertensor_param()
        +get_perchannel_param()
    }
    
    class AscendMoEScheme {
        <<Abstract>>
        +get_weight()*
        +get_dynamic_quant_param()*
        +apply()*
    }
    
    class W8A8DynamicLinear {
        +get_weight()
        +apply()
    }
    
    class W8A8DynamicMoE {
        +get_weight()
        +apply()
    }
    
    QuantizationConfig <|-- AscendModelSlimConfig
    QuantizationConfig <|-- AscendCompressedTensorsConfig
    
    AscendModelSlimConfig ..> AscendLinearMethod : creates
    AscendModelSlimConfig ..> AscendFusedMoEMethod : creates
    AscendCompressedTensorsConfig ..> AscendLinearMethod : creates
    AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates
    
    AscendLinearMethod o-- AscendLinearScheme : delegates to
    AscendFusedMoEMethod o-- AscendMoEScheme : delegates to
    
    AscendLinearScheme <|-- W8A8DynamicLinear
    AscendMoEScheme <|-- W8A8DynamicMoE
```

### Scheme Registration Flow

```mermaid
sequenceDiagram
    participant Module as Scheme Module
    participant Registry as _SCHEME_REGISTRY
    participant Config as QuantConfig
    participant Wrapper as Wrapper Class
    
    Note over Module: At import time
    Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear")
    Registry->>Registry: Store (quant_type, layer_type) -> Class
    
    Note over Config: At runtime
    Config->>Config: Determine quant_type from description
    Config->>Registry: get_scheme_class(quant_type, layer_type)
    Registry-->>Config: Return scheme class
    Config->>Config: scheme = scheme_cls()
    Config->>Wrapper: Create wrapper with scheme
    Wrapper-->>Config: Return wrapper instance
```

### File Changes Summary

| Original Files | Refactored Files |
|----------------|------------------|
| `__init__.py` (empty) | `__init__.py` (exports public API) |
| `quant_config.py` | `modelslim_config.py` + `wrappers.py` |
| `compressed_tensors/` | `compressed_tensors_config.py` |
| `utils.py` | `methods/registry.py` |
| `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` |
| `w8a8.py` | `methods/w8a8_static.py` |
| `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` |
| ... | `methods/base.py` (new) |

### Benefits

1. **Extensibility**: Adding new quantization schemes only requires
implementing the base class and adding `@register_scheme` decorator
2. **Maintainability**: Clear separation between config parsing, wrapper
logic, and scheme implementation
3. **Testability**: Abstract base classes enable easier unit testing and
mocking
4. **Discoverability**: Registry pattern makes it easy to list all
supported schemes
5. **Reduced Coupling**: Config classes no longer need to know about all
scheme implementations

___

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
Cao Yi
2026-01-23 14:13:47 +08:00
committed by GitHub
parent 8378bc28b0
commit a69ef10c3a
36 changed files with 2044 additions and 1524 deletions

View File

@@ -20,8 +20,6 @@ import torch
import torch.nn as nn
import torch_npu
from pytest_mock import MockerFixture
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
@@ -233,25 +231,6 @@ class MockQuantMethod(nn.Module):
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32)))
class MockFusedMoEMethod(FusedMoEMethodBase):
moe = MagicMock()
def __init__(self):
super().__init__(self.moe)
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
pass
def apply(self, hidden_states: torch.Tensor,
expert_weights: torch.Tensor) -> torch.Tensor:
pass
def get_fused_moe_quant_config(self, layer: torch.nn.Module):
pass
class TestExpertsSelector:
@pytest.mark.parametrize("global_num_experts", [256, 128])

View File

@@ -7,11 +7,11 @@ from vllm.model_executor.layers.linear import LinearBase
from tests.ut.base import TestBase
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
from vllm_ascend.quantization.quant_config import AscendQuantConfig
from vllm_ascend.quantization.modelslim_config import AscendModelSlimConfig
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
class TestAscendQuantConfig(TestBase):
class TestAscendModelSlimConfig(TestBase):
def setUp(self):
self.sample_config = {
@@ -25,7 +25,7 @@ class TestAscendQuantConfig(TestBase):
"shard1.weight": "FLOAT",
"shard2.weight": "FLOAT",
}
self.ascend_config = AscendQuantConfig(self.sample_config)
self.ascend_config = AscendModelSlimConfig(self.sample_config)
self.ascend_config.packed_modules_mapping = None
def test_init(self):
@@ -34,55 +34,55 @@ class TestAscendQuantConfig(TestBase):
def test_repr(self):
repr_str = repr(self.ascend_config)
self.assertTrue(repr_str.startswith("AscendQuantConfig:\n"))
self.assertTrue(repr_str.startswith("AscendModelSlimConfig:\n"))
def test_get_name(self):
self.assertEqual(AscendQuantConfig.get_name(),
self.assertEqual(AscendModelSlimConfig.get_name(),
ASCEND_QUANTIZATION_METHOD)
def test_get_supported_act_dtypes(self):
supported_dtypes = AscendQuantConfig.get_supported_act_dtypes()
supported_dtypes = AscendModelSlimConfig.get_supported_act_dtypes()
self.assertEqual(len(supported_dtypes), 3)
def test_get_min_capability(self):
with self.assertRaises(NotImplementedError):
AscendQuantConfig.get_min_capability()
AscendModelSlimConfig.get_min_capability()
def test_get_config_filenames(self):
filenames = AscendQuantConfig.get_config_filenames()
filenames = AscendModelSlimConfig.get_config_filenames()
self.assertEqual(filenames, ["quant_model_description.json"])
def test_from_config(self):
config = AscendQuantConfig.from_config(self.sample_config)
self.assertIsInstance(config, AscendQuantConfig)
config = AscendModelSlimConfig.from_config(self.sample_config)
self.assertIsInstance(config, AscendModelSlimConfig)
self.assertEqual(config.quant_description, self.sample_config)
@patch('torch.npu.is_available')
def test_override_quantization_method(self, mock_is_available):
# Test when NPU is available
mock_is_available.return_value = True
result = AscendQuantConfig.override_quantization_method(None, None)
result = AscendModelSlimConfig.override_quantization_method(None, None)
self.assertIsNone(result)
hf_quant_cfg = {"quant_method": ""}
result = AscendQuantConfig.override_quantization_method(
result = AscendModelSlimConfig.override_quantization_method(
hf_quant_cfg, None)
self.assertEqual(result, "ascend")
# Test when NPU is not available
mock_is_available.return_value = False
result = AscendQuantConfig.override_quantization_method(None, None)
result = AscendModelSlimConfig.override_quantization_method(None, None)
self.assertIsNone(result)
hf_quant_cfg = {"quant_method": ""}
result = AscendQuantConfig.override_quantization_method(
result = AscendModelSlimConfig.override_quantization_method(
hf_quant_cfg, None)
self.assertIsNone(result)
def test_get_quant_method_for_linear(self):
mock_config = MagicMock()
mock_config.model_config.hf_text_config.model_type = None
mock_config.model_config.hf_config.model_type = None
linear_layer = MagicMock(spec=LinearBase)
# Test skipped layer
with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
with patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \
patch.object(self.ascend_config, \
'is_layer_skipped_ascend',
return_value=True):
@@ -90,22 +90,24 @@ class TestAscendQuantConfig(TestBase):
self.assertIsInstance(method, AscendUnquantizedLinearMethod)
# Test quantized layer
mock_scheme = MagicMock()
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
patch('vllm_ascend.quantization.quant_config.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear:
patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \
patch("vllm_ascend.quantization.modelslim_config.create_scheme_for_layer", return_value=mock_scheme), \
patch('vllm_ascend.quantization.method_adapters.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear:
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
self.assertIs(method, mock_ascend_linear.return_value)
mock_ascend_linear.assert_called_once_with(
self.ascend_config, ".attn",
self.ascend_config.packed_modules_mapping, linear_layer)
mock_ascend_linear.assert_called_once_with(mock_scheme)
def test_get_quant_method_for_attention(self):
attention_layer = MagicMock(spec=Attention)
mock_config = MagicMock()
mock_config.model_config.hf_text_config.model_type = None
with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod', \
mock_config.model_config.hf_config.model_type = None
mock_scheme = MagicMock()
with patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \
patch("vllm_ascend.quantization.modelslim_config.create_scheme_for_layer", return_value=mock_scheme), \
patch('vllm_ascend.quantization.method_adapters.AscendKVCacheMethod', \
return_value=MagicMock()) as mock_ascend_kvcache:
# Test with fa_quant_type
method = self.ascend_config.get_quant_method(
@@ -117,20 +119,22 @@ class TestAscendQuantConfig(TestBase):
fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig)
fused_moe_layer.moe_config = MagicMock(spec=FusedMoEConfig)
mock_config = MagicMock()
mock_config.model_config.hf_text_config.model_type = None
mock_config.model_config.hf_config.model_type = None
# Test skipped layer
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
patch('vllm_ascend.quantization.quant_config.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \
patch('vllm_ascend.ops.fused_moe.fused_moe.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
method = self.ascend_config.get_quant_method(
fused_moe_layer, "moe_layer")
self.assertIs(method, mock_ascend_moe.return_value)
# Test quantized layer
mock_scheme = MagicMock()
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
patch('vllm_ascend.quantization.quant_config.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \
patch("vllm_ascend.quantization.modelslim_config.create_scheme_for_layer", return_value=mock_scheme), \
patch('vllm_ascend.quantization.method_adapters.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
method = self.ascend_config.get_quant_method(
fused_moe_layer, "moe_layer")
self.assertIs(method, mock_ascend_moe.return_value)
@@ -150,7 +154,7 @@ class TestAscendQuantConfig(TestBase):
# Test inconsistent fused layer shards
bad_config = {"shard1.weight": "FLOAT", "shard2.weight": "INT8"}
config = AscendQuantConfig(bad_config)
config = AscendModelSlimConfig(bad_config)
with self.assertRaises(ValueError):
config.is_layer_skipped_ascend("fused_layer", fused_mapping)

View File

@@ -1,50 +0,0 @@
import types
from tests.ut.base import TestBase
from vllm_ascend.quantization.utils import (ASCEND_QUANTIZATION_METHOD_MAP,
get_quant_method)
class TestGetQuantMethod(TestBase):
def setUp(self):
self.original_quantization_method_map = ASCEND_QUANTIZATION_METHOD_MAP.copy(
)
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
for layer_type in layer_map.keys():
ASCEND_QUANTIZATION_METHOD_MAP[quant_type][
layer_type] = types.new_class(f"{quant_type}_{layer_type}")
def tearDown(self):
# Restore original map
ASCEND_QUANTIZATION_METHOD_MAP.clear()
ASCEND_QUANTIZATION_METHOD_MAP.update(
self.original_quantization_method_map)
def test_linear_quant_methods(self):
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
if "linear" in layer_map.keys():
prefix = "linear_layer"
cls = layer_map["linear"]
method = get_quant_method({"linear_layer.weight": quant_type},
prefix, "linear")
self.assertIsInstance(method, cls)
def test_moe_quant_methods(self):
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
if "moe" in layer_map.keys():
prefix = "layer"
cls = layer_map["moe"]
method = get_quant_method({"layer.weight": quant_type}, prefix,
"moe")
self.assertIsInstance(method, cls)
def test_invalid_layer_type(self):
quant_description = {"linear_layer.weight": "W8A8"}
with self.assertRaises(NotImplementedError):
get_quant_method(quant_description, "linear_layer", "unsupported")
def test_invalid_quant_type(self):
quant_description = {"linear_layer.weight": "UNKNOWN"}
with self.assertRaises(NotImplementedError):
get_quant_method(quant_description, "linear_layer", "linear")

View File

@@ -3,8 +3,9 @@ from unittest.mock import Mock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.quantization.w4a16 import (AscendW4A16FusedMoEMethod,
pack_to_int32, unpack_from_int32)
from vllm_ascend.quantization.methods.w4a16 import (AscendW4A16FusedMoEMethod,
pack_to_int32,
unpack_from_int32)
class TestUnpackFromInt32(TestBase):
@@ -42,7 +43,7 @@ class TestUnpackFromInt32(TestBase):
class TestPackToInt32(TestBase):
@patch(
"vllm_ascend.quantization.w4a16.torch_npu.npu_convert_weight_to_int4pack"
"vllm_ascend.quantization.methods.w4a16.torch_npu.npu_convert_weight_to_int4pack"
)
def test_pack_to_int32_int8(self, mock_npu_convert_weight_to_int4pack):
mock_npu_convert_weight_to_int4pack.return_value = torch.zeros(
@@ -57,7 +58,7 @@ class TestPackToInt32(TestBase):
self.assertEqual(result.shape, torch.Size([2, 8, 4]))
@patch(
"vllm_ascend.quantization.w4a16.torch_npu.npu_convert_weight_to_int4pack"
"vllm_ascend.quantization.methods.w4a16.torch_npu.npu_convert_weight_to_int4pack"
)
def test_pack_to_int32_int32(self, mock_npu_convert_weight_to_int4pack):
@@ -97,8 +98,8 @@ class TestAscendW4A16FusedMoEMethod(TestBase):
output_size = 128
group_size = 32
@patch("vllm_ascend.quantization.w4a16.get_ascend_config")
@patch("vllm_ascend.quantization.w4a16.get_current_vllm_config")
@patch("vllm_ascend.quantization.methods.w4a16.get_ascend_config")
@patch("vllm_ascend.quantization.methods.w4a16.get_current_vllm_config")
def setUp(self, mock_get_current_vllm_config, mock_get_ascend_config):
mock_ascend_config = Mock()
mock_ascend_config.eplb_config.dynamic_eplb = False
@@ -218,7 +219,7 @@ class TestAscendW4A16FusedMoEMethod(TestBase):
return layer
@patch(
"vllm_ascend.quantization.w4a16.torch_npu.npu_convert_weight_to_int4pack"
"vllm_ascend.quantization.methods.w4a16.torch_npu.npu_convert_weight_to_int4pack"
)
def test_process_weights_after_loading_with_transpose(
self, mock_npu_convert_weight_to_int4pack):

View File

@@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch
import torch
import torch.nn as nn
from vllm_ascend.quantization.w4a4_flatquant_dynamic import (
from vllm_ascend.quantization.methods.w4a4_flatquant import (
AscendW4A4FlatQuantDynamicLinearMethod, get_decompose_dim,
pack_int4_weights)
@@ -33,7 +33,7 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
self.assertEqual(get_decompose_dim(100), (10, 10))
self.assertEqual(get_decompose_dim(99), (9, 11))
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu')
@patch('vllm_ascend.quantization.methods.w4a4_flatquant.torch_npu')
def test_pack_int4_weights_npu_success(self, mock_torch_npu):
"""
Tests weight packing using the mocked NPU kernel.
@@ -119,7 +119,7 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
x = torch.randn(batch_size, self.input_size, dtype=self.params_dtype)
return layer, x, m, n
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu')
@patch('vllm_ascend.quantization.methods.w4a4_flatquant.torch_npu')
def test_apply_small_batch(self, mock_torch_npu):
"""Tests the apply method with a batch size smaller than MAX_BATCH_SIZE."""
batch_size = 128
@@ -143,9 +143,9 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
self.assertEqual(output.shape, (batch_size, self.output_size))
@patch(
'vllm_ascend.quantization.w4a4_flatquant_dynamic.KRONECKER_QUANT_MAX_BATCH_SIZE',
'vllm_ascend.quantization.methods.w4a4_flatquant.KRONECKER_QUANT_MAX_BATCH_SIZE',
10)
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu')
@patch('vllm_ascend.quantization.methods.w4a4_flatquant.torch_npu')
def test_apply_large_batch(self, mock_torch_npu):
"""Tests the apply method with a batch size larger than MAX_BATCH_SIZE."""
batch_size = 25
@@ -178,7 +178,7 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
ValueError, "FlatQuant transform matrices dimension mismatch"):
self.method.apply(layer, x)
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.pack_int4_weights')
@patch('vllm_ascend.quantization.methods.w4a4_flatquant.pack_int4_weights')
def test_process_weights_after_loading(self, mock_pack_weights):
"""Tests weight processing after loading, without transpose."""
layer = nn.Module()

View File

@@ -3,14 +3,14 @@ from unittest.mock import Mock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.quantization.w4a8_dynamic import (
from vllm_ascend.quantization.methods.w4a8 import (
AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod)
class TestAscendW4A8DynamicLinearMethod(TestBase):
@patch('vllm.distributed.get_tensor_model_parallel_world_size')
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
@patch('vllm_ascend.quantization.methods.w4a8.get_current_vllm_config')
def setUp(self, mock_get_current_vllm_config, mock_get_tp_world_size):
mock_get_tp_world_size.return_value = 1
mock_vllm_config = Mock()
@@ -127,10 +127,10 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
output_size = 56
group_size = 2
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ascend_config')
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group')
@patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group')
@patch('vllm_ascend.quantization.methods.w4a8.get_ascend_config')
@patch('vllm_ascend.quantization.methods.w4a8.get_current_vllm_config')
@patch('vllm_ascend.quantization.methods.w4a8.get_ep_group')
@patch('vllm_ascend.quantization.methods.w4a8.get_mc2_group')
@patch('torch.distributed.get_rank', return_value=0)
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ep_group,
get_current_vllm_config, mock_get_ascend_config):

View File

@@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.quantization.w8a16 import AscendW8A16LinearMethod
from vllm_ascend.quantization.methods.w8a16 import AscendW8A16LinearMethod
class TestAscendW8A16LinearMethod(TestBase):

View File

@@ -4,36 +4,10 @@ from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.quantization.w8a8 import (AscendW8A8LinearMethod,
quant_per_tensor)
from vllm_ascend.quantization.methods.w8a8_static import AscendW8A8LinearMethod
from vllm_ascend.utils import AscendDeviceType
class TestQuantPerTensor(TestBase):
@patch("torch_npu.npu_quantize")
def test_quant_per_tensor(self, mock_npu_quantize):
in_tensor = torch.randn(32, 128)
input_scale = torch.tensor(0.1)
input_offset = torch.tensor(0)
expected_output = torch.randint(-128, 127, (32, 128), dtype=torch.int8)
mock_npu_quantize.return_value = expected_output
output = quant_per_tensor(in_tensor, input_scale, input_offset)
mock_npu_quantize.assert_called_once_with(
in_tensor,
input_scale,
input_offset,
torch.qint8,
-1,
False,
)
self.assertTrue(torch.equal(output, expected_output))
class TestAscendW8A8LinearMethod(TestBase):
def setUp(self):
@@ -63,7 +37,9 @@ class TestAscendW8A8LinearMethod(TestBase):
self.assertEqual(params['weight_scale'].shape, (10, 1))
self.assertEqual(params['weight_offset'].shape, (10, 1))
@patch("vllm_ascend.quantization.w8a8.get_weight_prefetch_method")
@patch(
"vllm_ascend.quantization.methods.w8a8_static.get_weight_prefetch_method"
)
@patch("torch.ops.vllm.quantize")
@patch("torch_npu.npu_quant_matmul")
def test_apply_with_x_not_int8(self, mock_npu_quant_matmul, mock_quantize,

View File

@@ -3,7 +3,7 @@ from unittest.mock import Mock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.quantization.w8a8_dynamic import \
from vllm_ascend.quantization.methods.w8a8_dynamic import \
AscendW8A8DynamicFusedMoEMethod
@@ -13,13 +13,13 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
intermediate_size = 128
@patch("torch.distributed.get_rank")
@patch("vllm_ascend.quantization.w8a8_dynamic.get_mc2_group")
@patch("vllm_ascend.quantization.w8a8_dynamic.get_ascend_config")
@patch("vllm_ascend.quantization.w8a8_dynamic.get_ep_group")
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.get_mc2_group")
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.get_ascend_config")
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.get_ep_group")
def setUp(self, mock_get_ep_group, mock_get_ascend_config,
mock_get_mc2_group, mock_get_rank):
with patch(
'vllm_ascend.quantization.w8a8_dynamic.get_current_vllm_config'
'vllm_ascend.quantization.methods.w8a8_dynamic.get_current_vllm_config'
) as mock_get_current_vllm_config:
mock_vllm_config = Mock()
mock_vllm_config.quant_config = Mock(

View File

@@ -51,8 +51,9 @@ class TestNPUPlatform(TestBase):
self.assertTrue(self.platform.is_sleep_mode_available())
@patch("vllm_ascend.utils.adapt_patch")
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
def test_pre_register_and_update_with_parser(self, mock_quant_config, mock_adapt_patch):
@patch("vllm_ascend.quantization.modelslim_config.AscendModelSlimConfig")
def test_pre_register_and_update_with_parser(self, mock_quant_config,
mock_adapt_patch):
mock_parser = MagicMock()
mock_action = MagicMock()
mock_action.choices = ["awq", "gptq"]
@@ -66,15 +67,17 @@ class TestNPUPlatform(TestBase):
self.assertEqual(len(mock_action.choices), 3) # original 2 + ascend
@patch("vllm_ascend.utils.adapt_patch")
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
def test_pre_register_and_update_without_parser(self, mock_quant_config, mock_adapt_patch):
@patch("vllm_ascend.quantization.modelslim_config.AscendModelSlimConfig")
def test_pre_register_and_update_without_parser(self, mock_quant_config,
mock_adapt_patch):
self.platform.pre_register_and_update(None)
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
@patch("vllm_ascend.utils.adapt_patch")
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
def test_pre_register_and_update_with_parser_no_quant_action(self, mock_quant_config, mock_adapt_patch):
@patch("vllm_ascend.quantization.modelslim_config.AscendModelSlimConfig")
def test_pre_register_and_update_with_parser_no_quant_action(
self, mock_quant_config, mock_adapt_patch):
mock_parser = MagicMock()
mock_parser._option_string_actions = {}
@@ -83,8 +86,9 @@ class TestNPUPlatform(TestBase):
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
@patch("vllm_ascend.utils.adapt_patch")
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
def test_pre_register_and_update_with_existing_ascend_quant(self, mock_quant_config, mock_adapt_patch):
@patch("vllm_ascend.quantization.modelslim_config.AscendModelSlimConfig")
def test_pre_register_and_update_with_existing_ascend_quant(
self, mock_quant_config, mock_adapt_patch):
mock_parser = MagicMock()
mock_action = MagicMock()
mock_action.choices = ["awq", ASCEND_QUANTIZATION_METHOD]