[Refactor] Quantization Module Refactor (#5738)

### Summary

This PR refactors the `vllm_ascend/quantization` module to improve code
organization, maintainability, and extensibility. The refactoring
introduces a clear separation of concerns with a registry-based scheme
discovery pattern, abstract base classes for quantization schemes, and
dedicated wrapper classes.

### Key Changes

#### 1. **Modular Directory Structure**

| Before | After |
|--------|-------|
| Flat file structure with mixed responsibilities | Organized into
`methods/` subpackage for schemes |
| Single `quant_config.py` (600+ lines) | Separate config files:
`modelslim_config.py`, `compressed_tensors_config.py` |
| `utils.py` with scheme lookup logic | `methods/registry.py` with
decorator-based registration |

#### 2. **Registry-Based Scheme Discovery**

Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a
decorator-based registry pattern:

```python
# Before: Manual dictionary mapping
ASCEND_QUANTIZATION_METHOD_MAP = {
    "W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...},
    ...
}

# After: Decorator-based registration
@register_scheme("W8A8_DYNAMIC", "linear")
class AscendW8A8DynamicLinearMethod(AscendLinearScheme):
    ...
```

#### 3. **Abstract Base Classes**

Introduced three abstract base classes in `methods/base.py`:
- `AscendLinearScheme` - Base for linear layer quantization
- `AscendMoEScheme` - Base for MoE layer quantization  
- `AscendAttentionScheme` - Base for attention layer quantization

#### 4. **Separated Config and Wrapper Classes**

- **Config classes** (`AscendModelSlimConfig`,
`AscendCompressedTensorsConfig`): Handle config parsing and scheme
selection
- **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`,
etc.): Implement vLLM interfaces and delegate to schemes

#### 5. **Cleaner Public API**

```python
# New clean module interface
from vllm_ascend.quantization import (
    AscendModelSlimConfig,
    AscendCompressedTensorsConfig,
)
from vllm_ascend.quantization.methods import get_scheme_class
```

### Architecture Diagram

```mermaid
classDiagram
    direction TB
    
    class QuantizationConfig {
        <<vLLM Interface>>
        +get_quant_method()
    }
    
    class AscendModelSlimConfig {
        +quant_description
        +get_quant_method()
        -create_scheme_for_layer()
    }
    
    class AscendCompressedTensorsConfig {
        +target_scheme_map
        +get_quant_method()
        -_get_scheme_from_parts()
    }
    
    class AscendLinearMethod {
        <<Wrapper>>
        +quant_method: AscendLinearScheme
        +create_weights()
        +apply()
    }
    
    class AscendFusedMoEMethod {
        <<Wrapper>>
        +quant_method: AscendMoEScheme
        +create_weights()
        +apply()
    }
    
    class AscendLinearScheme {
        <<Abstract>>
        +get_weight()*
        +apply()*
        +get_pertensor_param()
        +get_perchannel_param()
    }
    
    class AscendMoEScheme {
        <<Abstract>>
        +get_weight()*
        +get_dynamic_quant_param()*
        +apply()*
    }
    
    class W8A8DynamicLinear {
        +get_weight()
        +apply()
    }
    
    class W8A8DynamicMoE {
        +get_weight()
        +apply()
    }
    
    QuantizationConfig <|-- AscendModelSlimConfig
    QuantizationConfig <|-- AscendCompressedTensorsConfig
    
    AscendModelSlimConfig ..> AscendLinearMethod : creates
    AscendModelSlimConfig ..> AscendFusedMoEMethod : creates
    AscendCompressedTensorsConfig ..> AscendLinearMethod : creates
    AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates
    
    AscendLinearMethod o-- AscendLinearScheme : delegates to
    AscendFusedMoEMethod o-- AscendMoEScheme : delegates to
    
    AscendLinearScheme <|-- W8A8DynamicLinear
    AscendMoEScheme <|-- W8A8DynamicMoE
```

### Scheme Registration Flow

```mermaid
sequenceDiagram
    participant Module as Scheme Module
    participant Registry as _SCHEME_REGISTRY
    participant Config as QuantConfig
    participant Wrapper as Wrapper Class
    
    Note over Module: At import time
    Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear")
    Registry->>Registry: Store (quant_type, layer_type) -> Class
    
    Note over Config: At runtime
    Config->>Config: Determine quant_type from description
    Config->>Registry: get_scheme_class(quant_type, layer_type)
    Registry-->>Config: Return scheme class
    Config->>Config: scheme = scheme_cls()
    Config->>Wrapper: Create wrapper with scheme
    Wrapper-->>Config: Return wrapper instance
```

### File Changes Summary

| Original Files | Refactored Files |
|----------------|------------------|
| `__init__.py` (empty) | `__init__.py` (exports public API) |
| `quant_config.py` | `modelslim_config.py` + `wrappers.py` |
| `compressed_tensors/` | `compressed_tensors_config.py` |
| `utils.py` | `methods/registry.py` |
| `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` |
| `w8a8.py` | `methods/w8a8_static.py` |
| `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` |
| ... | `methods/base.py` (new) |

### Benefits

1. **Extensibility**: Adding new quantization schemes only requires
implementing the base class and adding `@register_scheme` decorator
2. **Maintainability**: Clear separation between config parsing, wrapper
logic, and scheme implementation
3. **Testability**: Abstract base classes enable easier unit testing and
mocking
4. **Discoverability**: Registry pattern makes it easy to list all
supported schemes
5. **Reduced Coupling**: Config classes no longer need to know about all
scheme implementations

___

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
Cao Yi
2026-01-23 14:13:47 +08:00
committed by GitHub
parent 8378bc28b0
commit a69ef10c3a
36 changed files with 2044 additions and 1524 deletions

View File

@@ -10,7 +10,7 @@ The current process for registering and obtaining quantization methods in vLLM A
![get_quant_method](../../assets/quantization/get_quant_method.png)
vLLM Ascend registers a custom ascend quantization method. By configuring the `--quantization ascend` parameter (or `quantization="ascend"` for offline), the quantization feature is enabled. When constructing the `quant_config`, the registered `AscendQuantConfig` is initialized and `get_quant_method` is called to obtain the quantization method corresponding to each weight part, stored in the `quant_method` attribute.
vLLM Ascend registers a custom ascend quantization method. By configuring the `--quantization ascend` parameter (or `quantization="ascend"` for offline), the quantization feature is enabled. When constructing the `quant_config`, the registered `AscendModelSlimConfig` is initialized and `get_quant_method` is called to obtain the quantization method corresponding to each weight part, stored in the `quant_method` attribute.
Currently supported quantization methods include `AscendLinearMethod`, `AscendFusedMoEMethod`, `AscendEmbeddingMethod`, and their corresponding non-quantized methods:
@@ -51,18 +51,21 @@ Based on the above content, we present a brief description of the adaptation pro
### Quantization Algorithm Adaptation
- **Step 1: Algorithm Design**. Define the algorithm ID (e.g., `W4A8_DYNAMIC`), determine supported layers (linear, moe, attention), and design the quantization scheme (static/dynamic, pertensor/perchannel/pergroup).
- **Step 2: Registration**. Add the algorithm ID to `ASCEND_QUANTIZATION_METHOD_MAP` in `vllm_ascend/quantization/utils.py` and associate it with the corresponding method class.
- **Step 2: Registration**. Use the `@register_scheme` decorator in `vllm_ascend/quantization/methods/registry.py` to register your quantization scheme class.
```python
ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = {
"W4A8_DYNAMIC": {
"linear": AscendW4A8DynamicLinearMethod,
"moe": AscendW4A8DynamicFusedMoEMethod,
},
}
from vllm_ascend.quantization.methods import register_scheme, AscendLinearScheme
@register_scheme("W4A8_DYNAMIC", "linear")
class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
...
@register_scheme("W4A8_DYNAMIC", "moe")
class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
...
```
- **Step 3: Implementation**. Create an algorithm implementation file, such as `vllm_ascend/quantization/w4a8_dynamic.py`, and implement the method class and logic.
- **Step 3: Implementation**. Create an algorithm implementation file, such as `vllm_ascend/quantization/methods/w4a8.py`, and implement the method class and logic.
- **Step 4: Testing**. Use your algorithm to generate quantization configurations and verify correctness and performance on target models and hardware.
### Quantized Model Adaptation
@@ -70,7 +73,7 @@ ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = {
Adapting a new quantized model requires ensuring the following three points:
- The original model has been successfully adapted in `vLLM Ascend`.
- **Fused Module Mapping**: Add the model's `model_type` to `packed_modules_model_mapping` in `vllm_ascend/quantization/quant_config.py` (e.g., `qkv_proj`, `gate_up_proj`, `experts`) to ensure sharding consistency and correct loading.
- **Fused Module Mapping**: Add the model's `model_type` to `packed_modules_model_mapping` in `vllm_ascend/quantization/modelslim_config.py` (e.g., `qkv_proj`, `gate_up_proj`, `experts`) to ensure sharding consistency and correct loading.
```python
packed_modules_model_mapping = {

View File

@@ -20,8 +20,6 @@ import torch
import torch.nn as nn
import torch_npu
from pytest_mock import MockerFixture
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
@@ -233,25 +231,6 @@ class MockQuantMethod(nn.Module):
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32)))
class MockFusedMoEMethod(FusedMoEMethodBase):
moe = MagicMock()
def __init__(self):
super().__init__(self.moe)
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
pass
def apply(self, hidden_states: torch.Tensor,
expert_weights: torch.Tensor) -> torch.Tensor:
pass
def get_fused_moe_quant_config(self, layer: torch.nn.Module):
pass
class TestExpertsSelector:
@pytest.mark.parametrize("global_num_experts", [256, 128])

View File

@@ -7,11 +7,11 @@ from vllm.model_executor.layers.linear import LinearBase
from tests.ut.base import TestBase
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
from vllm_ascend.quantization.quant_config import AscendQuantConfig
from vllm_ascend.quantization.modelslim_config import AscendModelSlimConfig
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
class TestAscendQuantConfig(TestBase):
class TestAscendModelSlimConfig(TestBase):
def setUp(self):
self.sample_config = {
@@ -25,7 +25,7 @@ class TestAscendQuantConfig(TestBase):
"shard1.weight": "FLOAT",
"shard2.weight": "FLOAT",
}
self.ascend_config = AscendQuantConfig(self.sample_config)
self.ascend_config = AscendModelSlimConfig(self.sample_config)
self.ascend_config.packed_modules_mapping = None
def test_init(self):
@@ -34,55 +34,55 @@ class TestAscendQuantConfig(TestBase):
def test_repr(self):
repr_str = repr(self.ascend_config)
self.assertTrue(repr_str.startswith("AscendQuantConfig:\n"))
self.assertTrue(repr_str.startswith("AscendModelSlimConfig:\n"))
def test_get_name(self):
self.assertEqual(AscendQuantConfig.get_name(),
self.assertEqual(AscendModelSlimConfig.get_name(),
ASCEND_QUANTIZATION_METHOD)
def test_get_supported_act_dtypes(self):
supported_dtypes = AscendQuantConfig.get_supported_act_dtypes()
supported_dtypes = AscendModelSlimConfig.get_supported_act_dtypes()
self.assertEqual(len(supported_dtypes), 3)
def test_get_min_capability(self):
with self.assertRaises(NotImplementedError):
AscendQuantConfig.get_min_capability()
AscendModelSlimConfig.get_min_capability()
def test_get_config_filenames(self):
filenames = AscendQuantConfig.get_config_filenames()
filenames = AscendModelSlimConfig.get_config_filenames()
self.assertEqual(filenames, ["quant_model_description.json"])
def test_from_config(self):
config = AscendQuantConfig.from_config(self.sample_config)
self.assertIsInstance(config, AscendQuantConfig)
config = AscendModelSlimConfig.from_config(self.sample_config)
self.assertIsInstance(config, AscendModelSlimConfig)
self.assertEqual(config.quant_description, self.sample_config)
@patch('torch.npu.is_available')
def test_override_quantization_method(self, mock_is_available):
# Test when NPU is available
mock_is_available.return_value = True
result = AscendQuantConfig.override_quantization_method(None, None)
result = AscendModelSlimConfig.override_quantization_method(None, None)
self.assertIsNone(result)
hf_quant_cfg = {"quant_method": ""}
result = AscendQuantConfig.override_quantization_method(
result = AscendModelSlimConfig.override_quantization_method(
hf_quant_cfg, None)
self.assertEqual(result, "ascend")
# Test when NPU is not available
mock_is_available.return_value = False
result = AscendQuantConfig.override_quantization_method(None, None)
result = AscendModelSlimConfig.override_quantization_method(None, None)
self.assertIsNone(result)
hf_quant_cfg = {"quant_method": ""}
result = AscendQuantConfig.override_quantization_method(
result = AscendModelSlimConfig.override_quantization_method(
hf_quant_cfg, None)
self.assertIsNone(result)
def test_get_quant_method_for_linear(self):
mock_config = MagicMock()
mock_config.model_config.hf_text_config.model_type = None
mock_config.model_config.hf_config.model_type = None
linear_layer = MagicMock(spec=LinearBase)
# Test skipped layer
with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
with patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \
patch.object(self.ascend_config, \
'is_layer_skipped_ascend',
return_value=True):
@@ -90,22 +90,24 @@ class TestAscendQuantConfig(TestBase):
self.assertIsInstance(method, AscendUnquantizedLinearMethod)
# Test quantized layer
mock_scheme = MagicMock()
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
patch('vllm_ascend.quantization.quant_config.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear:
patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \
patch("vllm_ascend.quantization.modelslim_config.create_scheme_for_layer", return_value=mock_scheme), \
patch('vllm_ascend.quantization.method_adapters.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear:
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
self.assertIs(method, mock_ascend_linear.return_value)
mock_ascend_linear.assert_called_once_with(
self.ascend_config, ".attn",
self.ascend_config.packed_modules_mapping, linear_layer)
mock_ascend_linear.assert_called_once_with(mock_scheme)
def test_get_quant_method_for_attention(self):
attention_layer = MagicMock(spec=Attention)
mock_config = MagicMock()
mock_config.model_config.hf_text_config.model_type = None
with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod', \
mock_config.model_config.hf_config.model_type = None
mock_scheme = MagicMock()
with patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \
patch("vllm_ascend.quantization.modelslim_config.create_scheme_for_layer", return_value=mock_scheme), \
patch('vllm_ascend.quantization.method_adapters.AscendKVCacheMethod', \
return_value=MagicMock()) as mock_ascend_kvcache:
# Test with fa_quant_type
method = self.ascend_config.get_quant_method(
@@ -117,20 +119,22 @@ class TestAscendQuantConfig(TestBase):
fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig)
fused_moe_layer.moe_config = MagicMock(spec=FusedMoEConfig)
mock_config = MagicMock()
mock_config.model_config.hf_text_config.model_type = None
mock_config.model_config.hf_config.model_type = None
# Test skipped layer
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
patch('vllm_ascend.quantization.quant_config.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \
patch('vllm_ascend.ops.fused_moe.fused_moe.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
method = self.ascend_config.get_quant_method(
fused_moe_layer, "moe_layer")
self.assertIs(method, mock_ascend_moe.return_value)
# Test quantized layer
mock_scheme = MagicMock()
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
patch('vllm_ascend.quantization.quant_config.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \
patch("vllm_ascend.quantization.modelslim_config.create_scheme_for_layer", return_value=mock_scheme), \
patch('vllm_ascend.quantization.method_adapters.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
method = self.ascend_config.get_quant_method(
fused_moe_layer, "moe_layer")
self.assertIs(method, mock_ascend_moe.return_value)
@@ -150,7 +154,7 @@ class TestAscendQuantConfig(TestBase):
# Test inconsistent fused layer shards
bad_config = {"shard1.weight": "FLOAT", "shard2.weight": "INT8"}
config = AscendQuantConfig(bad_config)
config = AscendModelSlimConfig(bad_config)
with self.assertRaises(ValueError):
config.is_layer_skipped_ascend("fused_layer", fused_mapping)

View File

@@ -1,50 +0,0 @@
import types
from tests.ut.base import TestBase
from vllm_ascend.quantization.utils import (ASCEND_QUANTIZATION_METHOD_MAP,
get_quant_method)
class TestGetQuantMethod(TestBase):
def setUp(self):
self.original_quantization_method_map = ASCEND_QUANTIZATION_METHOD_MAP.copy(
)
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
for layer_type in layer_map.keys():
ASCEND_QUANTIZATION_METHOD_MAP[quant_type][
layer_type] = types.new_class(f"{quant_type}_{layer_type}")
def tearDown(self):
# Restore original map
ASCEND_QUANTIZATION_METHOD_MAP.clear()
ASCEND_QUANTIZATION_METHOD_MAP.update(
self.original_quantization_method_map)
def test_linear_quant_methods(self):
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
if "linear" in layer_map.keys():
prefix = "linear_layer"
cls = layer_map["linear"]
method = get_quant_method({"linear_layer.weight": quant_type},
prefix, "linear")
self.assertIsInstance(method, cls)
def test_moe_quant_methods(self):
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
if "moe" in layer_map.keys():
prefix = "layer"
cls = layer_map["moe"]
method = get_quant_method({"layer.weight": quant_type}, prefix,
"moe")
self.assertIsInstance(method, cls)
def test_invalid_layer_type(self):
quant_description = {"linear_layer.weight": "W8A8"}
with self.assertRaises(NotImplementedError):
get_quant_method(quant_description, "linear_layer", "unsupported")
def test_invalid_quant_type(self):
quant_description = {"linear_layer.weight": "UNKNOWN"}
with self.assertRaises(NotImplementedError):
get_quant_method(quant_description, "linear_layer", "linear")

View File

@@ -3,8 +3,9 @@ from unittest.mock import Mock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.quantization.w4a16 import (AscendW4A16FusedMoEMethod,
pack_to_int32, unpack_from_int32)
from vllm_ascend.quantization.methods.w4a16 import (AscendW4A16FusedMoEMethod,
pack_to_int32,
unpack_from_int32)
class TestUnpackFromInt32(TestBase):
@@ -42,7 +43,7 @@ class TestUnpackFromInt32(TestBase):
class TestPackToInt32(TestBase):
@patch(
"vllm_ascend.quantization.w4a16.torch_npu.npu_convert_weight_to_int4pack"
"vllm_ascend.quantization.methods.w4a16.torch_npu.npu_convert_weight_to_int4pack"
)
def test_pack_to_int32_int8(self, mock_npu_convert_weight_to_int4pack):
mock_npu_convert_weight_to_int4pack.return_value = torch.zeros(
@@ -57,7 +58,7 @@ class TestPackToInt32(TestBase):
self.assertEqual(result.shape, torch.Size([2, 8, 4]))
@patch(
"vllm_ascend.quantization.w4a16.torch_npu.npu_convert_weight_to_int4pack"
"vllm_ascend.quantization.methods.w4a16.torch_npu.npu_convert_weight_to_int4pack"
)
def test_pack_to_int32_int32(self, mock_npu_convert_weight_to_int4pack):
@@ -97,8 +98,8 @@ class TestAscendW4A16FusedMoEMethod(TestBase):
output_size = 128
group_size = 32
@patch("vllm_ascend.quantization.w4a16.get_ascend_config")
@patch("vllm_ascend.quantization.w4a16.get_current_vllm_config")
@patch("vllm_ascend.quantization.methods.w4a16.get_ascend_config")
@patch("vllm_ascend.quantization.methods.w4a16.get_current_vllm_config")
def setUp(self, mock_get_current_vllm_config, mock_get_ascend_config):
mock_ascend_config = Mock()
mock_ascend_config.eplb_config.dynamic_eplb = False
@@ -218,7 +219,7 @@ class TestAscendW4A16FusedMoEMethod(TestBase):
return layer
@patch(
"vllm_ascend.quantization.w4a16.torch_npu.npu_convert_weight_to_int4pack"
"vllm_ascend.quantization.methods.w4a16.torch_npu.npu_convert_weight_to_int4pack"
)
def test_process_weights_after_loading_with_transpose(
self, mock_npu_convert_weight_to_int4pack):

View File

@@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch
import torch
import torch.nn as nn
from vllm_ascend.quantization.w4a4_flatquant_dynamic import (
from vllm_ascend.quantization.methods.w4a4_flatquant import (
AscendW4A4FlatQuantDynamicLinearMethod, get_decompose_dim,
pack_int4_weights)
@@ -33,7 +33,7 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
self.assertEqual(get_decompose_dim(100), (10, 10))
self.assertEqual(get_decompose_dim(99), (9, 11))
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu')
@patch('vllm_ascend.quantization.methods.w4a4_flatquant.torch_npu')
def test_pack_int4_weights_npu_success(self, mock_torch_npu):
"""
Tests weight packing using the mocked NPU kernel.
@@ -119,7 +119,7 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
x = torch.randn(batch_size, self.input_size, dtype=self.params_dtype)
return layer, x, m, n
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu')
@patch('vllm_ascend.quantization.methods.w4a4_flatquant.torch_npu')
def test_apply_small_batch(self, mock_torch_npu):
"""Tests the apply method with a batch size smaller than MAX_BATCH_SIZE."""
batch_size = 128
@@ -143,9 +143,9 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
self.assertEqual(output.shape, (batch_size, self.output_size))
@patch(
'vllm_ascend.quantization.w4a4_flatquant_dynamic.KRONECKER_QUANT_MAX_BATCH_SIZE',
'vllm_ascend.quantization.methods.w4a4_flatquant.KRONECKER_QUANT_MAX_BATCH_SIZE',
10)
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu')
@patch('vllm_ascend.quantization.methods.w4a4_flatquant.torch_npu')
def test_apply_large_batch(self, mock_torch_npu):
"""Tests the apply method with a batch size larger than MAX_BATCH_SIZE."""
batch_size = 25
@@ -178,7 +178,7 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
ValueError, "FlatQuant transform matrices dimension mismatch"):
self.method.apply(layer, x)
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.pack_int4_weights')
@patch('vllm_ascend.quantization.methods.w4a4_flatquant.pack_int4_weights')
def test_process_weights_after_loading(self, mock_pack_weights):
"""Tests weight processing after loading, without transpose."""
layer = nn.Module()

View File

@@ -3,14 +3,14 @@ from unittest.mock import Mock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.quantization.w4a8_dynamic import (
from vllm_ascend.quantization.methods.w4a8 import (
AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod)
class TestAscendW4A8DynamicLinearMethod(TestBase):
@patch('vllm.distributed.get_tensor_model_parallel_world_size')
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
@patch('vllm_ascend.quantization.methods.w4a8.get_current_vllm_config')
def setUp(self, mock_get_current_vllm_config, mock_get_tp_world_size):
mock_get_tp_world_size.return_value = 1
mock_vllm_config = Mock()
@@ -127,10 +127,10 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
output_size = 56
group_size = 2
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ascend_config')
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group')
@patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group')
@patch('vllm_ascend.quantization.methods.w4a8.get_ascend_config')
@patch('vllm_ascend.quantization.methods.w4a8.get_current_vllm_config')
@patch('vllm_ascend.quantization.methods.w4a8.get_ep_group')
@patch('vllm_ascend.quantization.methods.w4a8.get_mc2_group')
@patch('torch.distributed.get_rank', return_value=0)
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ep_group,
get_current_vllm_config, mock_get_ascend_config):

View File

@@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.quantization.w8a16 import AscendW8A16LinearMethod
from vllm_ascend.quantization.methods.w8a16 import AscendW8A16LinearMethod
class TestAscendW8A16LinearMethod(TestBase):

View File

@@ -4,36 +4,10 @@ from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.quantization.w8a8 import (AscendW8A8LinearMethod,
quant_per_tensor)
from vllm_ascend.quantization.methods.w8a8_static import AscendW8A8LinearMethod
from vllm_ascend.utils import AscendDeviceType
class TestQuantPerTensor(TestBase):
@patch("torch_npu.npu_quantize")
def test_quant_per_tensor(self, mock_npu_quantize):
in_tensor = torch.randn(32, 128)
input_scale = torch.tensor(0.1)
input_offset = torch.tensor(0)
expected_output = torch.randint(-128, 127, (32, 128), dtype=torch.int8)
mock_npu_quantize.return_value = expected_output
output = quant_per_tensor(in_tensor, input_scale, input_offset)
mock_npu_quantize.assert_called_once_with(
in_tensor,
input_scale,
input_offset,
torch.qint8,
-1,
False,
)
self.assertTrue(torch.equal(output, expected_output))
class TestAscendW8A8LinearMethod(TestBase):
def setUp(self):
@@ -63,7 +37,9 @@ class TestAscendW8A8LinearMethod(TestBase):
self.assertEqual(params['weight_scale'].shape, (10, 1))
self.assertEqual(params['weight_offset'].shape, (10, 1))
@patch("vllm_ascend.quantization.w8a8.get_weight_prefetch_method")
@patch(
"vllm_ascend.quantization.methods.w8a8_static.get_weight_prefetch_method"
)
@patch("torch.ops.vllm.quantize")
@patch("torch_npu.npu_quant_matmul")
def test_apply_with_x_not_int8(self, mock_npu_quant_matmul, mock_quantize,

View File

@@ -3,7 +3,7 @@ from unittest.mock import Mock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.quantization.w8a8_dynamic import \
from vllm_ascend.quantization.methods.w8a8_dynamic import \
AscendW8A8DynamicFusedMoEMethod
@@ -13,13 +13,13 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
intermediate_size = 128
@patch("torch.distributed.get_rank")
@patch("vllm_ascend.quantization.w8a8_dynamic.get_mc2_group")
@patch("vllm_ascend.quantization.w8a8_dynamic.get_ascend_config")
@patch("vllm_ascend.quantization.w8a8_dynamic.get_ep_group")
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.get_mc2_group")
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.get_ascend_config")
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.get_ep_group")
def setUp(self, mock_get_ep_group, mock_get_ascend_config,
mock_get_mc2_group, mock_get_rank):
with patch(
'vllm_ascend.quantization.w8a8_dynamic.get_current_vllm_config'
'vllm_ascend.quantization.methods.w8a8_dynamic.get_current_vllm_config'
) as mock_get_current_vllm_config:
mock_vllm_config = Mock()
mock_vllm_config.quant_config = Mock(

View File

@@ -51,8 +51,9 @@ class TestNPUPlatform(TestBase):
self.assertTrue(self.platform.is_sleep_mode_available())
@patch("vllm_ascend.utils.adapt_patch")
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
def test_pre_register_and_update_with_parser(self, mock_quant_config, mock_adapt_patch):
@patch("vllm_ascend.quantization.modelslim_config.AscendModelSlimConfig")
def test_pre_register_and_update_with_parser(self, mock_quant_config,
mock_adapt_patch):
mock_parser = MagicMock()
mock_action = MagicMock()
mock_action.choices = ["awq", "gptq"]
@@ -66,15 +67,17 @@ class TestNPUPlatform(TestBase):
self.assertEqual(len(mock_action.choices), 3) # original 2 + ascend
@patch("vllm_ascend.utils.adapt_patch")
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
def test_pre_register_and_update_without_parser(self, mock_quant_config, mock_adapt_patch):
@patch("vllm_ascend.quantization.modelslim_config.AscendModelSlimConfig")
def test_pre_register_and_update_without_parser(self, mock_quant_config,
mock_adapt_patch):
self.platform.pre_register_and_update(None)
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
@patch("vllm_ascend.utils.adapt_patch")
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
def test_pre_register_and_update_with_parser_no_quant_action(self, mock_quant_config, mock_adapt_patch):
@patch("vllm_ascend.quantization.modelslim_config.AscendModelSlimConfig")
def test_pre_register_and_update_with_parser_no_quant_action(
self, mock_quant_config, mock_adapt_patch):
mock_parser = MagicMock()
mock_parser._option_string_actions = {}
@@ -83,8 +86,9 @@ class TestNPUPlatform(TestBase):
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
@patch("vllm_ascend.utils.adapt_patch")
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
def test_pre_register_and_update_with_existing_ascend_quant(self, mock_quant_config, mock_adapt_patch):
@patch("vllm_ascend.quantization.modelslim_config.AscendModelSlimConfig")
def test_pre_register_and_update_with_existing_ascend_quant(
self, mock_quant_config, mock_adapt_patch):
mock_parser = MagicMock()
mock_action = MagicMock()
mock_action.choices = ["awq", ASCEND_QUANTIZATION_METHOD]

View File

@@ -36,7 +36,7 @@ from vllm_ascend.ops.layer_shard_linear import (
register_all_layers_to_shard_weight_series)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, maybe_trans_nz,
weak_ref_tensors)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch

View File

@@ -35,7 +35,7 @@ from vllm_ascend.ops.layer_shard_linear import (
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.triton.rope import rope_forward_triton
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer,
enable_dsa_cp, enable_dsa_cp_with_layer_shard, maybe_trans_nz)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch

View File

@@ -43,10 +43,6 @@ from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
FusedExpertsResult,
setup_moe_comm_method)
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
from vllm_ascend.quantization.w4a8_dynamic import \
AscendW4A8DynamicFusedMoEMethod
from vllm_ascend.quantization.w8a8_dynamic import \
AscendW8A8DynamicFusedMoEMethod
from vllm_ascend.utils import (AscendDeviceType, enable_sp,
get_ascend_device_type, maybe_trans_nz,
npu_stream_switch, shared_expert_dp_enabled,
@@ -251,12 +247,16 @@ class AscendFusedMoE(FusedMoE):
method = quant_method.quant_method
if isinstance(method, AscendW8A8DynamicFusedMoEMethod):
return QuantType.W8A8
elif isinstance(method, AscendW4A8DynamicFusedMoEMethod):
return QuantType.W4A8
else:
return QuantType.NONE
if hasattr(method, "quant_type"):
from vllm_ascend.quantization.methods.base import \
QuantType as SchemeQuantType
scheme_quant_type = method.quant_type
if scheme_quant_type == SchemeQuantType.W8A8:
return QuantType.W8A8
elif scheme_quant_type == SchemeQuantType.W4A8:
return QuantType.W4A8
return QuantType.NONE
def update_expert_map(self, new_expert_map):
self._expert_map = new_expert_map

View File

@@ -368,7 +368,8 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
"communication_fn"] = otp_maybe_quant_comm
actual_quant_method = getattr(self.quant_method, 'quant_method',
self.quant_method)
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.quantization.methods.w8a8_static import \
AscendW8A8LinearMethod
if not isinstance(actual_quant_method, AscendW8A8LinearMethod):
# Check if w8a8 quantization is enabled. If not, communicate immediately.
input_parallel = otp_maybe_quant_comm(input_parallel)
@@ -586,8 +587,8 @@ class SequenceRowParallelOp(CustomRowParallelOp):
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
from vllm_ascend.quantization.method_adapters import AscendLinearMethod
# For unquant
if mmrs_fusion and isinstance(self.layer.quant_method,

View File

@@ -151,8 +151,7 @@ class NPUPlatform(Platform):
if ASCEND_QUANTIZATION_METHOD not in quant_action.choices:
quant_action.choices.append(ASCEND_QUANTIZATION_METHOD)
from vllm_ascend.quantization.compressed_tensors.compressed_tensors import AscendCompressedTensorsConfig # noqa: F401
from vllm_ascend.quantization.quant_config import AscendQuantConfig # noqa: F401
from vllm_ascend.quantization import AscendCompressedTensorsConfig, AscendModelSlimConfig # noqa: F401
config_deprecated_logging()

View File

@@ -0,0 +1,38 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Ascend quantization module.
This module provides quantization support for Ascend NPU.
Supported quantization tools:
- ModelSlim: Use AscendModelSlimConfig
- LLM-Compressor (compressed_tensors): Use AscendCompressedTensorsConfig
Public API:
- Config classes: AscendModelSlimConfig, AscendCompressedTensorsConfig
- For scheme implementations, import from vllm_ascend.quantization.methods
"""
# LLM-Compressor (compressed_tensors) quantization config
from .compressed_tensors_config import AscendCompressedTensorsConfig
# ModelSlim quantization config
from .modelslim_config import AscendModelSlimConfig
__all__ = [
"AscendModelSlimConfig",
"AscendCompressedTensorsConfig",
]

View File

@@ -1,4 +1,23 @@
from typing import TYPE_CHECKING, Any, Optional, cast
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
"""LLM-Compressor (compressed_tensors) quantization configuration for Ascend."""
from typing import Any, Optional, Union, cast
import torch
from compressed_tensors.quantization import (QuantizationArgs,
@@ -12,40 +31,37 @@ from vllm.model_executor.layers.quantization import (
QUANTIZATION_METHODS, register_quantization_config)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import \
CompressedTensorsScheme
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target, is_activation_quantization_format,
should_ignore_layer)
from vllm.model_executor.models.utils import WeightsMapper
from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
from vllm_ascend.quantization.quant_config import (AscendFusedMoEMethod,
AscendLinearMethod,
AscendQuantConfig)
from vllm_ascend.quantization.w4a16 import AscendW4A16FusedMoEMethod
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.quantization.w8a8_dynamic import (
AscendW8A8DynamicFusedMoEMethod, AscendW8A8DynamicLinearMethod)
from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
from .methods import AscendLinearScheme, AscendMoEScheme
logger = init_logger(__name__)
QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QuantizationArgs]]]
def remove_quantization_method():
# Remove the original compressed_tensors method to replace with our implementation
def _remove_quantization_method():
if COMPRESSED_TENSORS_METHOD in QUANTIZATION_METHODS:
QUANTIZATION_METHODS.remove(COMPRESSED_TENSORS_METHOD)
remove_quantization_method()
_remove_quantization_method()
QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str,
"QuantizationArgs"]]]
@register_quantization_config(COMPRESSED_TENSORS_METHOD)
class AscendCompressedTensorsConfig(QuantizationConfig):
"""Config class for LLM-Compressor (compressed_tensors) quantization on Ascend.
This class adapts the compressed_tensors format to work with Ascend's
quantization implementations.
"""
def __init__(
self,
@@ -107,23 +123,16 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
@classmethod
def _quantization_scheme_map_from_config(
cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
"""
"""Build target scheme map from config.
:param config: The `quantization_config` dictionary from config.json
:return: A dictionary mapping target layer names to their corresponding
quantization_args for weights and input activations
"""
target_scheme_map: dict[str, Any] = dict()
quant_format = cast(str, config.get("format"))
# The quant_config has multiple config_groups, each containing
# an input_activations key with details about how the activations are
# quantized, a weights key indicating how the weights are quantized,
# and a list of targets under the `targets` key, dictating which
# layers are impacted by the quantization details. The quantization
# details follow the structure defined by the QuantizationArgs
# pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use.
config_groups = config.get("config_groups", dict())
for _, quant_config in config_groups.items():
targets = quant_config.get("targets")
@@ -154,70 +163,102 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
from .method_adapters import AscendFusedMoEMethod, AscendLinearMethod
if isinstance(layer, LinearBase):
layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD
# collect schemes
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
# Get the scheme for this layer
linear_scheme = self._get_linear_scheme(layer=layer,
layer_name=prefix)
# Return unquantized method if no scheme found
if linear_scheme is None:
return UnquantizedLinearMethod()
# Store scheme on layer for reference (optional, for debugging)
layer.scheme = linear_scheme
logger.info_once(
"Using the vLLM Ascend llmcompressor Quantization now!")
return AscendLinearMethod(linear_scheme)
# choose quantization method
quant_method = UnquantizedLinearMethod()
if quant_scheme is not None:
layer.scheme = quant_scheme
ascend_quant_config = AscendQuantConfig(self.quant_description
or {})
quant_method = AscendLinearMethod(ascend_quant_config, prefix,
None, layer)
return quant_method
if isinstance(layer, FusedMoE):
self._add_fused_moe_to_target_scheme_map()
unfused_names = [
prefix + proj_name for proj_name in
[".0.gate_proj", ".0.up_proj", ".0.down_proj"]
]
# TODO: refactor this to use expert_mapping and check all layer numbers
all_scheme_dicts = [
self.get_scheme_dict(layer, name) for name in unfused_names
]
scheme_dict = all_scheme_dicts.pop()
# Delayed import to avoid circular import
from vllm_ascend.ops.fused_moe.fused_moe import \
AscendUnquantizedFusedMoEMethod
# multiple schemes found
if not all(
[cur_dict == scheme_dict for cur_dict in all_scheme_dicts]):
raise ValueError("All MoE projections need to have same "
"quantization scheme but found multiple")
layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD
# Get the scheme for this layer
moe_scheme = self._get_moe_scheme(layer=layer, layer_name=prefix)
if scheme_dict is None:
# Return unquantized method if no scheme found
if moe_scheme is None:
return AscendUnquantizedFusedMoEMethod(layer.moe_config)
weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations")
# Store scheme on layer for reference (optional, for debugging)
layer.scheme = moe_scheme
logger.info_once(
"Using the vLLM Ascend llmcompressor Quantization now!")
return AscendFusedMoEMethod(moe_scheme, layer.moe_config)
quant_scheme = None
act_quant_format = is_activation_quantization_format(self.quant_format)
if act_quant_format:
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
quant_scheme = AscendW8A8DynamicFusedMoEMethod()
else:
if self._is_w4a16(weight_quant, input_quant):
quant_scheme = AscendW4A16FusedMoEMethod()
if quant_scheme is None:
raise RuntimeError(
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}"
)
layer.scheme = quant_scheme
layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD
ascend_quant_config = AscendQuantConfig(self.quant_description
or {})
return AscendFusedMoEMethod(ascend_quant_config, prefix,
self.packed_modules_mapping, layer)
return None
def get_scheme(self,
layer: torch.nn.Module,
layer_name: Optional[str] = None
) -> Optional["CompressedTensorsScheme"]:
def _get_linear_scheme(
self,
layer: torch.nn.Module,
layer_name: Optional[str] = None) -> Optional[AscendLinearScheme]:
"""Get the linear quantization scheme for a layer.
Returns:
An AscendLinearScheme instance, or None if the layer
should use unquantized method.
"""
weight_quant, input_quant, format = self._get_quant_args(
layer, layer_name)
if weight_quant is None:
return None
scheme = self._create_scheme_for_layer_type(
weight_quant=weight_quant,
input_quant=input_quant,
format=format,
layer_type="linear",
)
return cast(AscendLinearScheme, scheme)
def _get_moe_scheme(
self,
layer: torch.nn.Module,
layer_name: Optional[str] = None) -> Optional[AscendMoEScheme]:
"""Get the MoE quantization scheme for a layer.
Returns:
An AscendMoEScheme instance, or None if the layer
should use unquantized method.
"""
# Add FusedMoE to target scheme map if needed
self._add_fused_moe_to_target_scheme_map()
weight_quant, input_quant, format = self._get_quant_args(
layer, layer_name)
if weight_quant is None:
return None
scheme = self._create_scheme_for_layer_type(
weight_quant=weight_quant,
input_quant=input_quant,
format=format,
layer_type="moe",
)
return cast(AscendMoEScheme, scheme)
def _get_quant_args(
self,
layer: torch.nn.Module,
layer_name: Optional[str] = None
) -> tuple[Optional["QuantizationArgs"], Optional["QuantizationArgs"],
Optional[str]]:
"""Extract quantization arguments for a layer.
compressed-tensors supports non uniform in the following way:
targets of config_groups: There can be N config_groups which each
@@ -226,10 +267,12 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
an nn.Module name.
Detect whether a layer_name is found in any target and
use the quantization scheme corresponding to the matched target
to select the CompressedTensorsScheme used for inference.
use the quantization scheme corresponding to the matched target.
Returns:
A tuple of (weight_quant, input_quant, format). weight_quant is
None if the layer should use unquantized method.
"""
scheme_dict = self.get_scheme_dict(layer, layer_name)
weight_quant = None
input_quant = None
@@ -243,16 +286,8 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
logger.warning_once("Acceleration for non-quantized schemes is "
"not supported by Compressed Tensors. "
"Falling back to UnquantizedLinearMethod")
return None
else:
# Find the quant_scheme
scheme = self._get_scheme_from_parts(
weight_quant=weight_quant,
input_quant=input_quant,
format=format,
)
return scheme
return weight_quant, input_quant, format
def get_scheme_dict(
self,
@@ -288,28 +323,73 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
return None
def _get_scheme_from_parts(
def _create_scheme_for_layer_type(
self,
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs,
format: str | None = None,
) -> "CompressedTensorsScheme":
weight_quant: "QuantizationArgs",
input_quant: Optional["QuantizationArgs"],
format: Optional[str],
layer_type: str,
) -> Union[AscendLinearScheme, AscendMoEScheme]:
"""Create the appropriate Ascend scheme based on quantization args and layer type.
Args:
weight_quant: Weight quantization arguments.
input_quant: Input activation quantization arguments.
format: Per-layer format, if defined.
layer_type: Type of layer ("linear" or "moe").
Returns:
An instance of the appropriate Ascend quantization scheme.
"""
from .methods import get_scheme_class
# Determine the quantization type
quant_type = self._detect_quant_type(weight_quant, input_quant, format)
# Get the scheme class from registry
scheme_cls = get_scheme_class(quant_type, layer_type)
if scheme_cls is None:
raise NotImplementedError(
f"No compressed-tensors compatible scheme was found for "
f"quant_type={quant_type}, layer_type={layer_type}.")
return scheme_cls()
def _detect_quant_type(
self,
weight_quant: "QuantizationArgs",
input_quant: Optional["QuantizationArgs"],
format: Optional[str],
) -> str:
"""Detect the quantization type from quantization arguments.
Args:
weight_quant: Weight quantization arguments.
input_quant: Input activation quantization arguments.
format: Per-layer format, if defined.
Returns:
A string representing the quantization type (e.g., "W8A8", "W8A8_DYNAMIC").
"""
# use the per-layer format if defined, otherwise, use global format
format = format if format is not None else self.quant_format
act_quant_format = is_activation_quantization_format(format)
if act_quant_format and input_quant is not None:
if self._is_static_tensor_w8a8(weight_quant, input_quant):
return AscendW8A8LinearMethod()
return "W8A8"
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
return AscendW8A8DynamicLinearMethod()
return "W8A8_DYNAMIC"
if self._is_w4a16(weight_quant, input_quant):
return "W4A16"
raise NotImplementedError(
"No compressed-tensors compatible scheme was found.")
"No compressed-tensors compatible quantization type was found.")
def _is_static_tensor_w8a8(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> bool:
def _is_static_tensor_w8a8(self, weight_quant: "QuantizationArgs",
input_quant: "QuantizationArgs") -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
@@ -322,8 +402,8 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
# Only symmetric weight quantization supported.
return is_8_bits and is_tensor and is_symmetric and is_static
def _is_dynamic_token_w8a8(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> bool:
def _is_dynamic_token_w8a8(self, weight_quant: "QuantizationArgs",
input_quant: "QuantizationArgs") -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
@@ -336,13 +416,13 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
# Only symmetric weight quantization supported.
return is_8_bits and is_token and is_symmetric and is_dynamic
def _is_w4a16(self, weight_quant: QuantizationArgs,
input_quant: QuantizationArgs) -> bool:
def _is_w4a16(self, weight_quant: "QuantizationArgs",
input_quant: Optional["QuantizationArgs"]) -> bool:
# Confirm weights quantized.
if weight_quant is None:
return False
# Confirm we have floating points.
# Confirm we have integer type.
if weight_quant.type != QuantizationType.INT:
return False

View File

@@ -0,0 +1,288 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import Callable, List, Optional
import torch
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.linear import (LinearMethodBase,
RowParallelLinear)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.parameter import PerTensorScaleParameter
from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import (get_flashcomm2_otp_group,
get_mlp_tp_group,
get_otp_group)
from vllm_ascend.utils import flashcomm2_enable, mlp_tp_enable, oproj_tp_enable
from .methods import (AscendAttentionScheme, AscendLinearScheme,
AscendMoEScheme, is_mx_quant_type)
class AscendLinearMethod(LinearMethodBase):
"""Linear method for Ascend quantization.
This wrapper class delegates to the actual quantization scheme implementation.
The scheme is determined by the Config class and passed directly to this wrapper.
Args:
scheme: The quantization scheme instance (e.g., AscendW8A8DynamicLinearMethod).
"""
def __init__(self, scheme: AscendLinearScheme) -> None:
self.quant_method = scheme
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
weight_dict = self.quant_method.get_weight(input_size_per_partition,
output_size_per_partition,
params_dtype)
# Extract packing information (if present)
packed_dim = weight_dict.pop("_packed_dim", None)
packed_factor = weight_dict.pop("_packed_factor", None)
for weight_name, weight_param in weight_dict.items():
param = torch.nn.Parameter(weight_param, requires_grad=False)
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
# Set packing attributes if the weight is packed
if packed_dim is not None and packed_factor is not None:
set_weight_attrs(param, {
"packed_dim": packed_dim,
"packed_factor": packed_factor
})
layer.register_parameter(weight_name, param)
set_weight_attrs(param, extra_weight_attrs)
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
for pertensor_name, pertensor_param in pertensor_dict.items():
param = PerTensorScaleParameter(data=pertensor_param,
weight_loader=weight_loader)
# disable warning
param.ignore_warning = True
layer.register_parameter(pertensor_name, param)
param.weight_loader = extra_weight_attrs.get("weight_loader")
perchannel_dict = self.quant_method.get_perchannel_param(
output_size_per_partition, params_dtype)
for perchannel_name, perchannel_param in perchannel_dict.items():
param = torch.nn.Parameter(perchannel_param, requires_grad=False)
set_weight_attrs(param, {"output_dim": 0})
layer.register_parameter(perchannel_name, param)
set_weight_attrs(param, extra_weight_attrs)
# NOTE: In w4a8 quantization implementation,
# for down_proj and o_proj scale_bias shape is [output_size, 16],
# others are [output_size, 1]
layer_type = "row" if isinstance(layer,
RowParallelLinear) else "others"
pergroup_dict = self.quant_method.get_pergroup_param(
input_size_per_partition,
output_size_per_partition,
params_dtype,
layer_type=layer_type)
for pergroup_name, pergroup_param in pergroup_dict.items():
param = torch.nn.Parameter(pergroup_param, requires_grad=False)
set_weight_attrs(param, {"output_dim": 0})
layer.register_parameter(pergroup_name, param)
set_weight_attrs(param, extra_weight_attrs)
if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name \
or is_mx_quant_type(self.quant_method):
setattr(param, "input_dim", 1)
param.input_dim = 1
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
self.quant_method.process_weights_after_loading(layer)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(layer, RowParallelLinear):
if layer.prefix.find("o_proj") != -1 and oproj_tp_enable():
tp_rank = get_otp_group().rank_in_group
elif layer.prefix.find("down_proj") != -1 and mlp_tp_enable():
tp_rank = get_mlp_tp_group().rank_in_group
elif (layer.prefix.find("o_proj") != -1 or
layer.prefix.find("out_proj") != -1) and flashcomm2_enable():
if get_ascend_config(
).flashcomm2_oproj_tensor_parallel_size == 1:
tp_rank = 0
else:
tp_rank = get_flashcomm2_otp_group().rank_in_group
else:
tp_rank = get_tensor_model_parallel_rank()
else:
tp_rank = 0
return self.quant_method.apply(layer, x, bias, tp_rank)
class AscendKVCacheMethod(BaseKVCacheMethod):
"""KVCache method for Ascend quantization.
This wrapper class delegates to the actual attention quantization scheme.
Args:
scheme: The attention quantization scheme instance.
"""
def __init__(self, scheme: AscendAttentionScheme) -> None:
self.quant_method = scheme
def create_weights(self, layer: torch.nn.Module) -> None:
# Different from linear method, there are no weight processing/slicing
# steps for attention in vllm. So the whole process of create weights
# is hidden into the specific quant method.
self.quant_method.create_weights(layer)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.quant_method.process_weights_after_loading(layer)
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata,
attn_type, scale, output) -> torch.Tensor:
return self.quant_method.apply(layer, query, key, value, kv_cache,
attn_metadata, attn_type, scale, output)
class AscendFusedMoEMethod(FusedMoEMethodBase):
"""FusedMoE method for Ascend quantization.
This wrapper class delegates to the actual MoE quantization scheme.
Args:
scheme: The MoE quantization scheme instance.
moe_config: The FusedMoE configuration.
"""
def __init__(self, scheme: AscendMoEScheme,
moe_config: FusedMoEConfig) -> None:
super().__init__(moe_config)
self.quant_method = scheme
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
weight_param = self.quant_method.get_weight(
num_experts, intermediate_size_per_partition, hidden_size,
params_dtype)
for param_key, param_value in weight_param.items():
param = torch.nn.Parameter(param_value, requires_grad=False)
layer.register_parameter(param_key, param)
set_weight_attrs(param, extra_weight_attrs)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
per_group_param = [
"weight_scale_second", "weight_offset_second", "scale_bias"
] + ["weight_scale", "weight_offset"] if hasattr(
self.quant_method,
"group_size") and self.quant_method.group_size > 0 else []
dynamic_quant_param = self.quant_method.get_dynamic_quant_param(
num_experts, intermediate_size_per_partition, hidden_size,
params_dtype)
for param_key, param_value in dynamic_quant_param.items():
param = torch.nn.Parameter(param_value, requires_grad=False)
layer.register_parameter(param_key, param)
set_weight_attrs(param, extra_weight_attrs)
if any(fields in param_key for fields in per_group_param):
setattr(param, "quant_method",
FusedMoeWeightScaleSupported.GROUP.value)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = False,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num=0,
**kwargs,
) -> torch.Tensor:
return self.quant_method.apply(
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
global_num_experts, expert_map, topk_group, num_expert_group,
custom_routing_function, scoring_func, routed_scaling_factor,
e_score_correction_bias, is_prefill, enable_force_load_balance,
log2phy, global_redundant_expert_num, **kwargs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
self.quant_method.process_weights_after_loading(layer)
def get_fused_moe_quant_config(self, layer: torch.nn.Module):
pass
@property
def supports_eplb(self):
supports_eplb = getattr(self.quant_method, "supports_eplb", False)
return supports_eplb
class AscendEmbeddingMethod(AscendLinearMethod):
"""Embedding method for Ascend quantization.
This is essentially the same as AscendLinearMethod, just with a different name
for clarity when used with VocabParallelEmbedding layers.
Args:
scheme: The quantization scheme instance.
"""
def __init__(self, scheme: AscendLinearScheme) -> None:
self.quant_method = scheme

View File

@@ -0,0 +1,82 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Ascend quantization scheme implementations.
This module provides all quantization scheme implementations for Ascend NPU.
Schemes are automatically registered via the @register_scheme decorator.
Usage:
from vllm_ascend.quantization.methods import get_scheme_class
# Get a scheme class by quant_type and layer_type
scheme_cls = get_scheme_class("W8A8_DYNAMIC", "linear")
scheme = scheme_cls()
"""
from typing import Any
# Import base classes
from .base import (AscendAttentionScheme, AscendLinearScheme, AscendMoEScheme,
QuantType)
# Import registry functions
from .registry import get_scheme_class, register_scheme
# Import all scheme classes for external access
from .w4a4_flatquant import AscendW4A4FlatQuantDynamicLinearMethod
from .w4a4_laos_dynamic import AscendW4A4LaosDynamicLinearMethod
from .w4a8 import (AscendW4A8DynamicFusedMoEMethod,
AscendW4A8DynamicLinearMethod)
from .w4a16 import AscendW4A16FusedMoEMethod
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
AscendW8A8DynamicLinearMethod)
from .w8a8_mxfp8 import AscendW8A8MXFP8DynamicLinearMethod
from .w8a8_pdmix import (AscendW8A8PDMixFusedMoeMethod,
AscendW8A8PDMixLinearMethod)
from .w8a8_static import AscendW8A8LinearMethod
from .w8a16 import AscendW8A16LinearMethod
def is_mx_quant_type(instance: Any) -> bool:
"""Checks if the quantization method is a microscaling (MX) type."""
MX_QUANT_TYPES = (AscendW8A8MXFP8DynamicLinearMethod, )
return isinstance(instance, MX_QUANT_TYPES)
__all__ = [
# Base classes
"AscendAttentionScheme",
"AscendLinearScheme",
"AscendMoEScheme",
"QuantType",
# Registry functions
"register_scheme",
"get_scheme_class",
# Utility functions
"is_mx_quant_type",
# Scheme classes
"AscendW8A8LinearMethod",
"AscendW8A8DynamicLinearMethod",
"AscendW8A8DynamicFusedMoEMethod",
"AscendW8A8MXFP8DynamicLinearMethod",
"AscendW8A8PDMixLinearMethod",
"AscendW8A8PDMixFusedMoeMethod",
"AscendW8A16LinearMethod",
"AscendW4A8DynamicLinearMethod",
"AscendW4A8DynamicFusedMoEMethod",
"AscendW4A16FusedMoEMethod",
"AscendW4A4FlatQuantDynamicLinearMethod",
"AscendW4A4LaosDynamicLinearMethod",
]

View File

@@ -0,0 +1,279 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Abstract base classes for Ascend quantization schemes."""
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Callable, Dict, Optional
import torch
class QuantType(Enum):
"""Quantization type enum for MoE schemes."""
NONE = 0
W8A8 = 1
W4A8 = 2
class AscendLinearScheme(ABC):
"""Base class for all linear quantization schemes.
Subclasses must implement get_weight() and apply() methods.
Other methods have default implementations that return empty dicts
or do nothing.
"""
@abstractmethod
def get_weight(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
"""Return weight tensor specifications.
Args:
input_size: Input dimension of the linear layer.
output_size: Output dimension of the linear layer.
params_dtype: Data type for parameters.
Returns:
Dictionary mapping parameter names to empty tensors with
the correct shape and dtype.
"""
...
def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]:
"""Return per-tensor parameter specifications (e.g., input_scale).
Args:
params_dtype: Data type for parameters.
Returns:
Dictionary mapping parameter names to empty tensors.
"""
return {}
def get_perchannel_param(self, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
"""Return per-channel parameter specifications (e.g., weight_scale).
Args:
output_size: Output dimension of the linear layer.
params_dtype: Data type for parameters.
Returns:
Dictionary mapping parameter names to empty tensors.
"""
return {}
def get_pergroup_param(self,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
layer_type: Optional[str] = None) -> Dict[str, Any]:
"""Return per-group parameter specifications.
Args:
input_size: Input dimension of the linear layer.
output_size: Output dimension of the linear layer.
params_dtype: Data type for parameters.
layer_type: Type of layer (e.g., "row" for RowParallelLinear).
Returns:
Dictionary mapping parameter names to empty tensors.
"""
return {}
@abstractmethod
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0) -> torch.Tensor:
"""Forward computation.
Args:
layer: The linear layer module.
x: Input tensor.
bias: Optional bias tensor.
tp_rank: Tensor parallel rank.
Returns:
Output tensor after quantized linear operation.
"""
...
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Post-loading weight processing (transpose, format conversion, etc.).
Args:
layer: The linear layer module.
"""
pass
class AscendAttentionScheme(ABC):
"""Base class for all attention quantization schemes.
Subclasses must implement apply() method.
Other methods have default implementations.
"""
def create_weights(self, layer: torch.nn.Module) -> None:
"""Create weights for attention quantization.
Args:
layer: The attention layer module.
"""
pass
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Post-loading weight processing for attention layer.
Args:
layer: The attention layer module.
"""
pass
@abstractmethod
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata,
attn_type, scale, output) -> torch.Tensor:
"""Forward computation for attention layer.
Args:
layer: The attention layer module.
query: Query tensor.
key: Key tensor.
value: Value tensor.
kv_cache: KV cache.
attn_metadata: Attention metadata.
attn_type: Attention type.
scale: Scale factor.
output: Output tensor.
Returns:
Output tensor after attention computation.
"""
...
class AscendMoEScheme(ABC):
"""Base class for all MoE quantization schemes.
Subclasses must implement get_weight(), get_dynamic_quant_param(),
and apply() methods.
Attributes:
quant_type: The quantization type for this scheme. Subclasses should
override this class attribute to declare their quant type.
"""
# Default quant type - subclasses should override this
quant_type: QuantType = QuantType.NONE
@abstractmethod
def get_weight(self, num_experts: int,
intermediate_size_per_partition: int, hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
"""Return weight tensor specifications for MoE layer.
Args:
num_experts: Number of experts.
intermediate_size_per_partition: Intermediate size per partition.
hidden_sizes: Hidden dimension size.
params_dtype: Data type for parameters.
Returns:
Dictionary mapping parameter names to empty tensors.
"""
...
@abstractmethod
def get_dynamic_quant_param(self, num_experts: int,
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
"""Return dynamic quantization parameters for MoE layer.
Args:
num_experts: Number of experts.
intermediate_size_per_partition: Intermediate size per partition.
hidden_sizes: Hidden dimension size.
params_dtype: Data type for parameters.
Returns:
Dictionary mapping parameter names to empty tensors.
"""
...
@abstractmethod
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = False,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
**kwargs,
) -> torch.Tensor:
"""Forward computation for MoE layer.
Args:
layer: The MoE layer module.
x: Input hidden states.
router_logits: Router logits for expert selection.
top_k: Number of experts to select per token.
renormalize: Whether to renormalize expert weights.
use_grouped_topk: Whether to use grouped top-k selection.
global_num_experts: Total number of experts globally.
expert_map: Mapping from local to global expert indices.
topk_group: Group size for grouped top-k.
num_expert_group: Number of expert groups.
custom_routing_function: Custom routing function.
scoring_func: Scoring function name.
routed_scaling_factor: Scaling factor for routed experts.
e_score_correction_bias: Expert score correction bias.
is_prefill: Whether in prefill phase.
enable_force_load_balance: Whether to force load balancing.
log2phy: Logical to physical expert mapping.
global_redundant_expert_num: Number of redundant experts.
**kwargs: Additional keyword arguments.
Returns:
Output tensor after MoE computation.
"""
...
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Post-loading weight processing for MoE layer.
Args:
layer: The MoE layer module.
"""
pass

View File

@@ -0,0 +1,62 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Dict, Optional, Tuple, Type
# Registry: maps (quant_type, layer_type) -> SchemeClass
_SCHEME_REGISTRY: Dict[Tuple[str, str], Type[Any]] = {}
def register_scheme(quant_type: str, layer_type: str):
"""Decorator to register a quantization scheme.
Args:
quant_type: Quantization type (e.g., "W8A8", "W8A8_DYNAMIC").
layer_type: Layer type (e.g., "linear", "moe").
Returns:
Decorator function that registers the class.
Example:
@register_scheme("W8A8_DYNAMIC", "linear")
class W8A8DynamicLinearScheme(AscendLinearScheme):
...
"""
def decorator(cls: Type[Any]) -> Type[Any]:
key = (quant_type, layer_type)
if key in _SCHEME_REGISTRY:
raise ValueError(
f"Scheme already registered for {quant_type}/{layer_type}: "
f"{_SCHEME_REGISTRY[key].__name__}")
_SCHEME_REGISTRY[key] = cls
return cls
return decorator
def get_scheme_class(quant_type: str, layer_type: str) -> Optional[Type[Any]]:
"""Get scheme class for given quant_type and layer_type.
Args:
quant_type: Quantization type (e.g., "W8A8", "W8A8_DYNAMIC").
layer_type: Layer type (e.g., "linear", "moe").
Returns:
The registered scheme class, or None if not found.
"""
return _SCHEME_REGISTRY.get((quant_type, layer_type))

View File

@@ -1,277 +1,278 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Callable, Dict, Optional
import torch
import torch_npu
from vllm.config import get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
def unpack_from_int32(
weight: torch.Tensor,
shape: torch.Size,
num_bits: int,
packed_dim: int = 1,
) -> torch.Tensor:
"""
Unpacks quantized weights from int32 format back to original bits.
:param weight: The packed int32 tensor containing quantized weights
:param shape: Original shape to restore, defaults to None
:param num_bits: The number of bits used for quantization (<= 8)
:param packed_dim: Dimension along which weights are packed (0 or 1), defaults to 1
:return: Unpacked tensor with int8 dtype after applying offset correction
"""
assert weight.dtype == torch.int32, f"Expecting `weight.dtype` is torch.int32 but got {weight.dtype}."
assert num_bits <= 8, f"Expecting `num_bits` should not be larger than 8 but got {num_bits}."
pack_factor = 32 // num_bits
mask = (1 << num_bits) - 1
if packed_dim == 1:
unpacked_weight = torch.zeros(
(weight.shape[0], weight.shape[1] * pack_factor),
device=weight.device,
dtype=torch.int32,
)
for i in range(pack_factor):
unpacked_weight[:, i::pack_factor] = (weight >>
(num_bits * i)) & mask
original_row_size = int(shape[1])
unpacked_weight = unpacked_weight[:, :original_row_size]
else:
unpacked_weight = torch.zeros(
(weight.shape[0] * pack_factor, weight.shape[1]),
device=weight.device,
dtype=torch.int32,
)
for i in range(pack_factor):
unpacked_weight[i::pack_factor, :] = (weight >>
(num_bits * i)) & mask
original_row_size = int(shape[0])
unpacked_weight = unpacked_weight[:original_row_size, :]
offset = pow(2, num_bits) // 2
unpacked_weight = (unpacked_weight - offset).to(torch.int8)
return unpacked_weight
def pack_to_int32(weight: torch.Tensor) -> torch.Tensor:
"""
Packs quantized weights into int32 format for storage.
:param weight: The 3D tensor to pack, must be int8 or int32 dtype
:return: Packed tensor with int32 dtype optimized for storage
"""
assert weight.dim(
) == 3, f"Expecting `weight.dim()` is 3 ([e, n, k] or [e, k, n]) but got {weight.dim()}."
assert weight.dtype in [
torch.int8, torch.int32
], f"Expecting `weight.dtype` is torch.int8 or torch.int32 bug got {weight.dtype}."
if weight.dtype == torch.int32:
assert weight.shape[
-1] % 8 == 0, "the last dim of weight needs to be divided by 8."
packed_weight = torch_npu.npu_convert_weight_to_int4pack(
weight.flatten(0, 1))
packed_weight = packed_weight.view(weight.shape[0], weight.shape[1],
-1)
else:
assert weight.shape[
-1] % 4 == 0, "the last dim of weight needs to be divided by 4."
packed_weight = weight.view(torch.int32).contiguous()
return packed_weight
class AscendW4A16FusedMoEMethod:
"""FusedMoe method for Ascend W4A16.
"""
def __init__(self) -> None:
self.transpose_weight = True
self.num_bits = 4 # dtype = torch.int4
self.pack_factor = 8 # pack 8 of torch.int4 tensors to torch.int32
vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get(
"group_size", 32)
self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb
def get_weight(
self,
num_experts: int,
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
assert intermediate_size_per_partition % self.pack_factor == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `pack_factor` {self.pack_factor}"
assert hidden_sizes % self.pack_factor == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `pack_factor` {self.pack_factor}"
param_dict = {}
param_dict["w13_weight_packed"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.pack_factor,
dtype=torch.int32)
param_dict["w2_weight_packed"] = torch.empty(
num_experts,
hidden_sizes,
intermediate_size_per_partition // self.pack_factor,
dtype=torch.int32)
return param_dict
def get_dynamic_quant_param(
self,
num_experts: int,
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
assert intermediate_size_per_partition % self.group_size == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `group_size` {self.group_size}"
assert hidden_sizes % self.group_size == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `group_size` {self.group_size}"
param_dict = {}
param_dict["w13_weight_scale"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=torch.bfloat16)
param_dict["w2_weight_scale"] = torch.empty(
num_experts,
hidden_sizes,
intermediate_size_per_partition // self.group_size,
dtype=torch.bfloat16)
param_dict["w13_weight_shape"] = torch.empty(num_experts,
2,
dtype=torch.int32)
param_dict["w2_weight_shape"] = torch.empty(num_experts,
2,
dtype=torch.int32)
param_dict["w13_weight_offset"] = torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=torch.bfloat16)
param_dict["w2_weight_offset"] = torch.zeros(
num_experts,
hidden_sizes,
intermediate_size_per_partition // self.group_size,
dtype=torch.bfloat16)
return param_dict
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = True,
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
topk_ids = topk_ids.to(torch.int32)
topk_weights = topk_weights.to(x.dtype)
moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts(hidden_states=x,
w1=layer.w13_weight_packed,
w2=layer.w2_weight_packed,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_offset=layer.w13_weight_offset,
w2_offset=layer.w2_weight_offset,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_int4_w4a16=True,
expert_map=expert_map,
log2phy=log2phy,
dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get(
"mc2_mask", None))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.transpose_weight:
w13_shape = layer.w13_weight_packed.data.shape
w2_shape = layer.w2_weight_packed.data.shape
unpacked_w13_weight = (unpack_from_int32(
layer.w13_weight_packed.data.flatten(0, 1),
torch.Size([
w13_shape[0] * w13_shape[1],
w13_shape[2] * self.pack_factor
]),
self.num_bits,
).view(w13_shape[0], w13_shape[1],
-1).transpose(1, 2).contiguous().int())
unpacked_w2_weight = (unpack_from_int32(
layer.w2_weight_packed.data.flatten(0, 1),
torch.Size([
w2_shape[0] * w2_shape[1], w2_shape[2] * self.pack_factor
]),
self.num_bits,
).view(w2_shape[0], w2_shape[1],
-1).transpose(1, 2).contiguous().int())
layer.w13_weight_packed.data = pack_to_int32(unpacked_w13_weight)
layer.w2_weight_packed.data = pack_to_int32(unpacked_w2_weight)
layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose(
1, 2).contiguous()
layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(
1, 2).contiguous()
layer.w13_weight_offset.data = layer.w13_weight_offset.data.transpose(
1, 2).contiguous()
layer.w2_weight_offset.data = layer.w2_weight_offset.data.transpose(
1, 2).contiguous()
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Callable, Dict, Optional
import torch
import torch_npu
from vllm.config import get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from .base import AscendMoEScheme
from .registry import register_scheme
def unpack_from_int32(
weight: torch.Tensor,
shape: torch.Size,
num_bits: int,
packed_dim: int = 1,
) -> torch.Tensor:
"""Unpacks quantized weights from int32 format back to original bits.
:param weight: The packed int32 tensor containing quantized weights
:param shape: Original shape to restore, defaults to None
:param num_bits: The number of bits used for quantization (<= 8)
:param packed_dim: Dimension along which weights are packed (0 or 1), defaults to 1
:return: Unpacked tensor with int8 dtype after applying offset correction
"""
assert weight.dtype == torch.int32, f"Expecting `weight.dtype` is torch.int32 but got {weight.dtype}."
assert num_bits <= 8, f"Expecting `num_bits` should not be larger than 8 but got {num_bits}."
pack_factor = 32 // num_bits
mask = (1 << num_bits) - 1
if packed_dim == 1:
unpacked_weight = torch.zeros(
(weight.shape[0], weight.shape[1] * pack_factor),
device=weight.device,
dtype=torch.int32,
)
for i in range(pack_factor):
unpacked_weight[:, i::pack_factor] = (weight >>
(num_bits * i)) & mask
original_row_size = int(shape[1])
unpacked_weight = unpacked_weight[:, :original_row_size]
else:
unpacked_weight = torch.zeros(
(weight.shape[0] * pack_factor, weight.shape[1]),
device=weight.device,
dtype=torch.int32,
)
for i in range(pack_factor):
unpacked_weight[i::pack_factor, :] = (weight >>
(num_bits * i)) & mask
original_row_size = int(shape[0])
unpacked_weight = unpacked_weight[:original_row_size, :]
offset = pow(2, num_bits) // 2
unpacked_weight = (unpacked_weight - offset).to(torch.int8)
return unpacked_weight
def pack_to_int32(weight: torch.Tensor) -> torch.Tensor:
"""Packs quantized weights into int32 format for storage.
:param weight: The 3D tensor to pack, must be int8 or int32 dtype
:return: Packed tensor with int32 dtype optimized for storage
"""
assert weight.dim(
) == 3, f"Expecting `weight.dim()` is 3 ([e, n, k] or [e, k, n]) but got {weight.dim()}."
assert weight.dtype in [
torch.int8, torch.int32
], f"Expecting `weight.dtype` is torch.int8 or torch.int32 bug got {weight.dtype}."
if weight.dtype == torch.int32:
assert weight.shape[
-1] % 8 == 0, "the last dim of weight needs to be divided by 8."
packed_weight = torch_npu.npu_convert_weight_to_int4pack(
weight.flatten(0, 1))
packed_weight = packed_weight.view(weight.shape[0], weight.shape[1],
-1)
else:
assert weight.shape[
-1] % 4 == 0, "the last dim of weight needs to be divided by 4."
packed_weight = weight.view(torch.int32).contiguous()
return packed_weight
@register_scheme("W4A16", "moe")
class AscendW4A16FusedMoEMethod(AscendMoEScheme):
"""FusedMoE method for Ascend W4A16."""
def __init__(self) -> None:
self.transpose_weight = True
self.num_bits = 4 # dtype = torch.int4
self.pack_factor = 8 # pack 8 of torch.int4 tensors to torch.int32
vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get(
"group_size", 32)
self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb
def get_weight(
self,
num_experts: int,
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
assert intermediate_size_per_partition % self.pack_factor == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `pack_factor` {self.pack_factor}"
assert hidden_sizes % self.pack_factor == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `pack_factor` {self.pack_factor}"
param_dict = {}
param_dict["w13_weight_packed"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.pack_factor,
dtype=torch.int32)
param_dict["w2_weight_packed"] = torch.empty(
num_experts,
hidden_sizes,
intermediate_size_per_partition // self.pack_factor,
dtype=torch.int32)
return param_dict
def get_dynamic_quant_param(
self,
num_experts: int,
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
assert intermediate_size_per_partition % self.group_size == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `group_size` {self.group_size}"
assert hidden_sizes % self.group_size == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `group_size` {self.group_size}"
param_dict = {}
param_dict["w13_weight_scale"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=torch.bfloat16)
param_dict["w2_weight_scale"] = torch.empty(
num_experts,
hidden_sizes,
intermediate_size_per_partition // self.group_size,
dtype=torch.bfloat16)
param_dict["w13_weight_shape"] = torch.empty(num_experts,
2,
dtype=torch.int32)
param_dict["w2_weight_shape"] = torch.empty(num_experts,
2,
dtype=torch.int32)
param_dict["w13_weight_offset"] = torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
hidden_sizes // self.group_size,
dtype=torch.bfloat16)
param_dict["w2_weight_offset"] = torch.zeros(
num_experts,
hidden_sizes,
intermediate_size_per_partition // self.group_size,
dtype=torch.bfloat16)
return param_dict
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = True,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
topk_ids = topk_ids.to(torch.int32)
topk_weights = topk_weights.to(x.dtype)
moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts(
hidden_states=x,
w1=layer.w13_weight_packed,
w2=layer.w2_weight_packed,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_offset=layer.w13_weight_offset,
w2_offset=layer.w2_weight_offset,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_int4_w4a16=True,
expert_map=expert_map,
log2phy=log2phy,
dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get("mc2_mask", None))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.transpose_weight:
w13_shape = layer.w13_weight_packed.data.shape
w2_shape = layer.w2_weight_packed.data.shape
unpacked_w13_weight = (unpack_from_int32(
layer.w13_weight_packed.data.flatten(0, 1),
torch.Size([
w13_shape[0] * w13_shape[1],
w13_shape[2] * self.pack_factor
]),
self.num_bits,
).view(w13_shape[0], w13_shape[1],
-1).transpose(1, 2).contiguous().int())
unpacked_w2_weight = (unpack_from_int32(
layer.w2_weight_packed.data.flatten(0, 1),
torch.Size([
w2_shape[0] * w2_shape[1], w2_shape[2] * self.pack_factor
]),
self.num_bits,
).view(w2_shape[0], w2_shape[1],
-1).transpose(1, 2).contiguous().int())
layer.w13_weight_packed.data = pack_to_int32(unpacked_w13_weight)
layer.w2_weight_packed.data = pack_to_int32(unpacked_w2_weight)
layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose(
1, 2).contiguous()
layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(
1, 2).contiguous()
layer.w13_weight_offset.data = layer.w13_weight_offset.data.transpose(
1, 2).contiguous()
layer.w2_weight_offset.data = layer.w2_weight_offset.data.transpose(
1, 2).contiguous()

View File

@@ -14,16 +14,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import math
from typing import Any, Dict, Optional, Tuple
import torch
import torch_npu
from .base import AscendLinearScheme
from .registry import register_scheme
KRONECKER_QUANT_MAX_BATCH_SIZE = 32768
def pack_int4_weights(weight_tensor: torch.Tensor) -> torch.Tensor:
"""Pack int4 weights for NPU."""
original_device = weight_tensor.device
weight_tensor_npu = weight_tensor.npu()
weight_int4_packed = torch_npu.npu_convert_weight_to_int4pack(
@@ -32,6 +37,7 @@ def pack_int4_weights(weight_tensor: torch.Tensor) -> torch.Tensor:
def get_decompose_dim(n):
"""Get decomposed dimensions for Kronecker quantization."""
a = int(math.sqrt(n))
if a * a < n:
a += 1
@@ -53,6 +59,7 @@ def batched_kronecker_quant(
right_trans: torch.Tensor,
clip_ratio: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Batched Kronecker quantization with batch size limit handling."""
batch_tokens = x.shape[0]
if batch_tokens <= KRONECKER_QUANT_MAX_BATCH_SIZE:
return torch_npu.npu_kronecker_quant(x,
@@ -75,7 +82,8 @@ def batched_kronecker_quant(
return x_quantized_int4, activation_scale
class AscendW4A4FlatQuantDynamicLinearMethod:
@register_scheme("W4A4_FLATQUANT_DYNAMIC", "linear")
class AscendW4A4FlatQuantDynamicLinearMethod(AscendLinearScheme):
"""Linear method for Ascend W4A4_FLATQUANT_DYNAMIC.
This class implements W4A4 quantization with FlatQuant approach and dynamic activation quantization.
@@ -88,8 +96,7 @@ class AscendW4A4FlatQuantDynamicLinearMethod:
def __init__(self):
self.sym = True
@staticmethod
def get_weight(input_size: int, output_size: int,
def get_weight(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
if input_size % 8 != 0:
raise ValueError(
@@ -101,8 +108,7 @@ class AscendW4A4FlatQuantDynamicLinearMethod:
}
return params_dict
@staticmethod
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]:
params_dict = {}
left_trans_dim, right_trans_dim = get_decompose_dim(
AscendW4A4FlatQuantDynamicLinearMethod.input_size)
@@ -115,8 +121,8 @@ class AscendW4A4FlatQuantDynamicLinearMethod:
params_dict["clip_ratio"] = torch.empty(1, dtype=torch.float32)
return params_dict
@staticmethod
def get_perchannel_param(
self,
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
@@ -129,15 +135,8 @@ class AscendW4A4FlatQuantDynamicLinearMethod:
dtype=torch.float32)
return params_dict
def get_pergroup_param(self,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
layer_type: Optional[str] = None) -> Dict[str, Any]:
return {}
@staticmethod
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,

View File

@@ -14,22 +14,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Callable, Dict, Optional
from typing import Any, Dict, Optional
import torch
import torch_npu
import torch.nn.functional as F
class AscendW4A4LaosDynamicLinearMethod:
"""Linear method for Ascend W4A4_LAOS_DYNAMIC.
from .base import AscendLinearScheme
from .registry import register_scheme
@register_scheme("W4A4_DYNAMIC", "linear")
class AscendW4A4LaosDynamicLinearMethod(AscendLinearScheme):
"""Linear method for Ascend W4A4_DYNAMIC.
This class implements W4A4 quantization with LAOS approach and dynamic activation quantization.
- Weight: 4-bit quantization (per-channel) with scale and offset, stored as int8.
- Activation: 4-bit dynamic quantization.
"""
def __init__(self):
self.transpose_weight = True
self.rotation_type = None
def set_rotation_config(self, prefix, metadata):
def set_rotation_config(self, prefix: str, metadata: Dict) -> Optional[str]:
"""Set rotation config based on prefix and metadata."""
layer_idx = prefix.split(".")[2]
if prefix.endswith("o_proj"):
layers = metadata["quarot"]["heads_rotation"]["layers"]
@@ -39,17 +48,17 @@ class AscendW4A4LaosDynamicLinearMethod:
layers = metadata["quarot"]["kronecker_rotation"]["layers"]
if layer_idx in layers:
return "kronecker_rotation"
return None
@staticmethod
def get_weight(input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]:
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
def get_weight(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
params_dict = {
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
}
return params_dict
@staticmethod
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
return {}
def get_perchannel_param(self, output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]:
def get_perchannel_param(self, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size,
1,
@@ -58,20 +67,18 @@ class AscendW4A4LaosDynamicLinearMethod:
1,
dtype=torch.float32)
if self.rotation_type == "heads_rotation":
params_dict["heads_rotation"] = torch.zeros((64, 64), dtype=torch.float32)
params_dict["heads_rotation"] = torch.zeros((64, 64),
dtype=torch.float32)
if self.rotation_type == "kronecker_rotation":
params_dict["kronecker_rotation_n"] = torch.zeros((160, 160), dtype=torch.float32)
params_dict["kronecker_rotation_m"] = torch.zeros((160, 160), dtype=torch.float32)
params_dict["kronecker_rotation_n"] = torch.zeros(
(160, 160), dtype=torch.float32)
params_dict["kronecker_rotation_m"] = torch.zeros(
(160, 160), dtype=torch.float32)
return params_dict
def get_pergroup_param(self,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
layer_type: Optional[str] = None) -> Dict[str, Any]:
return {}
def apply_rotation(self, layer, x):
def apply_rotation(self, layer: torch.nn.Module,
x: torch.Tensor) -> torch.Tensor:
"""Apply rotation transformation to input tensor."""
init_shape = x.shape
dtype = x.dtype
if self.rotation_type == "heads_rotation":
@@ -94,17 +101,26 @@ class AscendW4A4LaosDynamicLinearMethod:
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = None,
tp_rank: Optional[int] = 0,
) -> torch.Tensor:
dtype = x.dtype
x, pertoken_scale = torch_npu.npu_dynamic_quant(x, dst_type=torch.quint4x2)
pertoken_scale = pertoken_scale.reshape(-1, 1)
pertoken_scale = pertoken_scale.squeeze(-1)
y2 = torch_npu.npu_quant_matmul(x, layer.weight.data, scale=layer.weight_scale.data.view(-1), pertoken_scale=pertoken_scale, bias=None, output_dtype=dtype)
return y2
output = torch_npu.npu_quant_matmul(
x,
layer.weight.data,
scale=layer.weight_scale.data.view(-1),
pertoken_scale=pertoken_scale,
bias=None,
output_dtype=dtype)
if bias is not None:
output = output + bias.to(dtype)
return output
def process_weights_after_loading(self, layer: torch.nn.Module):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight_scale.data = layer.weight_scale.data.to(torch.float32)
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(layer.weight.data.to(torch.int32))
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
layer.weight.data.to(torch.int32))
if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(-1, -2)

View File

@@ -29,10 +29,13 @@ from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.utils import maybe_trans_nz
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
from .registry import register_scheme
class AscendW4A8DynamicLinearMethod:
"""Linear method for Ascend W4A8_DYNAMIC
"""
@register_scheme("W4A8_DYNAMIC", "linear")
class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
"""Linear method for Ascend W4A8_DYNAMIC."""
def __init__(self):
vllm_config = get_current_vllm_config()
@@ -72,23 +75,12 @@ class AscendW4A8DynamicLinearMethod:
return params_dict
@staticmethod
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
return {}
@staticmethod
def get_perchannel_param(output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
return {}
def get_pergroup_param(self,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
layer_type: Optional[str] = None) -> Dict[str, Any]:
"""
Create per-group quantization parameters.
"""
"""Create per-group quantization parameters."""
params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size,
1,
@@ -121,8 +113,7 @@ class AscendW4A8DynamicLinearMethod:
scale: torch.Tensor,
per_group_scale: torch.Tensor,
is_new_quant: bool = False):
"""
Process the scale for second-level quantization.
"""Process the scale for second-level quantization.
Args:
weight: weight tensor [k, n] (in new version, n is already compressed to n/2)
@@ -207,9 +198,12 @@ class AscendW4A8DynamicLinearMethod:
layer.weight.data.to(torch.int32))
class AscendW4A8DynamicFusedMoEMethod:
"""FusedMoe method for Ascend W4A8_DYNAMIC.
"""
@register_scheme("W4A8_DYNAMIC", "moe")
class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
"""FusedMoE method for Ascend W4A8_DYNAMIC."""
# Declare the quantization type for this scheme
quant_type: QuantType = QuantType.W4A8
def __init__(self):
self.ep_group = get_ep_group()
@@ -339,7 +333,7 @@ class AscendW4A8DynamicFusedMoEMethod:
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = False,
log2phy: torch.Tensor = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
**kwargs,
) -> torch.Tensor:

View File

@@ -22,17 +22,22 @@ import torch_npu
from vllm_ascend.utils import maybe_trans_nz
from .base import AscendLinearScheme
from .registry import register_scheme
class AscendW8A16LinearMethod:
@register_scheme("W8A16", "linear")
class AscendW8A16LinearMethod(AscendLinearScheme):
"""Linear method for Ascend W8A16.
This scheme uses 8-bit quantized weights with 16-bit activations.
"""
def __init__(self) -> None:
pass
@staticmethod
def get_weight(
self,
input_size: int,
output_size: int,
params_dtype: torch.dtype = torch.bfloat16,
@@ -42,12 +47,8 @@ class AscendW8A16LinearMethod:
}
return params_dict
@staticmethod
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
return {}
@staticmethod
def get_perchannel_param(
self,
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
@@ -60,15 +61,8 @@ class AscendW8A16LinearMethod:
dtype=params_dtype)
return params_dict
def get_pergroup_param(self,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
layer_type: Optional[str] = None) -> Dict[str, Any]:
return {}
@staticmethod
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,

View File

@@ -32,28 +32,39 @@ from vllm_ascend.ops.fused_moe.experts_selector import (select_experts,
zero_experts_compute)
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, maybe_trans_nz
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
from .registry import register_scheme
class AscendW8A8DynamicLinearMethod:
def scale_from_float_to_int64(scale):
"""Convert float32 scale to int64 representation."""
import numpy as np
scale = torch.from_numpy(
np.frombuffer(scale.cpu().to(torch.float32).numpy().tobytes(),
dtype=np.int32).astype(np.int64)).to(scale.device)
return scale
@register_scheme("W8A8_DYNAMIC", "linear")
class AscendW8A8DynamicLinearMethod(AscendLinearScheme):
"""Linear method for Ascend W8A8_DYNAMIC.
This scheme uses dynamic per-token quantization for activations
and per-channel quantization for weights.
"""
def __init__(self):
pass
@staticmethod
def get_weight(input_size: int, output_size: int,
def get_weight(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
params_dict = {
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
}
return params_dict
@staticmethod
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
return {}
@staticmethod
def get_perchannel_param(
self,
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
@@ -66,15 +77,8 @@ class AscendW8A8DynamicLinearMethod:
dtype=params_dtype)
return params_dict
def get_pergroup_param(self,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
layer_type: Optional[str] = None) -> Dict[str, Any]:
return {}
@staticmethod
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
@@ -100,9 +104,12 @@ class AscendW8A8DynamicLinearMethod:
layer.weight_offset.data = layer.weight_offset.data.flatten()
class AscendW8A8DynamicFusedMoEMethod:
"""FusedMoe method for Ascend W8A8_DYNAMIC.
"""
@register_scheme("W8A8_DYNAMIC", "moe")
class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
"""FusedMoE method for Ascend W8A8_DYNAMIC."""
# Declare the quantization type for this scheme
quant_type: QuantType = QuantType.W8A8
def __init__(self):
self.ep_group = get_ep_group()
@@ -128,9 +135,8 @@ class AscendW8A8DynamicFusedMoEMethod:
except AttributeError:
self.moe_all_to_all_group_name = ""
@staticmethod
def get_weight(num_experts: int, intermediate_size_per_partition: int,
hidden_sizes: int,
def get_weight(self, num_experts: int,
intermediate_size_per_partition: int, hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
param_dict = {}
param_dict["w13_weight"] = torch.empty(num_experts,
@@ -144,8 +150,7 @@ class AscendW8A8DynamicFusedMoEMethod:
dtype=torch.int8)
return param_dict
@staticmethod
def get_dynamic_quant_param(num_experts: int,
def get_dynamic_quant_param(self, num_experts: int,
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
@@ -188,7 +193,7 @@ class AscendW8A8DynamicFusedMoEMethod:
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = False,
log2phy: torch.Tensor = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
pertoken_scale: Optional[Any] = None,
**kwargs,
@@ -324,11 +329,3 @@ class AscendW8A8DynamicFusedMoEMethod:
del layer.w13_weight_scale_fp32
del layer.w2_weight_scale
torch.npu.empty_cache()
def scale_from_float_to_int64(scale):
import numpy as np
scale = torch.from_numpy(
np.frombuffer(scale.cpu().to(torch.float32).numpy().tobytes(),
dtype=np.int32).astype(np.int64)).to(scale.device)
return scale

View File

@@ -21,9 +21,17 @@ import torch
import torch_npu
from vllm.config import get_current_vllm_config
from .base import AscendLinearScheme
from .registry import register_scheme
class AscendW8A8MXFP8DynamicLinearMethod:
"""Linear method for Ascend W8A8_DYNAMIC.
@register_scheme("W8A8_MXFP8", "linear")
class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme):
"""Linear method for Ascend W8A8_MXFP8 (Microscaling FP8) quantization.
This scheme uses microscaling FP8 quantization with per-group scales.
The activation is dynamically quantized to FP8 (E4M3FN format) with
microscaling, and weights are stored in FP8 format with per-group scales.
"""
model_dtype = None
@@ -32,8 +40,7 @@ class AscendW8A8MXFP8DynamicLinearMethod:
self.group_size = vllm_config.quant_config.quant_description.get(
"group_size", 32)
@staticmethod
def get_weight(input_size: int, output_size: int,
def get_weight(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
params_dict = {
"weight":
@@ -41,17 +48,6 @@ class AscendW8A8MXFP8DynamicLinearMethod:
}
return params_dict
@staticmethod
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
return {}
@staticmethod
def get_perchannel_param(
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
return {}
def get_pergroup_param(self,
input_size: int,
output_size: int,

View File

@@ -0,0 +1,117 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""W8A8 Prefill-Decode Mix quantization methods.
This module provides quantization methods that use different strategies
for prefill and decode phases:
- Prefill: Uses dynamic W8A8 quantization
- Decode (KV consumer): Uses static W8A8 quantization
"""
from typing import Any, Dict, Optional
import torch
from vllm.config import get_current_vllm_config
from .base import AscendLinearScheme
from .registry import register_scheme
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
AscendW8A8DynamicLinearMethod)
from .w8a8_static import AscendW8A8LinearMethod
@register_scheme("W8A8_MIX", "linear")
class AscendW8A8PDMixLinearMethod(AscendLinearScheme):
"""Linear method for W8A8 prefill-decode mix quantization.
This scheme uses composition to delegate to the appropriate quantization
method based on the execution phase:
- Static W8A8 for KV consumer (decode phase)
- Dynamic W8A8 for prefill phase
The static method is used for weight/parameter specifications since
it requires more parameters (input_scale, deq_scale, etc.) that are
needed for static quantization during decode.
"""
def __init__(self):
self._static_method = AscendW8A8LinearMethod()
self._dynamic_method = AscendW8A8DynamicLinearMethod()
kv_transfer_config = get_current_vllm_config().kv_transfer_config
self._is_kv_consumer = (kv_transfer_config is not None
and kv_transfer_config.is_kv_consumer)
def get_weight(self, input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
return self._static_method.get_weight(input_size, output_size,
params_dtype)
def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]:
return self._static_method.get_pertensor_param(params_dtype)
def get_perchannel_param(
self,
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
return self._static_method.get_perchannel_param(
output_size, params_dtype)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0,
) -> torch.Tensor:
if layer.is_kv_consumer:
return self._static_method.apply(layer, x, bias, tp_rank)
else:
return self._dynamic_method.apply(layer, x, bias, tp_rank)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self._static_method.process_weights_after_loading(layer)
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
layer.is_kv_consumer = self._is_kv_consumer
@register_scheme("W8A8_MIX", "moe")
class AscendW8A8PDMixFusedMoeMethod(AscendW8A8DynamicFusedMoEMethod):
def get_dynamic_quant_param(self, num_experts: int,
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
param_dict = super().get_dynamic_quant_param(
num_experts, intermediate_size_per_partition, hidden_sizes,
params_dtype)
param_dict["w2_deq_scale"] = torch.empty(num_experts,
hidden_sizes,
dtype=torch.float32)
param_dict["w13_deq_scale"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
dtype=torch.float32)
param_dict["w2_input_offset"] = torch.empty(num_experts,
1,
dtype=torch.int8)
param_dict["w13_input_offset"] = torch.empty(num_experts,
1,
dtype=torch.int8)
return param_dict

View File

@@ -24,27 +24,23 @@ from vllm_ascend.utils import (COMPRESSED_TENSORS_METHOD, AscendDeviceType,
get_ascend_device_type,
get_weight_prefetch_method, maybe_trans_nz)
def quant_per_tensor(in_tensor: torch.Tensor,
input_scale: torch.Tensor,
input_offset: torch.Tensor,
function=False):
return torch_npu.npu_quantize(in_tensor, input_scale, input_offset,
torch.qint8, -1, function)
from .base import AscendLinearScheme
from .registry import register_scheme
class AscendW8A8LinearMethod:
"""Linear method for Ascend W8A8.
@register_scheme("W8A8", "linear")
class AscendW8A8LinearMethod(AscendLinearScheme):
"""Linear method for Ascend W8A8 static quantization.
Args:
w_sym: whether the linear weight is symmetrically quantized.
This scheme uses static per-tensor quantization for activations
and per-channel quantization for weights.
"""
def __init__(self) -> None:
pass
@staticmethod
def get_weight(
self,
input_size: int,
output_size: int,
params_dtype: torch.dtype = torch.bfloat16,
@@ -54,15 +50,14 @@ class AscendW8A8LinearMethod:
}
return params_dict
@staticmethod
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]:
params_dict = {}
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
return params_dict
@staticmethod
def get_perchannel_param(
self,
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
@@ -82,15 +77,8 @@ class AscendW8A8LinearMethod:
dtype=params_dtype)
return params_dict
def get_pergroup_param(self,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
layer_type: Optional[str] = None) -> Dict[str, Any]:
return {}
@staticmethod
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,

View File

@@ -0,0 +1,471 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
"""ModelSlim quantization configuration and model mappings for Ascend.
This module provides the AscendModelSlimConfig class for parsing quantization
configs generated by the ModelSlim tool, along with model-specific mappings.
"""
from types import MappingProxyType
from typing import Any, Dict, List, Mapping, Optional
import torch
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization import \
register_quantization_config
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.vocab_parallel_embedding import (
UnquantizedEmbeddingMethod, VocabParallelEmbedding)
from vllm.model_executor.models.utils import WeightsMapper
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
from .methods import get_scheme_class
logger = init_logger(__name__)
# key: model_type
# value: orig_to_new_prefix
QUANT_MODEL_PREFIX_MAPPINGS: Dict[str, Dict[str, str]] = {
"qwen3_vl_moe": {
"visual.": "model.visual.",
"language_model.lm_head.": "lm_head.",
"language_model.model.": "model.language_model.",
},
"qwen3_vl_text": {
"visual.": "model.visual.",
"language_model.lm_head.": "lm_head.",
"language_model.model.": "model.language_model.",
},
}
# key: model_type
# value: dict of fused module name -> list of original module names
packed_modules_model_mapping: Dict[str, Dict[str, List[str]]] = {
"qwen3_moe": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
},
"deepseek_v2": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
},
"deepseek_v3": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
},
"pangu_ultra_moe": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
},
"kimi_k2": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
},
"deepseek_v32": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
},
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
# NOTE 2.The description file generated by the current msmodelslim tool does not have
# MTP layer info. Please manually add it and set the value to FLOAT.
"deepseek_mtp": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
},
"pangu_ultra_moe_mtp": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
},
"qwen3_next": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": ["gate_proj", "up_proj"],
"in_proj": ["in_proj_qkvz", "in_proj_ba"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
},
"qwen2_5_vl": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
},
"qwen3_vl_moe": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
},
"glm4_moe": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
},
"longcat_flash": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
},
"minimax_m2": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"experts": ["experts.0.w1", "experts.0.w2", "experts.0.w3"]
}
}
def get_packed_modules_mapping(model_type: str) -> Dict[str, List[str]]:
"""Get packed modules mapping for a model type.
Args:
model_type: The model type string (e.g., "deepseek_v3").
Returns:
Dictionary mapping fused module names to their component module names.
Returns empty dict if model_type is not found.
"""
return packed_modules_model_mapping.get(model_type, {})
def get_prefix_mapping(model_type: str) -> Dict[str, str]:
"""Get prefix mapping for a model type.
Args:
model_type: The model type string (e.g., "qwen3_vl_moe").
Returns:
Dictionary mapping original prefixes to new prefixes.
Returns empty dict if model_type is not found.
"""
return QUANT_MODEL_PREFIX_MAPPINGS.get(model_type, {})
def get_linear_quant_type(
quant_description: Dict[str, Any], prefix: str,
packed_modules_mapping: Dict[str, Any]) -> Optional[str]:
"""Determine the quantization type for a linear layer.
Args:
quant_description: The quantization description dictionary.
prefix: The layer prefix.
packed_modules_mapping: Mapping for packed/fused modules.
Returns:
The quantization type string (e.g., "W8A8_DYNAMIC").
"""
proj_name = prefix.split(".")[-1]
if proj_name in packed_modules_mapping:
quant_type = None
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in packed_modules_mapping[proj_name]
]
for shard_prefix in shard_prefixes:
shard_quant_type = quant_description[shard_prefix + '.weight']
if quant_type is None:
quant_type = shard_quant_type
elif shard_quant_type != quant_type:
raise ValueError(
f"Not all shards of {prefix} are quantized with same quant type."
f"Shard {proj_name} uses {shard_quant_type}, but another shard"
f"use {quant_type}. Please check quantization config.")
else:
quant_type = quant_description[prefix + '.weight']
return quant_type
def get_quant_type_for_layer(
quant_description: Dict[str, Any],
prefix: str,
layer_type: str,
packed_modules_mapping: Optional[Dict[str,
Any]] = None) -> Optional[str]:
"""Determine the quantization type for a layer.
Args:
quant_description: The quantization description dictionary.
prefix: The layer prefix.
layer_type: The type of layer ("linear", "moe", "attention").
packed_modules_mapping: Mapping for packed/fused modules.
Returns:
The quantization type string (e.g., "W8A8_DYNAMIC").
"""
if packed_modules_mapping is None:
packed_modules_mapping = dict()
# Attention
if layer_type == "attention" and 'fa_quant_type' in quant_description.keys(
):
return quant_description['fa_quant_type']
# Linear / MoE
return get_linear_quant_type(quant_description, prefix,
packed_modules_mapping)
def create_scheme_for_layer(
quant_description: Dict[str, Any],
prefix: str,
layer_type: str,
packed_modules_mapping: Optional[Dict[str, Any]] = None):
"""Create a quantization scheme instance for a layer.
Args:
quant_description: The quantization description dictionary.
prefix: The layer prefix.
layer_type: The type of layer ("linear", "moe", "attention").
packed_modules_mapping: Mapping for packed/fused modules.
Returns:
An instance of the appropriate quantization scheme class.
"""
logger.info_once("Using the vLLM Ascend modelslim Quantization now!")
quant_type = get_quant_type_for_layer(quant_description, prefix,
layer_type, packed_modules_mapping)
if quant_type is None:
raise ValueError(
f"Could not determine quantization type for layer {prefix}.")
# Use registry to get scheme class
scheme_cls = get_scheme_class(quant_type, layer_type)
if scheme_cls is not None:
return scheme_cls()
raise NotImplementedError(
f"Currently, vLLM Ascend doesn't support {quant_type} for {layer_type}."
)
@register_quantization_config(ASCEND_QUANTIZATION_METHOD)
class AscendModelSlimConfig(QuantizationConfig):
"""Config class for Ascend ModelSlim quantization.
This class is a general class that parses quantization configs
that are supported on Ascend hardware, specifically for models
quantized using the ModelSlim tool.
"""
def __init__(self, quant_config: Dict[str, Any]):
super().__init__()
self.quant_description = quant_config
# TODO(whx): remove this adaptation after adding "shared_head"
# to prefix of DeepSeekShareHead in vLLM.
extra_quant_dict = {}
for k in self.quant_description.keys():
if "shared_head" in k:
new_k = k.replace(".shared_head.", ".")
extra_quant_dict[new_k] = self.quant_description[k]
if "weight_packed" in k:
new_k = k.replace("weight_packed", "weight")
extra_quant_dict[new_k] = self.quant_description[k]
self.quant_description.update(extra_quant_dict)
def __repr__(self) -> str:
return "AscendModelSlimConfig:\n" + super().__repr__()
@classmethod
def get_name(cls) -> str:
return ASCEND_QUANTIZATION_METHOD
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.int8, torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
raise NotImplementedError(
"Ascend hardware dose not support \"get_min_capability\" feature.")
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quant_model_description.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AscendModelSlimConfig":
return cls(config)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
if hf_quant_cfg is not None:
quant_method = hf_quant_cfg.get("quant_method", None)
if not quant_method and torch.npu.is_available():
return ASCEND_QUANTIZATION_METHOD
return None
def quant_prefix_mapper(self, model_type: str, prefix: str) -> str:
# TODO (Levi-JQ): will be removed when QuantizationConfig.apply_vllm_mapper is implemented
prefix_mapping = QUANT_MODEL_PREFIX_MAPPINGS.get(model_type)
if prefix_mapping:
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix=prefix_mapping)
return hf_to_vllm_mapper._map_name(prefix)
return prefix
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from .method_adapters import (AscendEmbeddingMethod, AscendFusedMoEMethod,
AscendKVCacheMethod, AscendLinearMethod)
vllm_config = get_current_vllm_config()
model_type = vllm_config.model_config.hf_config.model_type
if model_type in ["minimax", "minimax_m2"]:
# Adapt to Minimax architecture: update layer names to MoE convention
prefix = prefix.replace("mlp", "block_sparse_moe")
# Normalize the prefix by stripping specific expert indices (e.g., 'experts.0' -> 'experts')
parts = prefix.split('.')
if "experts" in parts and len(parts) > 2:
exp_idx = parts.index("experts")
if exp_idx + 1 < len(parts) and parts[exp_idx + 1].isdigit():
parts = parts[:exp_idx + 1]
prefix = ".".join(parts)
if model_type in packed_modules_model_mapping:
self.packed_modules_mapping = packed_modules_model_mapping[
model_type]
prefix = self.quant_prefix_mapper(model_type, prefix)
from vllm.attention.layer import Attention
if prefix.startswith("language_model"):
prefix = prefix.split('.', 1)[-1]
if isinstance(layer, LinearBase):
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):
# Delayed import to avoid circular import
from vllm_ascend.ops.linear import \
AscendUnquantizedLinearMethod
return AscendUnquantizedLinearMethod()
scheme = create_scheme_for_layer(self.quant_description, prefix,
"linear",
self.packed_modules_mapping)
return AscendLinearMethod(scheme)
elif isinstance(layer, Attention) and \
'fa_quant_type' in self.quant_description.keys() and \
self.quant_description['fa_quant_type'] is not None:
scheme = create_scheme_for_layer(self.quant_description, prefix,
"attention",
self.packed_modules_mapping)
return AscendKVCacheMethod(scheme)
elif isinstance(layer, FusedMoE):
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):
# Delayed import to avoid circular import
from vllm_ascend.ops.fused_moe.fused_moe import \
AscendUnquantizedFusedMoEMethod
return AscendUnquantizedFusedMoEMethod(layer.moe_config)
scheme = create_scheme_for_layer(self.quant_description, prefix,
"moe",
self.packed_modules_mapping)
return AscendFusedMoEMethod(scheme, layer.moe_config)
elif isinstance(layer, VocabParallelEmbedding):
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):
return UnquantizedEmbeddingMethod()
scheme = create_scheme_for_layer(self.quant_description, prefix,
"linear",
self.packed_modules_mapping)
return AscendEmbeddingMethod(scheme)
return None
def is_layer_skipped_ascend(
self,
prefix: str,
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})):
# adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped
proj_name = prefix.split(".")[-1]
if proj_name in fused_mapping:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in fused_mapping[proj_name]
]
is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = self.quant_description[shard_prefix +
'.weight'] == "FLOAT"
if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision.")
else:
is_skipped = self.quant_description[prefix + '.weight'] == "FLOAT"
assert is_skipped is not None
return is_skipped
def get_scaled_act_names(self) -> List[str]:
return []

View File

@@ -1,600 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from types import MappingProxyType
from typing import Any, Callable, Dict, List, Mapping, Optional
import torch
from vllm.config import get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
RowParallelLinear)
from vllm.model_executor.layers.quantization import \
register_quantization_config
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.vocab_parallel_embedding import (
UnquantizedEmbeddingMethod, VocabParallelEmbedding)
from vllm.model_executor.models.utils import WeightsMapper
from vllm.model_executor.parameter import PerTensorScaleParameter
from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import (get_flashcomm2_otp_group,
get_mlp_tp_group,
get_otp_group)
from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, flashcomm2_enable,
mlp_tp_enable, oproj_tp_enable)
from .utils import get_quant_method, is_mx_quant_type
@register_quantization_config(ASCEND_QUANTIZATION_METHOD)
class AscendQuantConfig(QuantizationConfig):
"""Config class for Ascend
This class is a general class that parse quantization configs
that are supported on ascend hardware.
"""
def __init__(self, quant_config: Dict[str, Any]):
super().__init__()
self.quant_description = quant_config
# TODO(whx): remove this adaptation after adding "shared_head"
# to prefix of DeepSeekShareHead in vLLM.
extra_quant_dict = {}
for k in self.quant_description.keys():
if "shared_head" in k:
new_k = k.replace(".shared_head.", ".")
extra_quant_dict[new_k] = self.quant_description[k]
if "weight_packed" in k:
new_k = k.replace("weight_packed", "weight")
extra_quant_dict[new_k] = self.quant_description[k]
self.quant_description.update(extra_quant_dict)
def __repr__(self) -> str:
return "AscendQuantConfig:\n" + super().__repr__()
@classmethod
def get_name(cls) -> str:
return ASCEND_QUANTIZATION_METHOD
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.int8, torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
raise NotImplementedError(
"Ascend hardware dose not support \"get_min_capability\" feature.")
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quant_model_description.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AscendQuantConfig":
return cls(config)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
if hf_quant_cfg is not None:
quant_method = hf_quant_cfg.get("quant_method", None)
if not quant_method and torch.npu.is_available():
return ASCEND_QUANTIZATION_METHOD
return None
def quant_prefix_mapper(self, model_type: str, prefix: str) -> str:
# TODO (Levi-JQ): will be removed when QuantizationConfig.apply_vllm_mapper is implemented
prefix_mapping = QUANT_MODEL_PREFIX_MAPPINGS.get(model_type)
if prefix_mapping:
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix=prefix_mapping)
return hf_to_vllm_mapper._map_name(prefix)
return prefix
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
vllm_config = get_current_vllm_config()
model_type = vllm_config.model_config.hf_text_config.model_type
if model_type in ["minimax", "minimax_m2"]:
prefix = prefix.replace("mlp", "block_sparse_moe")
#To adapt to minimax, modify the prefix of the model layer name
parts = prefix.split('.')
if "experts" in parts and len(parts) > 2:
exp_idx = parts.index("experts")
if exp_idx + 1 < len(parts) and parts[exp_idx + 1].isdigit():
parts = parts[:exp_idx + 1]
prefix = ".".join(parts)
if model_type in packed_modules_model_mapping:
self.packed_modules_mapping = packed_modules_model_mapping[
model_type]
prefix = self.quant_prefix_mapper(model_type, prefix)
from vllm.attention.layer import Attention
if prefix.startswith("language_model"):
prefix = prefix.split('.', 1)[-1]
if isinstance(layer, LinearBase):
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):
return AscendUnquantizedLinearMethod()
return AscendLinearMethod(self, prefix,
self.packed_modules_mapping, layer)
elif isinstance(layer, Attention) and \
'fa_quant_type' in self.quant_description.keys() and \
self.quant_description['fa_quant_type'] is not None:
return AscendKVCacheMethod(self, prefix)
elif isinstance(layer, FusedMoE):
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):
return AscendUnquantizedFusedMoEMethod(layer.moe_config)
return AscendFusedMoEMethod(self, prefix,
self.packed_modules_mapping, layer)
elif isinstance(layer, VocabParallelEmbedding):
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):
return UnquantizedEmbeddingMethod()
return AscendEmbeddingMethod(self, prefix,
self.packed_modules_mapping, layer)
return None
def is_layer_skipped_ascend(
self,
prefix: str,
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})):
# adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped
proj_name = prefix.split(".")[-1]
if proj_name in fused_mapping:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in fused_mapping[proj_name]
]
is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = self.quant_description[shard_prefix +
'.weight'] == "FLOAT"
if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision.")
else:
# NOTE: In GLM4.6, the MTP draft model shares the same LM head weigthts
# with the main model. Therefore, before `load_weights()` runs, some parameter
# names may not include the expected prefix and may appear only with the
# ".head" suffix. This can trigger a load-time error, so here we replace the
# key with "lm_head.weight".
key = prefix + '.weight'
if key not in self.quant_description and ".head" in prefix:
key = 'lm_head.weight'
is_skipped = self.quant_description[key] == "FLOAT"
assert is_skipped is not None
return is_skipped
def get_scaled_act_names(self) -> List[str]:
return []
# key: model_type
# value: orig_to_new_prefix
QUANT_MODEL_PREFIX_MAPPINGS = {
"qwen3_vl_moe": {
"visual.": "model.visual.",
"language_model.lm_head.": "lm_head.",
"language_model.model.": "model.language_model.",
},
"qwen3_vl_text": {
"visual.": "model.visual.",
"language_model.lm_head.": "lm_head.",
"language_model.model.": "model.language_model.",
},
}
packed_modules_model_mapping = {
"qwen3_moe": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
},
"deepseek_v2": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
},
"deepseek_v3": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
},
"pangu_ultra_moe": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
},
"kimi_k2": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
},
"deepseek_v32": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
},
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
# NOTE 2.The description file generated by the current msmodelslim tool does not have
# MTP layer info. Please manually add it and set the value to FLOAT.
"deepseek_mtp": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
},
"pangu_ultra_moe_mtp": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
},
"qwen3_next": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": ["gate_proj", "up_proj"],
"in_proj": ["in_proj_qkvz", "in_proj_ba"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
},
"qwen2_5_vl": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
},
"qwen3_vl_moe": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
},
"glm4_moe": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
},
"longcat_flash": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
},
"minimax_m2": {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"experts": ["experts.0.w1", "experts.0.w2", "experts.0.w3"]
}
}
class AscendLinearMethod(LinearMethodBase):
"""Linear method for Ascend quantization.
Args:
quant_config: The Ascend quantization config.
"""
def __init__(self,
quant_config: AscendQuantConfig,
prefix: str,
packed_modules_mapping: Dict[str, Any] | None,
layer: torch.nn.Module = None) -> None:
self.quant_method = get_quant_method(quant_config.quant_description,
prefix,
"linear",
packed_modules_mapping,
layer=layer)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
weight_dict = self.quant_method.get_weight(input_size_per_partition,
output_size_per_partition,
params_dtype)
# Extract packing information (if present)
packed_dim = weight_dict.pop("_packed_dim", None)
packed_factor = weight_dict.pop("_packed_factor", None)
for weight_name, weight_param in weight_dict.items():
param = torch.nn.Parameter(weight_param, requires_grad=False)
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
# Set packing attributes if the weight is packed
if packed_dim is not None and packed_factor is not None:
set_weight_attrs(param, {
"packed_dim": packed_dim,
"packed_factor": packed_factor
})
layer.register_parameter(weight_name, param)
set_weight_attrs(param, extra_weight_attrs)
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
for pertensor_name, pertensor_param in pertensor_dict.items():
param = PerTensorScaleParameter(data=pertensor_param,
weight_loader=weight_loader)
# disable warning
param.ignore_warning = True
layer.register_parameter(pertensor_name, param)
param.weight_loader = extra_weight_attrs.get("weight_loader")
perchannel_dict = self.quant_method.get_perchannel_param(
output_size_per_partition, params_dtype)
for perchannel_name, perchannel_param in perchannel_dict.items():
param = torch.nn.Parameter(perchannel_param, requires_grad=False)
set_weight_attrs(param, {"output_dim": 0})
layer.register_parameter(perchannel_name, param)
set_weight_attrs(param, extra_weight_attrs)
# NOTE: In w4a8 quantization implementation,
# for down_proj and o_proj scale_bias shape is [output_size, 16],
# others are [output_size, 1]
layer_type = "row" if isinstance(layer,
RowParallelLinear) else "others"
pergroup_dict = self.quant_method.get_pergroup_param(
input_size_per_partition,
output_size_per_partition,
params_dtype,
layer_type=layer_type)
for pergroup_name, pergroup_param in pergroup_dict.items():
param = torch.nn.Parameter(pergroup_param, requires_grad=False)
set_weight_attrs(param, {"output_dim": 0})
layer.register_parameter(pergroup_name, param)
set_weight_attrs(param, extra_weight_attrs)
if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name \
or is_mx_quant_type(self.quant_method):
setattr(param, "input_dim", 1)
param.input_dim = 1
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
self.quant_method.process_weights_after_loading(layer)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(layer, RowParallelLinear):
if layer.prefix.find("o_proj") != -1 and oproj_tp_enable():
tp_rank = get_otp_group().rank_in_group
elif layer.prefix.find("down_proj") != -1 and mlp_tp_enable():
tp_rank = get_mlp_tp_group().rank_in_group
elif (layer.prefix.find("o_proj") != -1 or
layer.prefix.find("out_proj") != -1) and flashcomm2_enable():
if get_ascend_config(
).flashcomm2_oproj_tensor_parallel_size == 1:
tp_rank = 0
else:
tp_rank = get_flashcomm2_otp_group().rank_in_group
else:
tp_rank = get_tensor_model_parallel_rank()
else:
tp_rank = 0
return self.quant_method.apply(layer, x, bias, tp_rank)
class AscendKVCacheMethod(BaseKVCacheMethod):
"""KVCache method for Ascend quantization.
Args:
quant_config: The Ascend quantization config.
"""
def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None:
self.quant_method = get_quant_method(quant_config.quant_description,
prefix, "attention")
def create_weights(self, layer: torch.nn.Module) -> None:
# Different from linear method, there are no weight processing/slicing
# steps for attention in vllm. So the whole process of create weights
# is hidden into the specific quant method.
self.quant_method.create_weights(layer)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
self.quant_method.process_weights_after_loading(layer)
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata,
attn_type, scale, output) -> torch.Tensor:
return self.quant_method.apply(layer, query, key, value, kv_cache,
attn_metadata, attn_type, scale, output)
class AscendFusedMoEMethod(FusedMoEMethodBase):
"""FusedMoE method for Ascend quantization.
Args:
quant_config: The Ascend quantization config.
"""
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
packed_modules_mapping: Dict[str,
Any], layer: torch.nn.Module):
super().__init__(layer.moe_config)
self.quant_method = get_quant_method(quant_config.quant_description,
prefix,
"moe",
packed_modules_mapping,
layer=layer)
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
weight_param = self.quant_method.get_weight(
num_experts, intermediate_size_per_partition, hidden_size,
params_dtype)
for param_key, param_value in weight_param.items():
param = torch.nn.Parameter(param_value, requires_grad=False)
layer.register_parameter(param_key, param)
set_weight_attrs(param, extra_weight_attrs)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
per_group_param = [
"weight_scale_second", "weight_offset_second", "scale_bias"
] + ["weight_scale", "weight_offset"] if hasattr(
self.quant_method,
"group_size") and self.quant_method.group_size > 0 else []
dynamic_quant_param = self.quant_method.get_dynamic_quant_param(
num_experts, intermediate_size_per_partition, hidden_size,
params_dtype)
for param_key, param_value in dynamic_quant_param.items():
param = torch.nn.Parameter(param_value, requires_grad=False)
layer.register_parameter(param_key, param)
set_weight_attrs(param, extra_weight_attrs)
if any(fields in param_key for fields in per_group_param):
setattr(param, "quant_method",
FusedMoeWeightScaleSupported.GROUP.value)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = False,
log2phy: torch.Tensor = None,
global_redundant_expert_num=0,
**kwargs,
) -> torch.Tensor:
return self.quant_method.apply(
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
global_num_experts, expert_map, topk_group, num_expert_group,
custom_routing_function, scoring_func, routed_scaling_factor,
e_score_correction_bias, is_prefill, enable_force_load_balance,
log2phy, global_redundant_expert_num, **kwargs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
self.quant_method.process_weights_after_loading(layer)
def get_fused_moe_quant_config(self, layer: torch.nn.Module):
# TODO: implement this function
pass
@property
def supports_eplb(self):
supports_eplb = getattr(self.quant_method, "supports_eplb", False)
return supports_eplb
class AscendEmbeddingMethod(AscendLinearMethod):
"""Embedding method for Ascend quantization.
Args:
quant_config: The Ascend quantization config.
"""
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
packed_modules_mapping: Dict[str, Any],
layer: torch.nn.Module) -> None:
self.quant_method = get_quant_method(quant_config.quant_description,
prefix,
"linear",
packed_modules_mapping,
layer=layer)

View File

@@ -1,129 +0,0 @@
from typing import Any, Dict, Optional, Type
import torch
from vllm.logger import logger
from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD
from .w4a4_flatquant_dynamic import AscendW4A4FlatQuantDynamicLinearMethod
from .w4a4_laos_dynamic import AscendW4A4LaosDynamicLinearMethod
from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod,
AscendW4A8DynamicLinearMethod)
from .w4a16 import AscendW4A16FusedMoEMethod
from .w8a8 import AscendW8A8LinearMethod
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
AscendW8A8DynamicLinearMethod)
from .w8a8_pdmix import (AscendW8A8PDMixFusedMoeMethod,
AscendW8A8PDMixLinearMethod)
from .w8a8mxfp8 import AscendW8A8MXFP8DynamicLinearMethod
from .w8a16 import AscendW8A16LinearMethod
ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = {
"W4A16": {
"moe": AscendW4A16FusedMoEMethod,
},
"W4A8_DYNAMIC": {
"linear": AscendW4A8DynamicLinearMethod,
"moe": AscendW4A8DynamicFusedMoEMethod,
},
"W4A4_DYNAMIC": {
"linear": AscendW4A4LaosDynamicLinearMethod,
},
"W4A4_FLATQUANT_DYNAMIC": {
"linear": AscendW4A4FlatQuantDynamicLinearMethod,
},
"W8A8": {
"linear": AscendW8A8LinearMethod,
},
"W8A8_DYNAMIC": {
"linear": AscendW8A8DynamicLinearMethod,
"moe": AscendW8A8DynamicFusedMoEMethod,
},
"W8A8_MIX": {
"linear": AscendW8A8PDMixLinearMethod,
"moe": AscendW8A8PDMixFusedMoeMethod,
},
"W8A16": {
"linear": AscendW8A16LinearMethod,
},
"W8A8_MXFP8": {
"linear": AscendW8A8MXFP8DynamicLinearMethod,
},
}
def get_linear_quant_type(quant_description: Dict[str, Any], prefix: str,
packed_modules_mapping: Dict[str, Any]):
proj_name = prefix.split(".")[-1]
if proj_name in packed_modules_mapping:
quant_type = None
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in packed_modules_mapping[proj_name]
]
for shard_prefix in shard_prefixes:
shard_quant_type = quant_description[shard_prefix + '.weight']
if quant_type is None:
quant_type = shard_quant_type
elif shard_quant_type != quant_type:
raise ValueError(
f"Not all shards of {prefix} are quantized with same quant type."
f"Shard {proj_name} uses {shard_quant_type}, but another shard"
f"use {quant_type}. Please check quantization config.")
else:
quant_type = quant_description[prefix + '.weight']
return quant_type
def get_quant_method(quant_description: Dict[str, Any],
prefix: str,
layer_type: str,
packed_modules_mapping: Optional[Dict[str, Any]] = None,
layer: torch.nn.Module = None):
if quant_description.get("quant_method") == COMPRESSED_TENSORS_METHOD:
return get_quant_method_llmcompressor(layer)
return get_quant_method_modelslim(quant_description, prefix, layer_type,
packed_modules_mapping)
def get_quant_method_llmcompressor(layer: torch.nn.Module):
logger.info_once("Using the vLLM Ascend llmcompressor Quantization now!")
if layer.scheme is None:
raise ValueError("A scheme must be defined for each layer")
return layer.scheme
def get_quant_method_modelslim(
quant_description: Dict[str, Any],
prefix: str,
layer_type: str,
packed_modules_mapping: Optional[Dict[str, Any]] = None):
logger.info_once("Using the vLLM Ascend modelslim Quantization now!")
if packed_modules_mapping is None:
packed_modules_mapping = dict()
# Attention
if '.attn' in prefix and 'fa_quant_type' in quant_description.keys():
quant_type = quant_description['fa_quant_type']
# Linear
else:
quant_type = get_linear_quant_type(quant_description, prefix,
packed_modules_mapping)
if quant_type in ASCEND_QUANTIZATION_METHOD_MAP.keys():
method_map = ASCEND_QUANTIZATION_METHOD_MAP[quant_type]
if layer_type in method_map.keys():
method_cls = method_map[layer_type]
return method_cls()
else:
raise NotImplementedError(
f"Currently, vLLM Ascend doesn't support {quant_type} for {layer_type}."
)
raise NotImplementedError("Currently, vLLM Ascend only supports following quant types:" \
f"{list(ASCEND_QUANTIZATION_METHOD_MAP.keys())}")
def is_mx_quant_type(instance: Any) -> bool:
"""Checks if the quantization method is a mix-precision type."""
MX_QUANT_TYPES = (AscendW8A8MXFP8DynamicLinearMethod, )
return isinstance(instance, MX_QUANT_TYPES)

View File

@@ -1,70 +0,0 @@
from typing import Any, Dict, cast
import torch
from vllm.config import get_current_vllm_config
from .w8a8 import AscendW8A8LinearMethod
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
AscendW8A8DynamicLinearMethod)
class AscendW8A8PDMixLinearMethod(AscendW8A8DynamicLinearMethod):
def __init__(self):
self.kv_transfer_config = get_current_vllm_config().kv_transfer_config
super().__init__()
@staticmethod
def apply(layer, x, bias=None, tp_rank=0):
if layer.is_kv_consumer:
return AscendW8A8LinearMethod.apply(layer, x, bias, tp_rank)
else:
return AscendW8A8DynamicLinearMethod.apply(layer, x, bias, tp_rank)
@staticmethod
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
return AscendW8A8LinearMethod.get_pertensor_param(params_dtype)
@staticmethod
def get_perchannel_param(
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
return AscendW8A8LinearMethod.get_perchannel_param(
output_size, params_dtype)
def process_weights_after_loading(self, layer):
AscendW8A8LinearMethod.process_weights_after_loading(
cast(AscendW8A8LinearMethod, self), layer)
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
layer.is_kv_consumer = self.kv_transfer_config is not None and self.kv_transfer_config.is_kv_consumer
class AscendW8A8PDMixFusedMoeMethod(AscendW8A8DynamicFusedMoEMethod):
def __init__(self):
super().__init__()
@staticmethod
def get_dynamic_quant_param(num_experts: int,
intermediate_size_per_partition: int,
hidden_sizes: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
param_dict = AscendW8A8DynamicFusedMoEMethod.get_dynamic_quant_param(
num_experts, intermediate_size_per_partition, hidden_sizes,
params_dtype)
param_dict["w2_deq_scale"] = torch.empty(num_experts,
hidden_sizes,
dtype=torch.float32)
param_dict["w13_deq_scale"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
dtype=torch.float32)
param_dict["w2_input_offset"] = torch.empty(num_experts,
1,
dtype=torch.int8)
param_dict["w13_input_offset"] = torch.empty(num_experts,
1,
dtype=torch.int8)
return param_dict