[Refactor] Quantization Module Refactor (#5738)
### Summary
This PR refactors the `vllm_ascend/quantization` module to improve code
organization, maintainability, and extensibility. The refactoring
introduces a clear separation of concerns with a registry-based scheme
discovery pattern, abstract base classes for quantization schemes, and
dedicated wrapper classes.
### Key Changes
#### 1. **Modular Directory Structure**
| Before | After |
|--------|-------|
| Flat file structure with mixed responsibilities | Organized into
`methods/` subpackage for schemes |
| Single `quant_config.py` (600+ lines) | Separate config files:
`modelslim_config.py`, `compressed_tensors_config.py` |
| `utils.py` with scheme lookup logic | `methods/registry.py` with
decorator-based registration |
#### 2. **Registry-Based Scheme Discovery**
Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a
decorator-based registry pattern:
```python
# Before: Manual dictionary mapping
ASCEND_QUANTIZATION_METHOD_MAP = {
"W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...},
...
}
# After: Decorator-based registration
@register_scheme("W8A8_DYNAMIC", "linear")
class AscendW8A8DynamicLinearMethod(AscendLinearScheme):
...
```
#### 3. **Abstract Base Classes**
Introduced three abstract base classes in `methods/base.py`:
- `AscendLinearScheme` - Base for linear layer quantization
- `AscendMoEScheme` - Base for MoE layer quantization
- `AscendAttentionScheme` - Base for attention layer quantization
#### 4. **Separated Config and Wrapper Classes**
- **Config classes** (`AscendModelSlimConfig`,
`AscendCompressedTensorsConfig`): Handle config parsing and scheme
selection
- **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`,
etc.): Implement vLLM interfaces and delegate to schemes
#### 5. **Cleaner Public API**
```python
# New clean module interface
from vllm_ascend.quantization import (
AscendModelSlimConfig,
AscendCompressedTensorsConfig,
)
from vllm_ascend.quantization.methods import get_scheme_class
```
### Architecture Diagram
```mermaid
classDiagram
direction TB
class QuantizationConfig {
<<vLLM Interface>>
+get_quant_method()
}
class AscendModelSlimConfig {
+quant_description
+get_quant_method()
-create_scheme_for_layer()
}
class AscendCompressedTensorsConfig {
+target_scheme_map
+get_quant_method()
-_get_scheme_from_parts()
}
class AscendLinearMethod {
<<Wrapper>>
+quant_method: AscendLinearScheme
+create_weights()
+apply()
}
class AscendFusedMoEMethod {
<<Wrapper>>
+quant_method: AscendMoEScheme
+create_weights()
+apply()
}
class AscendLinearScheme {
<<Abstract>>
+get_weight()*
+apply()*
+get_pertensor_param()
+get_perchannel_param()
}
class AscendMoEScheme {
<<Abstract>>
+get_weight()*
+get_dynamic_quant_param()*
+apply()*
}
class W8A8DynamicLinear {
+get_weight()
+apply()
}
class W8A8DynamicMoE {
+get_weight()
+apply()
}
QuantizationConfig <|-- AscendModelSlimConfig
QuantizationConfig <|-- AscendCompressedTensorsConfig
AscendModelSlimConfig ..> AscendLinearMethod : creates
AscendModelSlimConfig ..> AscendFusedMoEMethod : creates
AscendCompressedTensorsConfig ..> AscendLinearMethod : creates
AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates
AscendLinearMethod o-- AscendLinearScheme : delegates to
AscendFusedMoEMethod o-- AscendMoEScheme : delegates to
AscendLinearScheme <|-- W8A8DynamicLinear
AscendMoEScheme <|-- W8A8DynamicMoE
```
### Scheme Registration Flow
```mermaid
sequenceDiagram
participant Module as Scheme Module
participant Registry as _SCHEME_REGISTRY
participant Config as QuantConfig
participant Wrapper as Wrapper Class
Note over Module: At import time
Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear")
Registry->>Registry: Store (quant_type, layer_type) -> Class
Note over Config: At runtime
Config->>Config: Determine quant_type from description
Config->>Registry: get_scheme_class(quant_type, layer_type)
Registry-->>Config: Return scheme class
Config->>Config: scheme = scheme_cls()
Config->>Wrapper: Create wrapper with scheme
Wrapper-->>Config: Return wrapper instance
```
### File Changes Summary
| Original Files | Refactored Files |
|----------------|------------------|
| `__init__.py` (empty) | `__init__.py` (exports public API) |
| `quant_config.py` | `modelslim_config.py` + `wrappers.py` |
| `compressed_tensors/` | `compressed_tensors_config.py` |
| `utils.py` | `methods/registry.py` |
| `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` |
| `w8a8.py` | `methods/w8a8_static.py` |
| `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` |
| ... | `methods/base.py` (new) |
### Benefits
1. **Extensibility**: Adding new quantization schemes only requires
implementing the base class and adding `@register_scheme` decorator
2. **Maintainability**: Clear separation between config parsing, wrapper
logic, and scheme implementation
3. **Testability**: Abstract base classes enable easier unit testing and
mocking
4. **Discoverability**: Registry pattern makes it easy to list all
supported schemes
5. **Reduced Coupling**: Config classes no longer need to know about all
scheme implementations
___
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
@@ -10,7 +10,7 @@ The current process for registering and obtaining quantization methods in vLLM A
|
||||
|
||||

|
||||
|
||||
vLLM Ascend registers a custom ascend quantization method. By configuring the `--quantization ascend` parameter (or `quantization="ascend"` for offline), the quantization feature is enabled. When constructing the `quant_config`, the registered `AscendQuantConfig` is initialized and `get_quant_method` is called to obtain the quantization method corresponding to each weight part, stored in the `quant_method` attribute.
|
||||
vLLM Ascend registers a custom ascend quantization method. By configuring the `--quantization ascend` parameter (or `quantization="ascend"` for offline), the quantization feature is enabled. When constructing the `quant_config`, the registered `AscendModelSlimConfig` is initialized and `get_quant_method` is called to obtain the quantization method corresponding to each weight part, stored in the `quant_method` attribute.
|
||||
|
||||
Currently supported quantization methods include `AscendLinearMethod`, `AscendFusedMoEMethod`, `AscendEmbeddingMethod`, and their corresponding non-quantized methods:
|
||||
|
||||
@@ -51,18 +51,21 @@ Based on the above content, we present a brief description of the adaptation pro
|
||||
### Quantization Algorithm Adaptation
|
||||
|
||||
- **Step 1: Algorithm Design**. Define the algorithm ID (e.g., `W4A8_DYNAMIC`), determine supported layers (linear, moe, attention), and design the quantization scheme (static/dynamic, pertensor/perchannel/pergroup).
|
||||
- **Step 2: Registration**. Add the algorithm ID to `ASCEND_QUANTIZATION_METHOD_MAP` in `vllm_ascend/quantization/utils.py` and associate it with the corresponding method class.
|
||||
- **Step 2: Registration**. Use the `@register_scheme` decorator in `vllm_ascend/quantization/methods/registry.py` to register your quantization scheme class.
|
||||
|
||||
```python
|
||||
ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = {
|
||||
"W4A8_DYNAMIC": {
|
||||
"linear": AscendW4A8DynamicLinearMethod,
|
||||
"moe": AscendW4A8DynamicFusedMoEMethod,
|
||||
},
|
||||
}
|
||||
from vllm_ascend.quantization.methods import register_scheme, AscendLinearScheme
|
||||
|
||||
@register_scheme("W4A8_DYNAMIC", "linear")
|
||||
class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
|
||||
...
|
||||
|
||||
@register_scheme("W4A8_DYNAMIC", "moe")
|
||||
class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
|
||||
...
|
||||
```
|
||||
|
||||
- **Step 3: Implementation**. Create an algorithm implementation file, such as `vllm_ascend/quantization/w4a8_dynamic.py`, and implement the method class and logic.
|
||||
- **Step 3: Implementation**. Create an algorithm implementation file, such as `vllm_ascend/quantization/methods/w4a8.py`, and implement the method class and logic.
|
||||
- **Step 4: Testing**. Use your algorithm to generate quantization configurations and verify correctness and performance on target models and hardware.
|
||||
|
||||
### Quantized Model Adaptation
|
||||
@@ -70,7 +73,7 @@ ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = {
|
||||
Adapting a new quantized model requires ensuring the following three points:
|
||||
|
||||
- The original model has been successfully adapted in `vLLM Ascend`.
|
||||
- **Fused Module Mapping**: Add the model's `model_type` to `packed_modules_model_mapping` in `vllm_ascend/quantization/quant_config.py` (e.g., `qkv_proj`, `gate_up_proj`, `experts`) to ensure sharding consistency and correct loading.
|
||||
- **Fused Module Mapping**: Add the model's `model_type` to `packed_modules_model_mapping` in `vllm_ascend/quantization/modelslim_config.py` (e.g., `qkv_proj`, `gate_up_proj`, `experts`) to ensure sharding consistency and correct loading.
|
||||
|
||||
```python
|
||||
packed_modules_model_mapping = {
|
||||
|
||||
@@ -20,8 +20,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from pytest_mock import MockerFixture
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
@@ -233,25 +231,6 @@ class MockQuantMethod(nn.Module):
|
||||
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32)))
|
||||
|
||||
|
||||
class MockFusedMoEMethod(FusedMoEMethodBase):
|
||||
moe = MagicMock()
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(self.moe)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
pass
|
||||
|
||||
def apply(self, hidden_states: torch.Tensor,
|
||||
expert_weights: torch.Tensor) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module):
|
||||
pass
|
||||
|
||||
|
||||
class TestExpertsSelector:
|
||||
|
||||
@pytest.mark.parametrize("global_num_experts", [256, 128])
|
||||
|
||||
@@ -7,11 +7,11 @@ from vllm.model_executor.layers.linear import LinearBase
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
|
||||
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
||||
from vllm_ascend.quantization.modelslim_config import AscendModelSlimConfig
|
||||
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
|
||||
|
||||
|
||||
class TestAscendQuantConfig(TestBase):
|
||||
class TestAscendModelSlimConfig(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.sample_config = {
|
||||
@@ -25,7 +25,7 @@ class TestAscendQuantConfig(TestBase):
|
||||
"shard1.weight": "FLOAT",
|
||||
"shard2.weight": "FLOAT",
|
||||
}
|
||||
self.ascend_config = AscendQuantConfig(self.sample_config)
|
||||
self.ascend_config = AscendModelSlimConfig(self.sample_config)
|
||||
self.ascend_config.packed_modules_mapping = None
|
||||
|
||||
def test_init(self):
|
||||
@@ -34,55 +34,55 @@ class TestAscendQuantConfig(TestBase):
|
||||
|
||||
def test_repr(self):
|
||||
repr_str = repr(self.ascend_config)
|
||||
self.assertTrue(repr_str.startswith("AscendQuantConfig:\n"))
|
||||
self.assertTrue(repr_str.startswith("AscendModelSlimConfig:\n"))
|
||||
|
||||
def test_get_name(self):
|
||||
self.assertEqual(AscendQuantConfig.get_name(),
|
||||
self.assertEqual(AscendModelSlimConfig.get_name(),
|
||||
ASCEND_QUANTIZATION_METHOD)
|
||||
|
||||
def test_get_supported_act_dtypes(self):
|
||||
supported_dtypes = AscendQuantConfig.get_supported_act_dtypes()
|
||||
supported_dtypes = AscendModelSlimConfig.get_supported_act_dtypes()
|
||||
self.assertEqual(len(supported_dtypes), 3)
|
||||
|
||||
def test_get_min_capability(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
AscendQuantConfig.get_min_capability()
|
||||
AscendModelSlimConfig.get_min_capability()
|
||||
|
||||
def test_get_config_filenames(self):
|
||||
filenames = AscendQuantConfig.get_config_filenames()
|
||||
filenames = AscendModelSlimConfig.get_config_filenames()
|
||||
self.assertEqual(filenames, ["quant_model_description.json"])
|
||||
|
||||
def test_from_config(self):
|
||||
config = AscendQuantConfig.from_config(self.sample_config)
|
||||
self.assertIsInstance(config, AscendQuantConfig)
|
||||
config = AscendModelSlimConfig.from_config(self.sample_config)
|
||||
self.assertIsInstance(config, AscendModelSlimConfig)
|
||||
self.assertEqual(config.quant_description, self.sample_config)
|
||||
|
||||
@patch('torch.npu.is_available')
|
||||
def test_override_quantization_method(self, mock_is_available):
|
||||
# Test when NPU is available
|
||||
mock_is_available.return_value = True
|
||||
result = AscendQuantConfig.override_quantization_method(None, None)
|
||||
result = AscendModelSlimConfig.override_quantization_method(None, None)
|
||||
self.assertIsNone(result)
|
||||
hf_quant_cfg = {"quant_method": ""}
|
||||
result = AscendQuantConfig.override_quantization_method(
|
||||
result = AscendModelSlimConfig.override_quantization_method(
|
||||
hf_quant_cfg, None)
|
||||
self.assertEqual(result, "ascend")
|
||||
|
||||
# Test when NPU is not available
|
||||
mock_is_available.return_value = False
|
||||
result = AscendQuantConfig.override_quantization_method(None, None)
|
||||
result = AscendModelSlimConfig.override_quantization_method(None, None)
|
||||
self.assertIsNone(result)
|
||||
hf_quant_cfg = {"quant_method": ""}
|
||||
result = AscendQuantConfig.override_quantization_method(
|
||||
result = AscendModelSlimConfig.override_quantization_method(
|
||||
hf_quant_cfg, None)
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_get_quant_method_for_linear(self):
|
||||
mock_config = MagicMock()
|
||||
mock_config.model_config.hf_text_config.model_type = None
|
||||
mock_config.model_config.hf_config.model_type = None
|
||||
linear_layer = MagicMock(spec=LinearBase)
|
||||
# Test skipped layer
|
||||
with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
||||
with patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch.object(self.ascend_config, \
|
||||
'is_layer_skipped_ascend',
|
||||
return_value=True):
|
||||
@@ -90,22 +90,24 @@ class TestAscendQuantConfig(TestBase):
|
||||
self.assertIsInstance(method, AscendUnquantizedLinearMethod)
|
||||
|
||||
# Test quantized layer
|
||||
mock_scheme = MagicMock()
|
||||
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
|
||||
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch('vllm_ascend.quantization.quant_config.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear:
|
||||
patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch("vllm_ascend.quantization.modelslim_config.create_scheme_for_layer", return_value=mock_scheme), \
|
||||
patch('vllm_ascend.quantization.method_adapters.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear:
|
||||
|
||||
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
|
||||
self.assertIs(method, mock_ascend_linear.return_value)
|
||||
mock_ascend_linear.assert_called_once_with(
|
||||
self.ascend_config, ".attn",
|
||||
self.ascend_config.packed_modules_mapping, linear_layer)
|
||||
mock_ascend_linear.assert_called_once_with(mock_scheme)
|
||||
|
||||
def test_get_quant_method_for_attention(self):
|
||||
attention_layer = MagicMock(spec=Attention)
|
||||
mock_config = MagicMock()
|
||||
mock_config.model_config.hf_text_config.model_type = None
|
||||
with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod', \
|
||||
mock_config.model_config.hf_config.model_type = None
|
||||
mock_scheme = MagicMock()
|
||||
with patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch("vllm_ascend.quantization.modelslim_config.create_scheme_for_layer", return_value=mock_scheme), \
|
||||
patch('vllm_ascend.quantization.method_adapters.AscendKVCacheMethod', \
|
||||
return_value=MagicMock()) as mock_ascend_kvcache:
|
||||
# Test with fa_quant_type
|
||||
method = self.ascend_config.get_quant_method(
|
||||
@@ -117,20 +119,22 @@ class TestAscendQuantConfig(TestBase):
|
||||
fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig)
|
||||
fused_moe_layer.moe_config = MagicMock(spec=FusedMoEConfig)
|
||||
mock_config = MagicMock()
|
||||
mock_config.model_config.hf_text_config.model_type = None
|
||||
mock_config.model_config.hf_config.model_type = None
|
||||
|
||||
# Test skipped layer
|
||||
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \
|
||||
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch('vllm_ascend.quantization.quant_config.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
|
||||
patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch('vllm_ascend.ops.fused_moe.fused_moe.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
|
||||
method = self.ascend_config.get_quant_method(
|
||||
fused_moe_layer, "moe_layer")
|
||||
self.assertIs(method, mock_ascend_moe.return_value)
|
||||
|
||||
# Test quantized layer
|
||||
mock_scheme = MagicMock()
|
||||
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
|
||||
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch('vllm_ascend.quantization.quant_config.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
|
||||
patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch("vllm_ascend.quantization.modelslim_config.create_scheme_for_layer", return_value=mock_scheme), \
|
||||
patch('vllm_ascend.quantization.method_adapters.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
|
||||
method = self.ascend_config.get_quant_method(
|
||||
fused_moe_layer, "moe_layer")
|
||||
self.assertIs(method, mock_ascend_moe.return_value)
|
||||
@@ -150,7 +154,7 @@ class TestAscendQuantConfig(TestBase):
|
||||
|
||||
# Test inconsistent fused layer shards
|
||||
bad_config = {"shard1.weight": "FLOAT", "shard2.weight": "INT8"}
|
||||
config = AscendQuantConfig(bad_config)
|
||||
config = AscendModelSlimConfig(bad_config)
|
||||
with self.assertRaises(ValueError):
|
||||
config.is_layer_skipped_ascend("fused_layer", fused_mapping)
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
import types
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.utils import (ASCEND_QUANTIZATION_METHOD_MAP,
|
||||
get_quant_method)
|
||||
|
||||
|
||||
class TestGetQuantMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.original_quantization_method_map = ASCEND_QUANTIZATION_METHOD_MAP.copy(
|
||||
)
|
||||
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
|
||||
for layer_type in layer_map.keys():
|
||||
ASCEND_QUANTIZATION_METHOD_MAP[quant_type][
|
||||
layer_type] = types.new_class(f"{quant_type}_{layer_type}")
|
||||
|
||||
def tearDown(self):
|
||||
# Restore original map
|
||||
ASCEND_QUANTIZATION_METHOD_MAP.clear()
|
||||
ASCEND_QUANTIZATION_METHOD_MAP.update(
|
||||
self.original_quantization_method_map)
|
||||
|
||||
def test_linear_quant_methods(self):
|
||||
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
|
||||
if "linear" in layer_map.keys():
|
||||
prefix = "linear_layer"
|
||||
cls = layer_map["linear"]
|
||||
method = get_quant_method({"linear_layer.weight": quant_type},
|
||||
prefix, "linear")
|
||||
self.assertIsInstance(method, cls)
|
||||
|
||||
def test_moe_quant_methods(self):
|
||||
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
|
||||
if "moe" in layer_map.keys():
|
||||
prefix = "layer"
|
||||
cls = layer_map["moe"]
|
||||
method = get_quant_method({"layer.weight": quant_type}, prefix,
|
||||
"moe")
|
||||
self.assertIsInstance(method, cls)
|
||||
|
||||
def test_invalid_layer_type(self):
|
||||
quant_description = {"linear_layer.weight": "W8A8"}
|
||||
with self.assertRaises(NotImplementedError):
|
||||
get_quant_method(quant_description, "linear_layer", "unsupported")
|
||||
|
||||
def test_invalid_quant_type(self):
|
||||
quant_description = {"linear_layer.weight": "UNKNOWN"}
|
||||
with self.assertRaises(NotImplementedError):
|
||||
get_quant_method(quant_description, "linear_layer", "linear")
|
||||
@@ -3,8 +3,9 @@ from unittest.mock import Mock, patch
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.w4a16 import (AscendW4A16FusedMoEMethod,
|
||||
pack_to_int32, unpack_from_int32)
|
||||
from vllm_ascend.quantization.methods.w4a16 import (AscendW4A16FusedMoEMethod,
|
||||
pack_to_int32,
|
||||
unpack_from_int32)
|
||||
|
||||
|
||||
class TestUnpackFromInt32(TestBase):
|
||||
@@ -42,7 +43,7 @@ class TestUnpackFromInt32(TestBase):
|
||||
class TestPackToInt32(TestBase):
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.quantization.w4a16.torch_npu.npu_convert_weight_to_int4pack"
|
||||
"vllm_ascend.quantization.methods.w4a16.torch_npu.npu_convert_weight_to_int4pack"
|
||||
)
|
||||
def test_pack_to_int32_int8(self, mock_npu_convert_weight_to_int4pack):
|
||||
mock_npu_convert_weight_to_int4pack.return_value = torch.zeros(
|
||||
@@ -57,7 +58,7 @@ class TestPackToInt32(TestBase):
|
||||
self.assertEqual(result.shape, torch.Size([2, 8, 4]))
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.quantization.w4a16.torch_npu.npu_convert_weight_to_int4pack"
|
||||
"vllm_ascend.quantization.methods.w4a16.torch_npu.npu_convert_weight_to_int4pack"
|
||||
)
|
||||
def test_pack_to_int32_int32(self, mock_npu_convert_weight_to_int4pack):
|
||||
|
||||
@@ -97,8 +98,8 @@ class TestAscendW4A16FusedMoEMethod(TestBase):
|
||||
output_size = 128
|
||||
group_size = 32
|
||||
|
||||
@patch("vllm_ascend.quantization.w4a16.get_ascend_config")
|
||||
@patch("vllm_ascend.quantization.w4a16.get_current_vllm_config")
|
||||
@patch("vllm_ascend.quantization.methods.w4a16.get_ascend_config")
|
||||
@patch("vllm_ascend.quantization.methods.w4a16.get_current_vllm_config")
|
||||
def setUp(self, mock_get_current_vllm_config, mock_get_ascend_config):
|
||||
mock_ascend_config = Mock()
|
||||
mock_ascend_config.eplb_config.dynamic_eplb = False
|
||||
@@ -218,7 +219,7 @@ class TestAscendW4A16FusedMoEMethod(TestBase):
|
||||
return layer
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.quantization.w4a16.torch_npu.npu_convert_weight_to_int4pack"
|
||||
"vllm_ascend.quantization.methods.w4a16.torch_npu.npu_convert_weight_to_int4pack"
|
||||
)
|
||||
def test_process_weights_after_loading_with_transpose(
|
||||
self, mock_npu_convert_weight_to_int4pack):
|
||||
|
||||
@@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm_ascend.quantization.w4a4_flatquant_dynamic import (
|
||||
from vllm_ascend.quantization.methods.w4a4_flatquant import (
|
||||
AscendW4A4FlatQuantDynamicLinearMethod, get_decompose_dim,
|
||||
pack_int4_weights)
|
||||
|
||||
@@ -33,7 +33,7 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
|
||||
self.assertEqual(get_decompose_dim(100), (10, 10))
|
||||
self.assertEqual(get_decompose_dim(99), (9, 11))
|
||||
|
||||
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu')
|
||||
@patch('vllm_ascend.quantization.methods.w4a4_flatquant.torch_npu')
|
||||
def test_pack_int4_weights_npu_success(self, mock_torch_npu):
|
||||
"""
|
||||
Tests weight packing using the mocked NPU kernel.
|
||||
@@ -119,7 +119,7 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
|
||||
x = torch.randn(batch_size, self.input_size, dtype=self.params_dtype)
|
||||
return layer, x, m, n
|
||||
|
||||
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu')
|
||||
@patch('vllm_ascend.quantization.methods.w4a4_flatquant.torch_npu')
|
||||
def test_apply_small_batch(self, mock_torch_npu):
|
||||
"""Tests the apply method with a batch size smaller than MAX_BATCH_SIZE."""
|
||||
batch_size = 128
|
||||
@@ -143,9 +143,9 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
|
||||
self.assertEqual(output.shape, (batch_size, self.output_size))
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.quantization.w4a4_flatquant_dynamic.KRONECKER_QUANT_MAX_BATCH_SIZE',
|
||||
'vllm_ascend.quantization.methods.w4a4_flatquant.KRONECKER_QUANT_MAX_BATCH_SIZE',
|
||||
10)
|
||||
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu')
|
||||
@patch('vllm_ascend.quantization.methods.w4a4_flatquant.torch_npu')
|
||||
def test_apply_large_batch(self, mock_torch_npu):
|
||||
"""Tests the apply method with a batch size larger than MAX_BATCH_SIZE."""
|
||||
batch_size = 25
|
||||
@@ -178,7 +178,7 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
|
||||
ValueError, "FlatQuant transform matrices dimension mismatch"):
|
||||
self.method.apply(layer, x)
|
||||
|
||||
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.pack_int4_weights')
|
||||
@patch('vllm_ascend.quantization.methods.w4a4_flatquant.pack_int4_weights')
|
||||
def test_process_weights_after_loading(self, mock_pack_weights):
|
||||
"""Tests weight processing after loading, without transpose."""
|
||||
layer = nn.Module()
|
||||
|
||||
@@ -3,14 +3,14 @@ from unittest.mock import Mock, patch
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.w4a8_dynamic import (
|
||||
from vllm_ascend.quantization.methods.w4a8 import (
|
||||
AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod)
|
||||
|
||||
|
||||
class TestAscendW4A8DynamicLinearMethod(TestBase):
|
||||
|
||||
@patch('vllm.distributed.get_tensor_model_parallel_world_size')
|
||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
|
||||
@patch('vllm_ascend.quantization.methods.w4a8.get_current_vllm_config')
|
||||
def setUp(self, mock_get_current_vllm_config, mock_get_tp_world_size):
|
||||
mock_get_tp_world_size.return_value = 1
|
||||
mock_vllm_config = Mock()
|
||||
@@ -127,10 +127,10 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||
output_size = 56
|
||||
group_size = 2
|
||||
|
||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ascend_config')
|
||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
|
||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group')
|
||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group')
|
||||
@patch('vllm_ascend.quantization.methods.w4a8.get_ascend_config')
|
||||
@patch('vllm_ascend.quantization.methods.w4a8.get_current_vllm_config')
|
||||
@patch('vllm_ascend.quantization.methods.w4a8.get_ep_group')
|
||||
@patch('vllm_ascend.quantization.methods.w4a8.get_mc2_group')
|
||||
@patch('torch.distributed.get_rank', return_value=0)
|
||||
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ep_group,
|
||||
get_current_vllm_config, mock_get_ascend_config):
|
||||
|
||||
@@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.w8a16 import AscendW8A16LinearMethod
|
||||
from vllm_ascend.quantization.methods.w8a16 import AscendW8A16LinearMethod
|
||||
|
||||
|
||||
class TestAscendW8A16LinearMethod(TestBase):
|
||||
|
||||
@@ -4,36 +4,10 @@ from unittest.mock import MagicMock, patch
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.w8a8 import (AscendW8A8LinearMethod,
|
||||
quant_per_tensor)
|
||||
from vllm_ascend.quantization.methods.w8a8_static import AscendW8A8LinearMethod
|
||||
from vllm_ascend.utils import AscendDeviceType
|
||||
|
||||
|
||||
class TestQuantPerTensor(TestBase):
|
||||
|
||||
@patch("torch_npu.npu_quantize")
|
||||
def test_quant_per_tensor(self, mock_npu_quantize):
|
||||
in_tensor = torch.randn(32, 128)
|
||||
input_scale = torch.tensor(0.1)
|
||||
input_offset = torch.tensor(0)
|
||||
|
||||
expected_output = torch.randint(-128, 127, (32, 128), dtype=torch.int8)
|
||||
mock_npu_quantize.return_value = expected_output
|
||||
|
||||
output = quant_per_tensor(in_tensor, input_scale, input_offset)
|
||||
|
||||
mock_npu_quantize.assert_called_once_with(
|
||||
in_tensor,
|
||||
input_scale,
|
||||
input_offset,
|
||||
torch.qint8,
|
||||
-1,
|
||||
False,
|
||||
)
|
||||
|
||||
self.assertTrue(torch.equal(output, expected_output))
|
||||
|
||||
|
||||
class TestAscendW8A8LinearMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
@@ -63,7 +37,9 @@ class TestAscendW8A8LinearMethod(TestBase):
|
||||
self.assertEqual(params['weight_scale'].shape, (10, 1))
|
||||
self.assertEqual(params['weight_offset'].shape, (10, 1))
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.get_weight_prefetch_method")
|
||||
@patch(
|
||||
"vllm_ascend.quantization.methods.w8a8_static.get_weight_prefetch_method"
|
||||
)
|
||||
@patch("torch.ops.vllm.quantize")
|
||||
@patch("torch_npu.npu_quant_matmul")
|
||||
def test_apply_with_x_not_int8(self, mock_npu_quant_matmul, mock_quantize,
|
||||
|
||||
@@ -3,7 +3,7 @@ from unittest.mock import Mock, patch
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.w8a8_dynamic import \
|
||||
from vllm_ascend.quantization.methods.w8a8_dynamic import \
|
||||
AscendW8A8DynamicFusedMoEMethod
|
||||
|
||||
|
||||
@@ -13,13 +13,13 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||
intermediate_size = 128
|
||||
|
||||
@patch("torch.distributed.get_rank")
|
||||
@patch("vllm_ascend.quantization.w8a8_dynamic.get_mc2_group")
|
||||
@patch("vllm_ascend.quantization.w8a8_dynamic.get_ascend_config")
|
||||
@patch("vllm_ascend.quantization.w8a8_dynamic.get_ep_group")
|
||||
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.get_mc2_group")
|
||||
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.get_ascend_config")
|
||||
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.get_ep_group")
|
||||
def setUp(self, mock_get_ep_group, mock_get_ascend_config,
|
||||
mock_get_mc2_group, mock_get_rank):
|
||||
with patch(
|
||||
'vllm_ascend.quantization.w8a8_dynamic.get_current_vllm_config'
|
||||
'vllm_ascend.quantization.methods.w8a8_dynamic.get_current_vllm_config'
|
||||
) as mock_get_current_vllm_config:
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.quant_config = Mock(
|
||||
|
||||
@@ -51,8 +51,9 @@ class TestNPUPlatform(TestBase):
|
||||
self.assertTrue(self.platform.is_sleep_mode_available())
|
||||
|
||||
@patch("vllm_ascend.utils.adapt_patch")
|
||||
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
|
||||
def test_pre_register_and_update_with_parser(self, mock_quant_config, mock_adapt_patch):
|
||||
@patch("vllm_ascend.quantization.modelslim_config.AscendModelSlimConfig")
|
||||
def test_pre_register_and_update_with_parser(self, mock_quant_config,
|
||||
mock_adapt_patch):
|
||||
mock_parser = MagicMock()
|
||||
mock_action = MagicMock()
|
||||
mock_action.choices = ["awq", "gptq"]
|
||||
@@ -66,15 +67,17 @@ class TestNPUPlatform(TestBase):
|
||||
self.assertEqual(len(mock_action.choices), 3) # original 2 + ascend
|
||||
|
||||
@patch("vllm_ascend.utils.adapt_patch")
|
||||
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
|
||||
def test_pre_register_and_update_without_parser(self, mock_quant_config, mock_adapt_patch):
|
||||
@patch("vllm_ascend.quantization.modelslim_config.AscendModelSlimConfig")
|
||||
def test_pre_register_and_update_without_parser(self, mock_quant_config,
|
||||
mock_adapt_patch):
|
||||
self.platform.pre_register_and_update(None)
|
||||
|
||||
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
|
||||
|
||||
@patch("vllm_ascend.utils.adapt_patch")
|
||||
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
|
||||
def test_pre_register_and_update_with_parser_no_quant_action(self, mock_quant_config, mock_adapt_patch):
|
||||
@patch("vllm_ascend.quantization.modelslim_config.AscendModelSlimConfig")
|
||||
def test_pre_register_and_update_with_parser_no_quant_action(
|
||||
self, mock_quant_config, mock_adapt_patch):
|
||||
mock_parser = MagicMock()
|
||||
mock_parser._option_string_actions = {}
|
||||
|
||||
@@ -83,8 +86,9 @@ class TestNPUPlatform(TestBase):
|
||||
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
|
||||
|
||||
@patch("vllm_ascend.utils.adapt_patch")
|
||||
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
|
||||
def test_pre_register_and_update_with_existing_ascend_quant(self, mock_quant_config, mock_adapt_patch):
|
||||
@patch("vllm_ascend.quantization.modelslim_config.AscendModelSlimConfig")
|
||||
def test_pre_register_and_update_with_existing_ascend_quant(
|
||||
self, mock_quant_config, mock_adapt_patch):
|
||||
mock_parser = MagicMock()
|
||||
mock_action = MagicMock()
|
||||
mock_action.choices = ["awq", ASCEND_QUANTIZATION_METHOD]
|
||||
|
||||
@@ -36,7 +36,7 @@ from vllm_ascend.ops.layer_shard_linear import (
|
||||
register_all_layers_to_shard_weight_series)
|
||||
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, maybe_trans_nz,
|
||||
weak_ref_tensors)
|
||||
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
||||
|
||||
@@ -35,7 +35,7 @@ from vllm_ascend.ops.layer_shard_linear import (
|
||||
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
||||
from vllm_ascend.ops.triton.rope import rope_forward_triton
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer,
|
||||
enable_dsa_cp, enable_dsa_cp_with_layer_shard, maybe_trans_nz)
|
||||
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
||||
|
||||
@@ -43,10 +43,6 @@ from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
|
||||
FusedExpertsResult,
|
||||
setup_moe_comm_method)
|
||||
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
|
||||
from vllm_ascend.quantization.w4a8_dynamic import \
|
||||
AscendW4A8DynamicFusedMoEMethod
|
||||
from vllm_ascend.quantization.w8a8_dynamic import \
|
||||
AscendW8A8DynamicFusedMoEMethod
|
||||
from vllm_ascend.utils import (AscendDeviceType, enable_sp,
|
||||
get_ascend_device_type, maybe_trans_nz,
|
||||
npu_stream_switch, shared_expert_dp_enabled,
|
||||
@@ -251,12 +247,16 @@ class AscendFusedMoE(FusedMoE):
|
||||
|
||||
method = quant_method.quant_method
|
||||
|
||||
if isinstance(method, AscendW8A8DynamicFusedMoEMethod):
|
||||
return QuantType.W8A8
|
||||
elif isinstance(method, AscendW4A8DynamicFusedMoEMethod):
|
||||
return QuantType.W4A8
|
||||
else:
|
||||
return QuantType.NONE
|
||||
if hasattr(method, "quant_type"):
|
||||
from vllm_ascend.quantization.methods.base import \
|
||||
QuantType as SchemeQuantType
|
||||
scheme_quant_type = method.quant_type
|
||||
if scheme_quant_type == SchemeQuantType.W8A8:
|
||||
return QuantType.W8A8
|
||||
elif scheme_quant_type == SchemeQuantType.W4A8:
|
||||
return QuantType.W4A8
|
||||
|
||||
return QuantType.NONE
|
||||
|
||||
def update_expert_map(self, new_expert_map):
|
||||
self._expert_map = new_expert_map
|
||||
|
||||
@@ -368,7 +368,8 @@ class Flashcomm2OProjRowParallelOp(CustomRowParallelOp):
|
||||
"communication_fn"] = otp_maybe_quant_comm
|
||||
actual_quant_method = getattr(self.quant_method, 'quant_method',
|
||||
self.quant_method)
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
from vllm_ascend.quantization.methods.w8a8_static import \
|
||||
AscendW8A8LinearMethod
|
||||
if not isinstance(actual_quant_method, AscendW8A8LinearMethod):
|
||||
# Check if w8a8 quantization is enabled. If not, communicate immediately.
|
||||
input_parallel = otp_maybe_quant_comm(input_parallel)
|
||||
@@ -586,8 +587,8 @@ class SequenceRowParallelOp(CustomRowParallelOp):
|
||||
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
|
||||
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
|
||||
from vllm_ascend.quantization.method_adapters import AscendLinearMethod
|
||||
|
||||
# For unquant
|
||||
if mmrs_fusion and isinstance(self.layer.quant_method,
|
||||
|
||||
@@ -151,8 +151,7 @@ class NPUPlatform(Platform):
|
||||
if ASCEND_QUANTIZATION_METHOD not in quant_action.choices:
|
||||
quant_action.choices.append(ASCEND_QUANTIZATION_METHOD)
|
||||
|
||||
from vllm_ascend.quantization.compressed_tensors.compressed_tensors import AscendCompressedTensorsConfig # noqa: F401
|
||||
from vllm_ascend.quantization.quant_config import AscendQuantConfig # noqa: F401
|
||||
from vllm_ascend.quantization import AscendCompressedTensorsConfig, AscendModelSlimConfig # noqa: F401
|
||||
|
||||
config_deprecated_logging()
|
||||
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Ascend quantization module.
|
||||
|
||||
This module provides quantization support for Ascend NPU.
|
||||
|
||||
Supported quantization tools:
|
||||
- ModelSlim: Use AscendModelSlimConfig
|
||||
- LLM-Compressor (compressed_tensors): Use AscendCompressedTensorsConfig
|
||||
|
||||
Public API:
|
||||
- Config classes: AscendModelSlimConfig, AscendCompressedTensorsConfig
|
||||
- For scheme implementations, import from vllm_ascend.quantization.methods
|
||||
"""
|
||||
|
||||
# LLM-Compressor (compressed_tensors) quantization config
|
||||
from .compressed_tensors_config import AscendCompressedTensorsConfig
|
||||
# ModelSlim quantization config
|
||||
from .modelslim_config import AscendModelSlimConfig
|
||||
|
||||
__all__ = [
|
||||
"AscendModelSlimConfig",
|
||||
"AscendCompressedTensorsConfig",
|
||||
]
|
||||
|
||||
@@ -1,4 +1,23 @@
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
"""LLM-Compressor (compressed_tensors) quantization configuration for Ascend."""
|
||||
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import (QuantizationArgs,
|
||||
@@ -12,40 +31,37 @@ from vllm.model_executor.layers.quantization import (
|
||||
QUANTIZATION_METHODS, register_quantization_config)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import \
|
||||
CompressedTensorsScheme
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
find_matched_target, is_activation_quantization_format,
|
||||
should_ignore_layer)
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
|
||||
from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
|
||||
from vllm_ascend.quantization.quant_config import (AscendFusedMoEMethod,
|
||||
AscendLinearMethod,
|
||||
AscendQuantConfig)
|
||||
from vllm_ascend.quantization.w4a16 import AscendW4A16FusedMoEMethod
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
from vllm_ascend.quantization.w8a8_dynamic import (
|
||||
AscendW8A8DynamicFusedMoEMethod, AscendW8A8DynamicLinearMethod)
|
||||
from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from .methods import AscendLinearScheme, AscendMoEScheme
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QuantizationArgs]]]
|
||||
|
||||
|
||||
def remove_quantization_method():
|
||||
# Remove the original compressed_tensors method to replace with our implementation
|
||||
def _remove_quantization_method():
|
||||
if COMPRESSED_TENSORS_METHOD in QUANTIZATION_METHODS:
|
||||
QUANTIZATION_METHODS.remove(COMPRESSED_TENSORS_METHOD)
|
||||
|
||||
|
||||
remove_quantization_method()
|
||||
_remove_quantization_method()
|
||||
|
||||
QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str,
|
||||
"QuantizationArgs"]]]
|
||||
|
||||
|
||||
@register_quantization_config(COMPRESSED_TENSORS_METHOD)
|
||||
class AscendCompressedTensorsConfig(QuantizationConfig):
|
||||
"""Config class for LLM-Compressor (compressed_tensors) quantization on Ascend.
|
||||
|
||||
This class adapts the compressed_tensors format to work with Ascend's
|
||||
quantization implementations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -107,23 +123,16 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
|
||||
@classmethod
|
||||
def _quantization_scheme_map_from_config(
|
||||
cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE:
|
||||
"""
|
||||
"""Build target scheme map from config.
|
||||
|
||||
:param config: The `quantization_config` dictionary from config.json
|
||||
:return: A dictionary mapping target layer names to their corresponding
|
||||
quantization_args for weights and input activations
|
||||
"""
|
||||
|
||||
target_scheme_map: dict[str, Any] = dict()
|
||||
quant_format = cast(str, config.get("format"))
|
||||
|
||||
# The quant_config has multiple config_groups, each containing
|
||||
# an input_activations key with details about how the activations are
|
||||
# quantized, a weights key indicating how the weights are quantized,
|
||||
# and a list of targets under the `targets` key, dictating which
|
||||
# layers are impacted by the quantization details. The quantization
|
||||
# details follow the structure defined by the QuantizationArgs
|
||||
# pydantic model, which is used to verify the structure of the
|
||||
# quant_config and also store the details for later use.
|
||||
|
||||
config_groups = config.get("config_groups", dict())
|
||||
for _, quant_config in config_groups.items():
|
||||
targets = quant_config.get("targets")
|
||||
@@ -154,70 +163,102 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from .method_adapters import AscendFusedMoEMethod, AscendLinearMethod
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD
|
||||
# collect schemes
|
||||
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
# Get the scheme for this layer
|
||||
linear_scheme = self._get_linear_scheme(layer=layer,
|
||||
layer_name=prefix)
|
||||
|
||||
# Return unquantized method if no scheme found
|
||||
if linear_scheme is None:
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
# Store scheme on layer for reference (optional, for debugging)
|
||||
layer.scheme = linear_scheme
|
||||
logger.info_once(
|
||||
"Using the vLLM Ascend llmcompressor Quantization now!")
|
||||
return AscendLinearMethod(linear_scheme)
|
||||
|
||||
# choose quantization method
|
||||
quant_method = UnquantizedLinearMethod()
|
||||
if quant_scheme is not None:
|
||||
layer.scheme = quant_scheme
|
||||
ascend_quant_config = AscendQuantConfig(self.quant_description
|
||||
or {})
|
||||
quant_method = AscendLinearMethod(ascend_quant_config, prefix,
|
||||
None, layer)
|
||||
return quant_method
|
||||
if isinstance(layer, FusedMoE):
|
||||
self._add_fused_moe_to_target_scheme_map()
|
||||
unfused_names = [
|
||||
prefix + proj_name for proj_name in
|
||||
[".0.gate_proj", ".0.up_proj", ".0.down_proj"]
|
||||
]
|
||||
# TODO: refactor this to use expert_mapping and check all layer numbers
|
||||
all_scheme_dicts = [
|
||||
self.get_scheme_dict(layer, name) for name in unfused_names
|
||||
]
|
||||
scheme_dict = all_scheme_dicts.pop()
|
||||
# Delayed import to avoid circular import
|
||||
from vllm_ascend.ops.fused_moe.fused_moe import \
|
||||
AscendUnquantizedFusedMoEMethod
|
||||
|
||||
# multiple schemes found
|
||||
if not all(
|
||||
[cur_dict == scheme_dict for cur_dict in all_scheme_dicts]):
|
||||
raise ValueError("All MoE projections need to have same "
|
||||
"quantization scheme but found multiple")
|
||||
layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD
|
||||
# Get the scheme for this layer
|
||||
moe_scheme = self._get_moe_scheme(layer=layer, layer_name=prefix)
|
||||
|
||||
if scheme_dict is None:
|
||||
# Return unquantized method if no scheme found
|
||||
if moe_scheme is None:
|
||||
return AscendUnquantizedFusedMoEMethod(layer.moe_config)
|
||||
|
||||
weight_quant = scheme_dict.get("weights")
|
||||
input_quant = scheme_dict.get("input_activations")
|
||||
# Store scheme on layer for reference (optional, for debugging)
|
||||
layer.scheme = moe_scheme
|
||||
logger.info_once(
|
||||
"Using the vLLM Ascend llmcompressor Quantization now!")
|
||||
return AscendFusedMoEMethod(moe_scheme, layer.moe_config)
|
||||
|
||||
quant_scheme = None
|
||||
act_quant_format = is_activation_quantization_format(self.quant_format)
|
||||
if act_quant_format:
|
||||
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||
quant_scheme = AscendW8A8DynamicFusedMoEMethod()
|
||||
else:
|
||||
if self._is_w4a16(weight_quant, input_quant):
|
||||
quant_scheme = AscendW4A16FusedMoEMethod()
|
||||
if quant_scheme is None:
|
||||
raise RuntimeError(
|
||||
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}"
|
||||
)
|
||||
layer.scheme = quant_scheme
|
||||
layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD
|
||||
|
||||
ascend_quant_config = AscendQuantConfig(self.quant_description
|
||||
or {})
|
||||
return AscendFusedMoEMethod(ascend_quant_config, prefix,
|
||||
self.packed_modules_mapping, layer)
|
||||
return None
|
||||
|
||||
def get_scheme(self,
|
||||
layer: torch.nn.Module,
|
||||
layer_name: Optional[str] = None
|
||||
) -> Optional["CompressedTensorsScheme"]:
|
||||
def _get_linear_scheme(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer_name: Optional[str] = None) -> Optional[AscendLinearScheme]:
|
||||
"""Get the linear quantization scheme for a layer.
|
||||
|
||||
Returns:
|
||||
An AscendLinearScheme instance, or None if the layer
|
||||
should use unquantized method.
|
||||
"""
|
||||
weight_quant, input_quant, format = self._get_quant_args(
|
||||
layer, layer_name)
|
||||
if weight_quant is None:
|
||||
return None
|
||||
|
||||
scheme = self._create_scheme_for_layer_type(
|
||||
weight_quant=weight_quant,
|
||||
input_quant=input_quant,
|
||||
format=format,
|
||||
layer_type="linear",
|
||||
)
|
||||
return cast(AscendLinearScheme, scheme)
|
||||
|
||||
def _get_moe_scheme(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer_name: Optional[str] = None) -> Optional[AscendMoEScheme]:
|
||||
"""Get the MoE quantization scheme for a layer.
|
||||
|
||||
Returns:
|
||||
An AscendMoEScheme instance, or None if the layer
|
||||
should use unquantized method.
|
||||
"""
|
||||
# Add FusedMoE to target scheme map if needed
|
||||
self._add_fused_moe_to_target_scheme_map()
|
||||
|
||||
weight_quant, input_quant, format = self._get_quant_args(
|
||||
layer, layer_name)
|
||||
if weight_quant is None:
|
||||
return None
|
||||
|
||||
scheme = self._create_scheme_for_layer_type(
|
||||
weight_quant=weight_quant,
|
||||
input_quant=input_quant,
|
||||
format=format,
|
||||
layer_type="moe",
|
||||
)
|
||||
return cast(AscendMoEScheme, scheme)
|
||||
|
||||
def _get_quant_args(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
layer_name: Optional[str] = None
|
||||
) -> tuple[Optional["QuantizationArgs"], Optional["QuantizationArgs"],
|
||||
Optional[str]]:
|
||||
"""Extract quantization arguments for a layer.
|
||||
|
||||
compressed-tensors supports non uniform in the following way:
|
||||
|
||||
targets of config_groups: There can be N config_groups which each
|
||||
@@ -226,10 +267,12 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
|
||||
an nn.Module name.
|
||||
|
||||
Detect whether a layer_name is found in any target and
|
||||
use the quantization scheme corresponding to the matched target
|
||||
to select the CompressedTensorsScheme used for inference.
|
||||
use the quantization scheme corresponding to the matched target.
|
||||
|
||||
Returns:
|
||||
A tuple of (weight_quant, input_quant, format). weight_quant is
|
||||
None if the layer should use unquantized method.
|
||||
"""
|
||||
|
||||
scheme_dict = self.get_scheme_dict(layer, layer_name)
|
||||
weight_quant = None
|
||||
input_quant = None
|
||||
@@ -243,16 +286,8 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
|
||||
logger.warning_once("Acceleration for non-quantized schemes is "
|
||||
"not supported by Compressed Tensors. "
|
||||
"Falling back to UnquantizedLinearMethod")
|
||||
return None
|
||||
|
||||
else:
|
||||
# Find the quant_scheme
|
||||
scheme = self._get_scheme_from_parts(
|
||||
weight_quant=weight_quant,
|
||||
input_quant=input_quant,
|
||||
format=format,
|
||||
)
|
||||
return scheme
|
||||
return weight_quant, input_quant, format
|
||||
|
||||
def get_scheme_dict(
|
||||
self,
|
||||
@@ -288,28 +323,73 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
return None
|
||||
|
||||
def _get_scheme_from_parts(
|
||||
def _create_scheme_for_layer_type(
|
||||
self,
|
||||
weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs,
|
||||
format: str | None = None,
|
||||
) -> "CompressedTensorsScheme":
|
||||
weight_quant: "QuantizationArgs",
|
||||
input_quant: Optional["QuantizationArgs"],
|
||||
format: Optional[str],
|
||||
layer_type: str,
|
||||
) -> Union[AscendLinearScheme, AscendMoEScheme]:
|
||||
"""Create the appropriate Ascend scheme based on quantization args and layer type.
|
||||
|
||||
Args:
|
||||
weight_quant: Weight quantization arguments.
|
||||
input_quant: Input activation quantization arguments.
|
||||
format: Per-layer format, if defined.
|
||||
layer_type: Type of layer ("linear" or "moe").
|
||||
|
||||
Returns:
|
||||
An instance of the appropriate Ascend quantization scheme.
|
||||
"""
|
||||
from .methods import get_scheme_class
|
||||
|
||||
# Determine the quantization type
|
||||
quant_type = self._detect_quant_type(weight_quant, input_quant, format)
|
||||
|
||||
# Get the scheme class from registry
|
||||
scheme_cls = get_scheme_class(quant_type, layer_type)
|
||||
if scheme_cls is None:
|
||||
raise NotImplementedError(
|
||||
f"No compressed-tensors compatible scheme was found for "
|
||||
f"quant_type={quant_type}, layer_type={layer_type}.")
|
||||
|
||||
return scheme_cls()
|
||||
|
||||
def _detect_quant_type(
|
||||
self,
|
||||
weight_quant: "QuantizationArgs",
|
||||
input_quant: Optional["QuantizationArgs"],
|
||||
format: Optional[str],
|
||||
) -> str:
|
||||
"""Detect the quantization type from quantization arguments.
|
||||
|
||||
Args:
|
||||
weight_quant: Weight quantization arguments.
|
||||
input_quant: Input activation quantization arguments.
|
||||
format: Per-layer format, if defined.
|
||||
|
||||
Returns:
|
||||
A string representing the quantization type (e.g., "W8A8", "W8A8_DYNAMIC").
|
||||
"""
|
||||
# use the per-layer format if defined, otherwise, use global format
|
||||
format = format if format is not None else self.quant_format
|
||||
|
||||
act_quant_format = is_activation_quantization_format(format)
|
||||
|
||||
if act_quant_format and input_quant is not None:
|
||||
if self._is_static_tensor_w8a8(weight_quant, input_quant):
|
||||
return AscendW8A8LinearMethod()
|
||||
return "W8A8"
|
||||
|
||||
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||
return AscendW8A8DynamicLinearMethod()
|
||||
return "W8A8_DYNAMIC"
|
||||
|
||||
if self._is_w4a16(weight_quant, input_quant):
|
||||
return "W4A16"
|
||||
|
||||
raise NotImplementedError(
|
||||
"No compressed-tensors compatible scheme was found.")
|
||||
"No compressed-tensors compatible quantization type was found.")
|
||||
|
||||
def _is_static_tensor_w8a8(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs) -> bool:
|
||||
def _is_static_tensor_w8a8(self, weight_quant: "QuantizationArgs",
|
||||
input_quant: "QuantizationArgs") -> bool:
|
||||
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
||||
weight_strategy = (
|
||||
weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
|
||||
@@ -322,8 +402,8 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_8_bits and is_tensor and is_symmetric and is_static
|
||||
|
||||
def _is_dynamic_token_w8a8(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs) -> bool:
|
||||
def _is_dynamic_token_w8a8(self, weight_quant: "QuantizationArgs",
|
||||
input_quant: "QuantizationArgs") -> bool:
|
||||
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
||||
weight_strategy = (
|
||||
weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
|
||||
@@ -336,13 +416,13 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_8_bits and is_token and is_symmetric and is_dynamic
|
||||
|
||||
def _is_w4a16(self, weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs) -> bool:
|
||||
def _is_w4a16(self, weight_quant: "QuantizationArgs",
|
||||
input_quant: Optional["QuantizationArgs"]) -> bool:
|
||||
# Confirm weights quantized.
|
||||
if weight_quant is None:
|
||||
return False
|
||||
|
||||
# Confirm we have floating points.
|
||||
# Confirm we have integer type.
|
||||
if weight_quant.type != QuantizationType.INT:
|
||||
return False
|
||||
|
||||
288
vllm_ascend/quantization/method_adapters.py
Normal file
288
vllm_ascend/quantization/method_adapters.py
Normal file
@@ -0,0 +1,288 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.parameter import PerTensorScaleParameter
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import (get_flashcomm2_otp_group,
|
||||
get_mlp_tp_group,
|
||||
get_otp_group)
|
||||
from vllm_ascend.utils import flashcomm2_enable, mlp_tp_enable, oproj_tp_enable
|
||||
|
||||
from .methods import (AscendAttentionScheme, AscendLinearScheme,
|
||||
AscendMoEScheme, is_mx_quant_type)
|
||||
|
||||
|
||||
class AscendLinearMethod(LinearMethodBase):
|
||||
"""Linear method for Ascend quantization.
|
||||
|
||||
This wrapper class delegates to the actual quantization scheme implementation.
|
||||
The scheme is determined by the Config class and passed directly to this wrapper.
|
||||
|
||||
Args:
|
||||
scheme: The quantization scheme instance (e.g., AscendW8A8DynamicLinearMethod).
|
||||
"""
|
||||
|
||||
def __init__(self, scheme: AscendLinearScheme) -> None:
|
||||
self.quant_method = scheme
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
weight_dict = self.quant_method.get_weight(input_size_per_partition,
|
||||
output_size_per_partition,
|
||||
params_dtype)
|
||||
|
||||
# Extract packing information (if present)
|
||||
packed_dim = weight_dict.pop("_packed_dim", None)
|
||||
packed_factor = weight_dict.pop("_packed_factor", None)
|
||||
|
||||
for weight_name, weight_param in weight_dict.items():
|
||||
param = torch.nn.Parameter(weight_param, requires_grad=False)
|
||||
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
|
||||
|
||||
# Set packing attributes if the weight is packed
|
||||
if packed_dim is not None and packed_factor is not None:
|
||||
set_weight_attrs(param, {
|
||||
"packed_dim": packed_dim,
|
||||
"packed_factor": packed_factor
|
||||
})
|
||||
|
||||
layer.register_parameter(weight_name, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
|
||||
for pertensor_name, pertensor_param in pertensor_dict.items():
|
||||
param = PerTensorScaleParameter(data=pertensor_param,
|
||||
weight_loader=weight_loader)
|
||||
# disable warning
|
||||
param.ignore_warning = True
|
||||
layer.register_parameter(pertensor_name, param)
|
||||
param.weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
perchannel_dict = self.quant_method.get_perchannel_param(
|
||||
output_size_per_partition, params_dtype)
|
||||
for perchannel_name, perchannel_param in perchannel_dict.items():
|
||||
param = torch.nn.Parameter(perchannel_param, requires_grad=False)
|
||||
set_weight_attrs(param, {"output_dim": 0})
|
||||
layer.register_parameter(perchannel_name, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
# NOTE: In w4a8 quantization implementation,
|
||||
# for down_proj and o_proj scale_bias shape is [output_size, 16],
|
||||
# others are [output_size, 1]
|
||||
layer_type = "row" if isinstance(layer,
|
||||
RowParallelLinear) else "others"
|
||||
|
||||
pergroup_dict = self.quant_method.get_pergroup_param(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition,
|
||||
params_dtype,
|
||||
layer_type=layer_type)
|
||||
for pergroup_name, pergroup_param in pergroup_dict.items():
|
||||
param = torch.nn.Parameter(pergroup_param, requires_grad=False)
|
||||
set_weight_attrs(param, {"output_dim": 0})
|
||||
layer.register_parameter(pergroup_name, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name \
|
||||
or is_mx_quant_type(self.quant_method):
|
||||
setattr(param, "input_dim", 1)
|
||||
param.input_dim = 1
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(layer, RowParallelLinear):
|
||||
if layer.prefix.find("o_proj") != -1 and oproj_tp_enable():
|
||||
tp_rank = get_otp_group().rank_in_group
|
||||
elif layer.prefix.find("down_proj") != -1 and mlp_tp_enable():
|
||||
tp_rank = get_mlp_tp_group().rank_in_group
|
||||
elif (layer.prefix.find("o_proj") != -1 or
|
||||
layer.prefix.find("out_proj") != -1) and flashcomm2_enable():
|
||||
if get_ascend_config(
|
||||
).flashcomm2_oproj_tensor_parallel_size == 1:
|
||||
tp_rank = 0
|
||||
else:
|
||||
tp_rank = get_flashcomm2_otp_group().rank_in_group
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
else:
|
||||
tp_rank = 0
|
||||
return self.quant_method.apply(layer, x, bias, tp_rank)
|
||||
|
||||
|
||||
class AscendKVCacheMethod(BaseKVCacheMethod):
|
||||
"""KVCache method for Ascend quantization.
|
||||
|
||||
This wrapper class delegates to the actual attention quantization scheme.
|
||||
|
||||
Args:
|
||||
scheme: The attention quantization scheme instance.
|
||||
"""
|
||||
|
||||
def __init__(self, scheme: AscendAttentionScheme) -> None:
|
||||
self.quant_method = scheme
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module) -> None:
|
||||
# Different from linear method, there are no weight processing/slicing
|
||||
# steps for attention in vllm. So the whole process of create weights
|
||||
# is hidden into the specific quant method.
|
||||
self.quant_method.create_weights(layer)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata,
|
||||
attn_type, scale, output) -> torch.Tensor:
|
||||
return self.quant_method.apply(layer, query, key, value, kv_cache,
|
||||
attn_metadata, attn_type, scale, output)
|
||||
|
||||
|
||||
class AscendFusedMoEMethod(FusedMoEMethodBase):
|
||||
"""FusedMoE method for Ascend quantization.
|
||||
|
||||
This wrapper class delegates to the actual MoE quantization scheme.
|
||||
|
||||
Args:
|
||||
scheme: The MoE quantization scheme instance.
|
||||
moe_config: The FusedMoE configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, scheme: AscendMoEScheme,
|
||||
moe_config: FusedMoEConfig) -> None:
|
||||
super().__init__(moe_config)
|
||||
self.quant_method = scheme
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
weight_param = self.quant_method.get_weight(
|
||||
num_experts, intermediate_size_per_partition, hidden_size,
|
||||
params_dtype)
|
||||
for param_key, param_value in weight_param.items():
|
||||
param = torch.nn.Parameter(param_value, requires_grad=False)
|
||||
layer.register_parameter(param_key, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
|
||||
per_group_param = [
|
||||
"weight_scale_second", "weight_offset_second", "scale_bias"
|
||||
] + ["weight_scale", "weight_offset"] if hasattr(
|
||||
self.quant_method,
|
||||
"group_size") and self.quant_method.group_size > 0 else []
|
||||
dynamic_quant_param = self.quant_method.get_dynamic_quant_param(
|
||||
num_experts, intermediate_size_per_partition, hidden_size,
|
||||
params_dtype)
|
||||
for param_key, param_value in dynamic_quant_param.items():
|
||||
param = torch.nn.Parameter(param_value, requires_grad=False)
|
||||
layer.register_parameter(param_key, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
if any(fields in param_key for fields in per_group_param):
|
||||
setattr(param, "quant_method",
|
||||
FusedMoeWeightScaleSupported.GROUP.value)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = False,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num=0,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.quant_method.apply(
|
||||
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
|
||||
global_num_experts, expert_map, topk_group, num_expert_group,
|
||||
custom_routing_function, scoring_func, routed_scaling_factor,
|
||||
e_score_correction_bias, is_prefill, enable_force_load_balance,
|
||||
log2phy, global_redundant_expert_num, **kwargs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module):
|
||||
pass
|
||||
|
||||
@property
|
||||
def supports_eplb(self):
|
||||
supports_eplb = getattr(self.quant_method, "supports_eplb", False)
|
||||
return supports_eplb
|
||||
|
||||
|
||||
class AscendEmbeddingMethod(AscendLinearMethod):
|
||||
"""Embedding method for Ascend quantization.
|
||||
|
||||
This is essentially the same as AscendLinearMethod, just with a different name
|
||||
for clarity when used with VocabParallelEmbedding layers.
|
||||
|
||||
Args:
|
||||
scheme: The quantization scheme instance.
|
||||
"""
|
||||
|
||||
def __init__(self, scheme: AscendLinearScheme) -> None:
|
||||
self.quant_method = scheme
|
||||
82
vllm_ascend/quantization/methods/__init__.py
Normal file
82
vllm_ascend/quantization/methods/__init__.py
Normal file
@@ -0,0 +1,82 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Ascend quantization scheme implementations.
|
||||
|
||||
This module provides all quantization scheme implementations for Ascend NPU.
|
||||
Schemes are automatically registered via the @register_scheme decorator.
|
||||
|
||||
Usage:
|
||||
from vllm_ascend.quantization.methods import get_scheme_class
|
||||
|
||||
# Get a scheme class by quant_type and layer_type
|
||||
scheme_cls = get_scheme_class("W8A8_DYNAMIC", "linear")
|
||||
scheme = scheme_cls()
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
# Import base classes
|
||||
from .base import (AscendAttentionScheme, AscendLinearScheme, AscendMoEScheme,
|
||||
QuantType)
|
||||
# Import registry functions
|
||||
from .registry import get_scheme_class, register_scheme
|
||||
# Import all scheme classes for external access
|
||||
from .w4a4_flatquant import AscendW4A4FlatQuantDynamicLinearMethod
|
||||
from .w4a4_laos_dynamic import AscendW4A4LaosDynamicLinearMethod
|
||||
from .w4a8 import (AscendW4A8DynamicFusedMoEMethod,
|
||||
AscendW4A8DynamicLinearMethod)
|
||||
from .w4a16 import AscendW4A16FusedMoEMethod
|
||||
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
|
||||
AscendW8A8DynamicLinearMethod)
|
||||
from .w8a8_mxfp8 import AscendW8A8MXFP8DynamicLinearMethod
|
||||
from .w8a8_pdmix import (AscendW8A8PDMixFusedMoeMethod,
|
||||
AscendW8A8PDMixLinearMethod)
|
||||
from .w8a8_static import AscendW8A8LinearMethod
|
||||
from .w8a16 import AscendW8A16LinearMethod
|
||||
|
||||
|
||||
def is_mx_quant_type(instance: Any) -> bool:
|
||||
"""Checks if the quantization method is a microscaling (MX) type."""
|
||||
MX_QUANT_TYPES = (AscendW8A8MXFP8DynamicLinearMethod, )
|
||||
return isinstance(instance, MX_QUANT_TYPES)
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Base classes
|
||||
"AscendAttentionScheme",
|
||||
"AscendLinearScheme",
|
||||
"AscendMoEScheme",
|
||||
"QuantType",
|
||||
# Registry functions
|
||||
"register_scheme",
|
||||
"get_scheme_class",
|
||||
# Utility functions
|
||||
"is_mx_quant_type",
|
||||
# Scheme classes
|
||||
"AscendW8A8LinearMethod",
|
||||
"AscendW8A8DynamicLinearMethod",
|
||||
"AscendW8A8DynamicFusedMoEMethod",
|
||||
"AscendW8A8MXFP8DynamicLinearMethod",
|
||||
"AscendW8A8PDMixLinearMethod",
|
||||
"AscendW8A8PDMixFusedMoeMethod",
|
||||
"AscendW8A16LinearMethod",
|
||||
"AscendW4A8DynamicLinearMethod",
|
||||
"AscendW4A8DynamicFusedMoEMethod",
|
||||
"AscendW4A16FusedMoEMethod",
|
||||
"AscendW4A4FlatQuantDynamicLinearMethod",
|
||||
"AscendW4A4LaosDynamicLinearMethod",
|
||||
]
|
||||
279
vllm_ascend/quantization/methods/base.py
Normal file
279
vllm_ascend/quantization/methods/base.py
Normal file
@@ -0,0 +1,279 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Abstract base classes for Ascend quantization schemes."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class QuantType(Enum):
|
||||
"""Quantization type enum for MoE schemes."""
|
||||
NONE = 0
|
||||
W8A8 = 1
|
||||
W4A8 = 2
|
||||
|
||||
|
||||
class AscendLinearScheme(ABC):
|
||||
"""Base class for all linear quantization schemes.
|
||||
|
||||
Subclasses must implement get_weight() and apply() methods.
|
||||
Other methods have default implementations that return empty dicts
|
||||
or do nothing.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_weight(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
"""Return weight tensor specifications.
|
||||
|
||||
Args:
|
||||
input_size: Input dimension of the linear layer.
|
||||
output_size: Output dimension of the linear layer.
|
||||
params_dtype: Data type for parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping parameter names to empty tensors with
|
||||
the correct shape and dtype.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
"""Return per-tensor parameter specifications (e.g., input_scale).
|
||||
|
||||
Args:
|
||||
params_dtype: Data type for parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping parameter names to empty tensors.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_perchannel_param(self, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
"""Return per-channel parameter specifications (e.g., weight_scale).
|
||||
|
||||
Args:
|
||||
output_size: Output dimension of the linear layer.
|
||||
params_dtype: Data type for parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping parameter names to empty tensors.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_pergroup_param(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
layer_type: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Return per-group parameter specifications.
|
||||
|
||||
Args:
|
||||
input_size: Input dimension of the linear layer.
|
||||
output_size: Output dimension of the linear layer.
|
||||
params_dtype: Data type for parameters.
|
||||
layer_type: Type of layer (e.g., "row" for RowParallelLinear).
|
||||
|
||||
Returns:
|
||||
Dictionary mapping parameter names to empty tensors.
|
||||
"""
|
||||
return {}
|
||||
|
||||
@abstractmethod
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = 0) -> torch.Tensor:
|
||||
"""Forward computation.
|
||||
|
||||
Args:
|
||||
layer: The linear layer module.
|
||||
x: Input tensor.
|
||||
bias: Optional bias tensor.
|
||||
tp_rank: Tensor parallel rank.
|
||||
|
||||
Returns:
|
||||
Output tensor after quantized linear operation.
|
||||
"""
|
||||
...
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
"""Post-loading weight processing (transpose, format conversion, etc.).
|
||||
|
||||
Args:
|
||||
layer: The linear layer module.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class AscendAttentionScheme(ABC):
|
||||
"""Base class for all attention quantization schemes.
|
||||
|
||||
Subclasses must implement apply() method.
|
||||
Other methods have default implementations.
|
||||
"""
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module) -> None:
|
||||
"""Create weights for attention quantization.
|
||||
|
||||
Args:
|
||||
layer: The attention layer module.
|
||||
"""
|
||||
pass
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
"""Post-loading weight processing for attention layer.
|
||||
|
||||
Args:
|
||||
layer: The attention layer module.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata,
|
||||
attn_type, scale, output) -> torch.Tensor:
|
||||
"""Forward computation for attention layer.
|
||||
|
||||
Args:
|
||||
layer: The attention layer module.
|
||||
query: Query tensor.
|
||||
key: Key tensor.
|
||||
value: Value tensor.
|
||||
kv_cache: KV cache.
|
||||
attn_metadata: Attention metadata.
|
||||
attn_type: Attention type.
|
||||
scale: Scale factor.
|
||||
output: Output tensor.
|
||||
|
||||
Returns:
|
||||
Output tensor after attention computation.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class AscendMoEScheme(ABC):
|
||||
"""Base class for all MoE quantization schemes.
|
||||
|
||||
Subclasses must implement get_weight(), get_dynamic_quant_param(),
|
||||
and apply() methods.
|
||||
|
||||
Attributes:
|
||||
quant_type: The quantization type for this scheme. Subclasses should
|
||||
override this class attribute to declare their quant type.
|
||||
"""
|
||||
|
||||
# Default quant type - subclasses should override this
|
||||
quant_type: QuantType = QuantType.NONE
|
||||
|
||||
@abstractmethod
|
||||
def get_weight(self, num_experts: int,
|
||||
intermediate_size_per_partition: int, hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
"""Return weight tensor specifications for MoE layer.
|
||||
|
||||
Args:
|
||||
num_experts: Number of experts.
|
||||
intermediate_size_per_partition: Intermediate size per partition.
|
||||
hidden_sizes: Hidden dimension size.
|
||||
params_dtype: Data type for parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping parameter names to empty tensors.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_dynamic_quant_param(self, num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
"""Return dynamic quantization parameters for MoE layer.
|
||||
|
||||
Args:
|
||||
num_experts: Number of experts.
|
||||
intermediate_size_per_partition: Intermediate size per partition.
|
||||
hidden_sizes: Hidden dimension size.
|
||||
params_dtype: Data type for parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping parameter names to empty tensors.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = False,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Forward computation for MoE layer.
|
||||
|
||||
Args:
|
||||
layer: The MoE layer module.
|
||||
x: Input hidden states.
|
||||
router_logits: Router logits for expert selection.
|
||||
top_k: Number of experts to select per token.
|
||||
renormalize: Whether to renormalize expert weights.
|
||||
use_grouped_topk: Whether to use grouped top-k selection.
|
||||
global_num_experts: Total number of experts globally.
|
||||
expert_map: Mapping from local to global expert indices.
|
||||
topk_group: Group size for grouped top-k.
|
||||
num_expert_group: Number of expert groups.
|
||||
custom_routing_function: Custom routing function.
|
||||
scoring_func: Scoring function name.
|
||||
routed_scaling_factor: Scaling factor for routed experts.
|
||||
e_score_correction_bias: Expert score correction bias.
|
||||
is_prefill: Whether in prefill phase.
|
||||
enable_force_load_balance: Whether to force load balancing.
|
||||
log2phy: Logical to physical expert mapping.
|
||||
global_redundant_expert_num: Number of redundant experts.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
Output tensor after MoE computation.
|
||||
"""
|
||||
...
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
"""Post-loading weight processing for MoE layer.
|
||||
|
||||
Args:
|
||||
layer: The MoE layer module.
|
||||
"""
|
||||
pass
|
||||
62
vllm_ascend/quantization/methods/registry.py
Normal file
62
vllm_ascend/quantization/methods/registry.py
Normal file
@@ -0,0 +1,62 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Type
|
||||
|
||||
# Registry: maps (quant_type, layer_type) -> SchemeClass
|
||||
_SCHEME_REGISTRY: Dict[Tuple[str, str], Type[Any]] = {}
|
||||
|
||||
|
||||
def register_scheme(quant_type: str, layer_type: str):
|
||||
"""Decorator to register a quantization scheme.
|
||||
|
||||
Args:
|
||||
quant_type: Quantization type (e.g., "W8A8", "W8A8_DYNAMIC").
|
||||
layer_type: Layer type (e.g., "linear", "moe").
|
||||
|
||||
Returns:
|
||||
Decorator function that registers the class.
|
||||
|
||||
Example:
|
||||
@register_scheme("W8A8_DYNAMIC", "linear")
|
||||
class W8A8DynamicLinearScheme(AscendLinearScheme):
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(cls: Type[Any]) -> Type[Any]:
|
||||
key = (quant_type, layer_type)
|
||||
if key in _SCHEME_REGISTRY:
|
||||
raise ValueError(
|
||||
f"Scheme already registered for {quant_type}/{layer_type}: "
|
||||
f"{_SCHEME_REGISTRY[key].__name__}")
|
||||
_SCHEME_REGISTRY[key] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_scheme_class(quant_type: str, layer_type: str) -> Optional[Type[Any]]:
|
||||
"""Get scheme class for given quant_type and layer_type.
|
||||
|
||||
Args:
|
||||
quant_type: Quantization type (e.g., "W8A8", "W8A8_DYNAMIC").
|
||||
layer_type: Layer type (e.g., "linear", "moe").
|
||||
|
||||
Returns:
|
||||
The registered scheme class, or None if not found.
|
||||
"""
|
||||
return _SCHEME_REGISTRY.get((quant_type, layer_type))
|
||||
@@ -1,277 +1,278 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
|
||||
|
||||
def unpack_from_int32(
|
||||
weight: torch.Tensor,
|
||||
shape: torch.Size,
|
||||
num_bits: int,
|
||||
packed_dim: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Unpacks quantized weights from int32 format back to original bits.
|
||||
|
||||
:param weight: The packed int32 tensor containing quantized weights
|
||||
:param shape: Original shape to restore, defaults to None
|
||||
:param num_bits: The number of bits used for quantization (<= 8)
|
||||
:param packed_dim: Dimension along which weights are packed (0 or 1), defaults to 1
|
||||
:return: Unpacked tensor with int8 dtype after applying offset correction
|
||||
"""
|
||||
assert weight.dtype == torch.int32, f"Expecting `weight.dtype` is torch.int32 but got {weight.dtype}."
|
||||
assert num_bits <= 8, f"Expecting `num_bits` should not be larger than 8 but got {num_bits}."
|
||||
|
||||
pack_factor = 32 // num_bits
|
||||
mask = (1 << num_bits) - 1
|
||||
|
||||
if packed_dim == 1:
|
||||
unpacked_weight = torch.zeros(
|
||||
(weight.shape[0], weight.shape[1] * pack_factor),
|
||||
device=weight.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
for i in range(pack_factor):
|
||||
unpacked_weight[:, i::pack_factor] = (weight >>
|
||||
(num_bits * i)) & mask
|
||||
original_row_size = int(shape[1])
|
||||
unpacked_weight = unpacked_weight[:, :original_row_size]
|
||||
else:
|
||||
unpacked_weight = torch.zeros(
|
||||
(weight.shape[0] * pack_factor, weight.shape[1]),
|
||||
device=weight.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
for i in range(pack_factor):
|
||||
unpacked_weight[i::pack_factor, :] = (weight >>
|
||||
(num_bits * i)) & mask
|
||||
original_row_size = int(shape[0])
|
||||
unpacked_weight = unpacked_weight[:original_row_size, :]
|
||||
|
||||
offset = pow(2, num_bits) // 2
|
||||
unpacked_weight = (unpacked_weight - offset).to(torch.int8)
|
||||
|
||||
return unpacked_weight
|
||||
|
||||
|
||||
def pack_to_int32(weight: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Packs quantized weights into int32 format for storage.
|
||||
|
||||
:param weight: The 3D tensor to pack, must be int8 or int32 dtype
|
||||
:return: Packed tensor with int32 dtype optimized for storage
|
||||
"""
|
||||
assert weight.dim(
|
||||
) == 3, f"Expecting `weight.dim()` is 3 ([e, n, k] or [e, k, n]) but got {weight.dim()}."
|
||||
assert weight.dtype in [
|
||||
torch.int8, torch.int32
|
||||
], f"Expecting `weight.dtype` is torch.int8 or torch.int32 bug got {weight.dtype}."
|
||||
|
||||
if weight.dtype == torch.int32:
|
||||
assert weight.shape[
|
||||
-1] % 8 == 0, "the last dim of weight needs to be divided by 8."
|
||||
packed_weight = torch_npu.npu_convert_weight_to_int4pack(
|
||||
weight.flatten(0, 1))
|
||||
packed_weight = packed_weight.view(weight.shape[0], weight.shape[1],
|
||||
-1)
|
||||
else:
|
||||
assert weight.shape[
|
||||
-1] % 4 == 0, "the last dim of weight needs to be divided by 4."
|
||||
packed_weight = weight.view(torch.int32).contiguous()
|
||||
|
||||
return packed_weight
|
||||
|
||||
|
||||
class AscendW4A16FusedMoEMethod:
|
||||
"""FusedMoe method for Ascend W4A16.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.transpose_weight = True
|
||||
self.num_bits = 4 # dtype = torch.int4
|
||||
self.pack_factor = 8 # pack 8 of torch.int4 tensors to torch.int32
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.group_size = vllm_config.quant_config.quant_description.get(
|
||||
"group_size", 32)
|
||||
self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb
|
||||
|
||||
def get_weight(
|
||||
self,
|
||||
num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
assert intermediate_size_per_partition % self.pack_factor == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `pack_factor` {self.pack_factor}"
|
||||
assert hidden_sizes % self.pack_factor == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `pack_factor` {self.pack_factor}"
|
||||
|
||||
param_dict = {}
|
||||
|
||||
param_dict["w13_weight_packed"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.pack_factor,
|
||||
dtype=torch.int32)
|
||||
param_dict["w2_weight_packed"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.pack_factor,
|
||||
dtype=torch.int32)
|
||||
|
||||
return param_dict
|
||||
|
||||
def get_dynamic_quant_param(
|
||||
self,
|
||||
num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
assert intermediate_size_per_partition % self.group_size == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `group_size` {self.group_size}"
|
||||
assert hidden_sizes % self.group_size == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `group_size` {self.group_size}"
|
||||
|
||||
param_dict = {}
|
||||
|
||||
param_dict["w13_weight_scale"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=torch.bfloat16)
|
||||
param_dict["w2_weight_scale"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=torch.bfloat16)
|
||||
param_dict["w13_weight_shape"] = torch.empty(num_experts,
|
||||
2,
|
||||
dtype=torch.int32)
|
||||
param_dict["w2_weight_shape"] = torch.empty(num_experts,
|
||||
2,
|
||||
dtype=torch.int32)
|
||||
param_dict["w13_weight_offset"] = torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=torch.bfloat16)
|
||||
param_dict["w2_weight_offset"] = torch.zeros(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
return param_dict
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = True,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
return moe_comm_method.fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight_packed,
|
||||
w2=layer.w2_weight_packed,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_offset=layer.w13_weight_offset,
|
||||
w2_offset=layer.w2_weight_offset,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_int4_w4a16=True,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
dynamic_eplb=self.dynamic_eplb,
|
||||
mc2_mask=kwargs.get(
|
||||
"mc2_mask", None))
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if self.transpose_weight:
|
||||
w13_shape = layer.w13_weight_packed.data.shape
|
||||
w2_shape = layer.w2_weight_packed.data.shape
|
||||
unpacked_w13_weight = (unpack_from_int32(
|
||||
layer.w13_weight_packed.data.flatten(0, 1),
|
||||
torch.Size([
|
||||
w13_shape[0] * w13_shape[1],
|
||||
w13_shape[2] * self.pack_factor
|
||||
]),
|
||||
self.num_bits,
|
||||
).view(w13_shape[0], w13_shape[1],
|
||||
-1).transpose(1, 2).contiguous().int())
|
||||
unpacked_w2_weight = (unpack_from_int32(
|
||||
layer.w2_weight_packed.data.flatten(0, 1),
|
||||
torch.Size([
|
||||
w2_shape[0] * w2_shape[1], w2_shape[2] * self.pack_factor
|
||||
]),
|
||||
self.num_bits,
|
||||
).view(w2_shape[0], w2_shape[1],
|
||||
-1).transpose(1, 2).contiguous().int())
|
||||
layer.w13_weight_packed.data = pack_to_int32(unpacked_w13_weight)
|
||||
layer.w2_weight_packed.data = pack_to_int32(unpacked_w2_weight)
|
||||
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(
|
||||
1, 2).contiguous()
|
||||
|
||||
layer.w13_weight_offset.data = layer.w13_weight_offset.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight_offset.data = layer.w2_weight_offset.data.transpose(
|
||||
1, 2).contiguous()
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
|
||||
from .base import AscendMoEScheme
|
||||
from .registry import register_scheme
|
||||
|
||||
|
||||
def unpack_from_int32(
|
||||
weight: torch.Tensor,
|
||||
shape: torch.Size,
|
||||
num_bits: int,
|
||||
packed_dim: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""Unpacks quantized weights from int32 format back to original bits.
|
||||
|
||||
:param weight: The packed int32 tensor containing quantized weights
|
||||
:param shape: Original shape to restore, defaults to None
|
||||
:param num_bits: The number of bits used for quantization (<= 8)
|
||||
:param packed_dim: Dimension along which weights are packed (0 or 1), defaults to 1
|
||||
:return: Unpacked tensor with int8 dtype after applying offset correction
|
||||
"""
|
||||
assert weight.dtype == torch.int32, f"Expecting `weight.dtype` is torch.int32 but got {weight.dtype}."
|
||||
assert num_bits <= 8, f"Expecting `num_bits` should not be larger than 8 but got {num_bits}."
|
||||
|
||||
pack_factor = 32 // num_bits
|
||||
mask = (1 << num_bits) - 1
|
||||
|
||||
if packed_dim == 1:
|
||||
unpacked_weight = torch.zeros(
|
||||
(weight.shape[0], weight.shape[1] * pack_factor),
|
||||
device=weight.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
for i in range(pack_factor):
|
||||
unpacked_weight[:, i::pack_factor] = (weight >>
|
||||
(num_bits * i)) & mask
|
||||
original_row_size = int(shape[1])
|
||||
unpacked_weight = unpacked_weight[:, :original_row_size]
|
||||
else:
|
||||
unpacked_weight = torch.zeros(
|
||||
(weight.shape[0] * pack_factor, weight.shape[1]),
|
||||
device=weight.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
for i in range(pack_factor):
|
||||
unpacked_weight[i::pack_factor, :] = (weight >>
|
||||
(num_bits * i)) & mask
|
||||
original_row_size = int(shape[0])
|
||||
unpacked_weight = unpacked_weight[:original_row_size, :]
|
||||
|
||||
offset = pow(2, num_bits) // 2
|
||||
unpacked_weight = (unpacked_weight - offset).to(torch.int8)
|
||||
|
||||
return unpacked_weight
|
||||
|
||||
|
||||
def pack_to_int32(weight: torch.Tensor) -> torch.Tensor:
|
||||
"""Packs quantized weights into int32 format for storage.
|
||||
|
||||
:param weight: The 3D tensor to pack, must be int8 or int32 dtype
|
||||
:return: Packed tensor with int32 dtype optimized for storage
|
||||
"""
|
||||
assert weight.dim(
|
||||
) == 3, f"Expecting `weight.dim()` is 3 ([e, n, k] or [e, k, n]) but got {weight.dim()}."
|
||||
assert weight.dtype in [
|
||||
torch.int8, torch.int32
|
||||
], f"Expecting `weight.dtype` is torch.int8 or torch.int32 bug got {weight.dtype}."
|
||||
|
||||
if weight.dtype == torch.int32:
|
||||
assert weight.shape[
|
||||
-1] % 8 == 0, "the last dim of weight needs to be divided by 8."
|
||||
packed_weight = torch_npu.npu_convert_weight_to_int4pack(
|
||||
weight.flatten(0, 1))
|
||||
packed_weight = packed_weight.view(weight.shape[0], weight.shape[1],
|
||||
-1)
|
||||
else:
|
||||
assert weight.shape[
|
||||
-1] % 4 == 0, "the last dim of weight needs to be divided by 4."
|
||||
packed_weight = weight.view(torch.int32).contiguous()
|
||||
|
||||
return packed_weight
|
||||
|
||||
|
||||
@register_scheme("W4A16", "moe")
|
||||
class AscendW4A16FusedMoEMethod(AscendMoEScheme):
|
||||
"""FusedMoE method for Ascend W4A16."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.transpose_weight = True
|
||||
self.num_bits = 4 # dtype = torch.int4
|
||||
self.pack_factor = 8 # pack 8 of torch.int4 tensors to torch.int32
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.group_size = vllm_config.quant_config.quant_description.get(
|
||||
"group_size", 32)
|
||||
self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb
|
||||
|
||||
def get_weight(
|
||||
self,
|
||||
num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
assert intermediate_size_per_partition % self.pack_factor == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `pack_factor` {self.pack_factor}"
|
||||
assert hidden_sizes % self.pack_factor == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `pack_factor` {self.pack_factor}"
|
||||
|
||||
param_dict = {}
|
||||
|
||||
param_dict["w13_weight_packed"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.pack_factor,
|
||||
dtype=torch.int32)
|
||||
param_dict["w2_weight_packed"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.pack_factor,
|
||||
dtype=torch.int32)
|
||||
|
||||
return param_dict
|
||||
|
||||
def get_dynamic_quant_param(
|
||||
self,
|
||||
num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
assert intermediate_size_per_partition % self.group_size == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `group_size` {self.group_size}"
|
||||
assert hidden_sizes % self.group_size == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `group_size` {self.group_size}"
|
||||
|
||||
param_dict = {}
|
||||
|
||||
param_dict["w13_weight_scale"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=torch.bfloat16)
|
||||
param_dict["w2_weight_scale"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=torch.bfloat16)
|
||||
param_dict["w13_weight_shape"] = torch.empty(num_experts,
|
||||
2,
|
||||
dtype=torch.int32)
|
||||
param_dict["w2_weight_shape"] = torch.empty(num_experts,
|
||||
2,
|
||||
dtype=torch.int32)
|
||||
param_dict["w13_weight_offset"] = torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=torch.bfloat16)
|
||||
param_dict["w2_weight_offset"] = torch.zeros(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
return param_dict
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = True,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
return moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight_packed,
|
||||
w2=layer.w2_weight_packed,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_offset=layer.w13_weight_offset,
|
||||
w2_offset=layer.w2_weight_offset,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_int4_w4a16=True,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
dynamic_eplb=self.dynamic_eplb,
|
||||
mc2_mask=kwargs.get("mc2_mask", None))
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if self.transpose_weight:
|
||||
w13_shape = layer.w13_weight_packed.data.shape
|
||||
w2_shape = layer.w2_weight_packed.data.shape
|
||||
unpacked_w13_weight = (unpack_from_int32(
|
||||
layer.w13_weight_packed.data.flatten(0, 1),
|
||||
torch.Size([
|
||||
w13_shape[0] * w13_shape[1],
|
||||
w13_shape[2] * self.pack_factor
|
||||
]),
|
||||
self.num_bits,
|
||||
).view(w13_shape[0], w13_shape[1],
|
||||
-1).transpose(1, 2).contiguous().int())
|
||||
unpacked_w2_weight = (unpack_from_int32(
|
||||
layer.w2_weight_packed.data.flatten(0, 1),
|
||||
torch.Size([
|
||||
w2_shape[0] * w2_shape[1], w2_shape[2] * self.pack_factor
|
||||
]),
|
||||
self.num_bits,
|
||||
).view(w2_shape[0], w2_shape[1],
|
||||
-1).transpose(1, 2).contiguous().int())
|
||||
layer.w13_weight_packed.data = pack_to_int32(unpacked_w13_weight)
|
||||
layer.w2_weight_packed.data = pack_to_int32(unpacked_w2_weight)
|
||||
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(
|
||||
1, 2).contiguous()
|
||||
|
||||
layer.w13_weight_offset.data = layer.w13_weight_offset.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight_offset.data = layer.w2_weight_offset.data.transpose(
|
||||
1, 2).contiguous()
|
||||
@@ -14,16 +14,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
from .base import AscendLinearScheme
|
||||
from .registry import register_scheme
|
||||
|
||||
KRONECKER_QUANT_MAX_BATCH_SIZE = 32768
|
||||
|
||||
|
||||
def pack_int4_weights(weight_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Pack int4 weights for NPU."""
|
||||
original_device = weight_tensor.device
|
||||
weight_tensor_npu = weight_tensor.npu()
|
||||
weight_int4_packed = torch_npu.npu_convert_weight_to_int4pack(
|
||||
@@ -32,6 +37,7 @@ def pack_int4_weights(weight_tensor: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
def get_decompose_dim(n):
|
||||
"""Get decomposed dimensions for Kronecker quantization."""
|
||||
a = int(math.sqrt(n))
|
||||
if a * a < n:
|
||||
a += 1
|
||||
@@ -53,6 +59,7 @@ def batched_kronecker_quant(
|
||||
right_trans: torch.Tensor,
|
||||
clip_ratio: float,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Batched Kronecker quantization with batch size limit handling."""
|
||||
batch_tokens = x.shape[0]
|
||||
if batch_tokens <= KRONECKER_QUANT_MAX_BATCH_SIZE:
|
||||
return torch_npu.npu_kronecker_quant(x,
|
||||
@@ -75,7 +82,8 @@ def batched_kronecker_quant(
|
||||
return x_quantized_int4, activation_scale
|
||||
|
||||
|
||||
class AscendW4A4FlatQuantDynamicLinearMethod:
|
||||
@register_scheme("W4A4_FLATQUANT_DYNAMIC", "linear")
|
||||
class AscendW4A4FlatQuantDynamicLinearMethod(AscendLinearScheme):
|
||||
"""Linear method for Ascend W4A4_FLATQUANT_DYNAMIC.
|
||||
|
||||
This class implements W4A4 quantization with FlatQuant approach and dynamic activation quantization.
|
||||
@@ -88,8 +96,7 @@ class AscendW4A4FlatQuantDynamicLinearMethod:
|
||||
def __init__(self):
|
||||
self.sym = True
|
||||
|
||||
@staticmethod
|
||||
def get_weight(input_size: int, output_size: int,
|
||||
def get_weight(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
if input_size % 8 != 0:
|
||||
raise ValueError(
|
||||
@@ -101,8 +108,7 @@ class AscendW4A4FlatQuantDynamicLinearMethod:
|
||||
}
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
left_trans_dim, right_trans_dim = get_decompose_dim(
|
||||
AscendW4A4FlatQuantDynamicLinearMethod.input_size)
|
||||
@@ -115,8 +121,8 @@ class AscendW4A4FlatQuantDynamicLinearMethod:
|
||||
params_dict["clip_ratio"] = torch.empty(1, dtype=torch.float32)
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_perchannel_param(
|
||||
self,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
@@ -129,15 +135,8 @@ class AscendW4A4FlatQuantDynamicLinearMethod:
|
||||
dtype=torch.float32)
|
||||
return params_dict
|
||||
|
||||
def get_pergroup_param(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
layer_type: Optional[str] = None) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
@@ -14,22 +14,31 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
import torch.nn.functional as F
|
||||
|
||||
class AscendW4A4LaosDynamicLinearMethod:
|
||||
"""Linear method for Ascend W4A4_LAOS_DYNAMIC.
|
||||
from .base import AscendLinearScheme
|
||||
from .registry import register_scheme
|
||||
|
||||
|
||||
@register_scheme("W4A4_DYNAMIC", "linear")
|
||||
class AscendW4A4LaosDynamicLinearMethod(AscendLinearScheme):
|
||||
"""Linear method for Ascend W4A4_DYNAMIC.
|
||||
|
||||
This class implements W4A4 quantization with LAOS approach and dynamic activation quantization.
|
||||
- Weight: 4-bit quantization (per-channel) with scale and offset, stored as int8.
|
||||
- Activation: 4-bit dynamic quantization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
self.rotation_type = None
|
||||
|
||||
def set_rotation_config(self, prefix, metadata):
|
||||
def set_rotation_config(self, prefix: str, metadata: Dict) -> Optional[str]:
|
||||
"""Set rotation config based on prefix and metadata."""
|
||||
layer_idx = prefix.split(".")[2]
|
||||
if prefix.endswith("o_proj"):
|
||||
layers = metadata["quarot"]["heads_rotation"]["layers"]
|
||||
@@ -39,17 +48,17 @@ class AscendW4A4LaosDynamicLinearMethod:
|
||||
layers = metadata["quarot"]["kronecker_rotation"]["layers"]
|
||||
if layer_idx in layers:
|
||||
return "kronecker_rotation"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_weight(input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
|
||||
def get_weight(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
params_dict = {
|
||||
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
|
||||
}
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def get_perchannel_param(self, output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
|
||||
def get_perchannel_param(self, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
params_dict["weight_scale"] = torch.empty(output_size,
|
||||
1,
|
||||
@@ -58,20 +67,18 @@ class AscendW4A4LaosDynamicLinearMethod:
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
if self.rotation_type == "heads_rotation":
|
||||
params_dict["heads_rotation"] = torch.zeros((64, 64), dtype=torch.float32)
|
||||
params_dict["heads_rotation"] = torch.zeros((64, 64),
|
||||
dtype=torch.float32)
|
||||
if self.rotation_type == "kronecker_rotation":
|
||||
params_dict["kronecker_rotation_n"] = torch.zeros((160, 160), dtype=torch.float32)
|
||||
params_dict["kronecker_rotation_m"] = torch.zeros((160, 160), dtype=torch.float32)
|
||||
params_dict["kronecker_rotation_n"] = torch.zeros(
|
||||
(160, 160), dtype=torch.float32)
|
||||
params_dict["kronecker_rotation_m"] = torch.zeros(
|
||||
(160, 160), dtype=torch.float32)
|
||||
return params_dict
|
||||
|
||||
def get_pergroup_param(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
layer_type: Optional[str] = None) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def apply_rotation(self, layer, x):
|
||||
def apply_rotation(self, layer: torch.nn.Module,
|
||||
x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply rotation transformation to input tensor."""
|
||||
init_shape = x.shape
|
||||
dtype = x.dtype
|
||||
if self.rotation_type == "heads_rotation":
|
||||
@@ -94,17 +101,26 @@ class AscendW4A4LaosDynamicLinearMethod:
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = None,
|
||||
tp_rank: Optional[int] = 0,
|
||||
) -> torch.Tensor:
|
||||
dtype = x.dtype
|
||||
x, pertoken_scale = torch_npu.npu_dynamic_quant(x, dst_type=torch.quint4x2)
|
||||
pertoken_scale = pertoken_scale.reshape(-1, 1)
|
||||
pertoken_scale = pertoken_scale.squeeze(-1)
|
||||
y2 = torch_npu.npu_quant_matmul(x, layer.weight.data, scale=layer.weight_scale.data.view(-1), pertoken_scale=pertoken_scale, bias=None, output_dtype=dtype)
|
||||
return y2
|
||||
output = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
layer.weight.data,
|
||||
scale=layer.weight_scale.data.view(-1),
|
||||
pertoken_scale=pertoken_scale,
|
||||
bias=None,
|
||||
output_dtype=dtype)
|
||||
if bias is not None:
|
||||
output = output + bias.to(dtype)
|
||||
return output
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.weight_scale.data = layer.weight_scale.data.to(torch.float32)
|
||||
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(layer.weight.data.to(torch.int32))
|
||||
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
|
||||
layer.weight.data.to(torch.int32))
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(-1, -2)
|
||||
@@ -29,10 +29,13 @@ from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.utils import maybe_trans_nz
|
||||
|
||||
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
|
||||
from .registry import register_scheme
|
||||
|
||||
class AscendW4A8DynamicLinearMethod:
|
||||
"""Linear method for Ascend W4A8_DYNAMIC
|
||||
"""
|
||||
|
||||
@register_scheme("W4A8_DYNAMIC", "linear")
|
||||
class AscendW4A8DynamicLinearMethod(AscendLinearScheme):
|
||||
"""Linear method for Ascend W4A8_DYNAMIC."""
|
||||
|
||||
def __init__(self):
|
||||
vllm_config = get_current_vllm_config()
|
||||
@@ -72,23 +75,12 @@ class AscendW4A8DynamicLinearMethod:
|
||||
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def get_perchannel_param(output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def get_pergroup_param(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
layer_type: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Create per-group quantization parameters.
|
||||
"""
|
||||
"""Create per-group quantization parameters."""
|
||||
params_dict = {}
|
||||
params_dict["weight_scale"] = torch.empty(output_size,
|
||||
1,
|
||||
@@ -121,8 +113,7 @@ class AscendW4A8DynamicLinearMethod:
|
||||
scale: torch.Tensor,
|
||||
per_group_scale: torch.Tensor,
|
||||
is_new_quant: bool = False):
|
||||
"""
|
||||
Process the scale for second-level quantization.
|
||||
"""Process the scale for second-level quantization.
|
||||
|
||||
Args:
|
||||
weight: weight tensor [k, n] (in new version, n is already compressed to n/2)
|
||||
@@ -207,9 +198,12 @@ class AscendW4A8DynamicLinearMethod:
|
||||
layer.weight.data.to(torch.int32))
|
||||
|
||||
|
||||
class AscendW4A8DynamicFusedMoEMethod:
|
||||
"""FusedMoe method for Ascend W4A8_DYNAMIC.
|
||||
"""
|
||||
@register_scheme("W4A8_DYNAMIC", "moe")
|
||||
class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
|
||||
"""FusedMoE method for Ascend W4A8_DYNAMIC."""
|
||||
|
||||
# Declare the quantization type for this scheme
|
||||
quant_type: QuantType = QuantType.W4A8
|
||||
|
||||
def __init__(self):
|
||||
self.ep_group = get_ep_group()
|
||||
@@ -339,7 +333,7 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = False,
|
||||
log2phy: torch.Tensor = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@@ -22,17 +22,22 @@ import torch_npu
|
||||
|
||||
from vllm_ascend.utils import maybe_trans_nz
|
||||
|
||||
from .base import AscendLinearScheme
|
||||
from .registry import register_scheme
|
||||
|
||||
class AscendW8A16LinearMethod:
|
||||
|
||||
@register_scheme("W8A16", "linear")
|
||||
class AscendW8A16LinearMethod(AscendLinearScheme):
|
||||
"""Linear method for Ascend W8A16.
|
||||
|
||||
|
||||
This scheme uses 8-bit quantized weights with 16-bit activations.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_weight(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype = torch.bfloat16,
|
||||
@@ -42,12 +47,8 @@ class AscendW8A16LinearMethod:
|
||||
}
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def get_perchannel_param(
|
||||
self,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
@@ -60,15 +61,8 @@ class AscendW8A16LinearMethod:
|
||||
dtype=params_dtype)
|
||||
return params_dict
|
||||
|
||||
def get_pergroup_param(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
layer_type: Optional[str] = None) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
@@ -32,28 +32,39 @@ from vllm_ascend.ops.fused_moe.experts_selector import (select_experts,
|
||||
zero_experts_compute)
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, maybe_trans_nz
|
||||
|
||||
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
|
||||
from .registry import register_scheme
|
||||
|
||||
class AscendW8A8DynamicLinearMethod:
|
||||
|
||||
def scale_from_float_to_int64(scale):
|
||||
"""Convert float32 scale to int64 representation."""
|
||||
import numpy as np
|
||||
scale = torch.from_numpy(
|
||||
np.frombuffer(scale.cpu().to(torch.float32).numpy().tobytes(),
|
||||
dtype=np.int32).astype(np.int64)).to(scale.device)
|
||||
return scale
|
||||
|
||||
|
||||
@register_scheme("W8A8_DYNAMIC", "linear")
|
||||
class AscendW8A8DynamicLinearMethod(AscendLinearScheme):
|
||||
"""Linear method for Ascend W8A8_DYNAMIC.
|
||||
|
||||
This scheme uses dynamic per-token quantization for activations
|
||||
and per-channel quantization for weights.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_weight(input_size: int, output_size: int,
|
||||
def get_weight(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
params_dict = {
|
||||
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
|
||||
}
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def get_perchannel_param(
|
||||
self,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
@@ -66,15 +77,8 @@ class AscendW8A8DynamicLinearMethod:
|
||||
dtype=params_dtype)
|
||||
return params_dict
|
||||
|
||||
def get_pergroup_param(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
layer_type: Optional[str] = None) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
@@ -100,9 +104,12 @@ class AscendW8A8DynamicLinearMethod:
|
||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||
|
||||
|
||||
class AscendW8A8DynamicFusedMoEMethod:
|
||||
"""FusedMoe method for Ascend W8A8_DYNAMIC.
|
||||
"""
|
||||
@register_scheme("W8A8_DYNAMIC", "moe")
|
||||
class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
|
||||
"""FusedMoE method for Ascend W8A8_DYNAMIC."""
|
||||
|
||||
# Declare the quantization type for this scheme
|
||||
quant_type: QuantType = QuantType.W8A8
|
||||
|
||||
def __init__(self):
|
||||
self.ep_group = get_ep_group()
|
||||
@@ -128,9 +135,8 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
except AttributeError:
|
||||
self.moe_all_to_all_group_name = ""
|
||||
|
||||
@staticmethod
|
||||
def get_weight(num_experts: int, intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
def get_weight(self, num_experts: int,
|
||||
intermediate_size_per_partition: int, hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
param_dict["w13_weight"] = torch.empty(num_experts,
|
||||
@@ -144,8 +150,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
dtype=torch.int8)
|
||||
return param_dict
|
||||
|
||||
@staticmethod
|
||||
def get_dynamic_quant_param(num_experts: int,
|
||||
def get_dynamic_quant_param(self, num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
@@ -188,7 +193,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = False,
|
||||
log2phy: torch.Tensor = None,
|
||||
log2phy: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
pertoken_scale: Optional[Any] = None,
|
||||
**kwargs,
|
||||
@@ -324,11 +329,3 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
del layer.w13_weight_scale_fp32
|
||||
del layer.w2_weight_scale
|
||||
torch.npu.empty_cache()
|
||||
|
||||
|
||||
def scale_from_float_to_int64(scale):
|
||||
import numpy as np
|
||||
scale = torch.from_numpy(
|
||||
np.frombuffer(scale.cpu().to(torch.float32).numpy().tobytes(),
|
||||
dtype=np.int32).astype(np.int64)).to(scale.device)
|
||||
return scale
|
||||
@@ -21,9 +21,17 @@ import torch
|
||||
import torch_npu
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
from .base import AscendLinearScheme
|
||||
from .registry import register_scheme
|
||||
|
||||
class AscendW8A8MXFP8DynamicLinearMethod:
|
||||
"""Linear method for Ascend W8A8_DYNAMIC.
|
||||
|
||||
@register_scheme("W8A8_MXFP8", "linear")
|
||||
class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme):
|
||||
"""Linear method for Ascend W8A8_MXFP8 (Microscaling FP8) quantization.
|
||||
|
||||
This scheme uses microscaling FP8 quantization with per-group scales.
|
||||
The activation is dynamically quantized to FP8 (E4M3FN format) with
|
||||
microscaling, and weights are stored in FP8 format with per-group scales.
|
||||
"""
|
||||
model_dtype = None
|
||||
|
||||
@@ -32,8 +40,7 @@ class AscendW8A8MXFP8DynamicLinearMethod:
|
||||
self.group_size = vllm_config.quant_config.quant_description.get(
|
||||
"group_size", 32)
|
||||
|
||||
@staticmethod
|
||||
def get_weight(input_size: int, output_size: int,
|
||||
def get_weight(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
params_dict = {
|
||||
"weight":
|
||||
@@ -41,17 +48,6 @@ class AscendW8A8MXFP8DynamicLinearMethod:
|
||||
}
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def get_perchannel_param(
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def get_pergroup_param(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
117
vllm_ascend/quantization/methods/w8a8_pdmix.py
Normal file
117
vllm_ascend/quantization/methods/w8a8_pdmix.py
Normal file
@@ -0,0 +1,117 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""W8A8 Prefill-Decode Mix quantization methods.
|
||||
|
||||
This module provides quantization methods that use different strategies
|
||||
for prefill and decode phases:
|
||||
- Prefill: Uses dynamic W8A8 quantization
|
||||
- Decode (KV consumer): Uses static W8A8 quantization
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
from .base import AscendLinearScheme
|
||||
from .registry import register_scheme
|
||||
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
|
||||
AscendW8A8DynamicLinearMethod)
|
||||
from .w8a8_static import AscendW8A8LinearMethod
|
||||
|
||||
|
||||
@register_scheme("W8A8_MIX", "linear")
|
||||
class AscendW8A8PDMixLinearMethod(AscendLinearScheme):
|
||||
"""Linear method for W8A8 prefill-decode mix quantization.
|
||||
|
||||
This scheme uses composition to delegate to the appropriate quantization
|
||||
method based on the execution phase:
|
||||
- Static W8A8 for KV consumer (decode phase)
|
||||
- Dynamic W8A8 for prefill phase
|
||||
|
||||
The static method is used for weight/parameter specifications since
|
||||
it requires more parameters (input_scale, deq_scale, etc.) that are
|
||||
needed for static quantization during decode.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._static_method = AscendW8A8LinearMethod()
|
||||
self._dynamic_method = AscendW8A8DynamicLinearMethod()
|
||||
|
||||
kv_transfer_config = get_current_vllm_config().kv_transfer_config
|
||||
self._is_kv_consumer = (kv_transfer_config is not None
|
||||
and kv_transfer_config.is_kv_consumer)
|
||||
|
||||
def get_weight(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return self._static_method.get_weight(input_size, output_size,
|
||||
params_dtype)
|
||||
|
||||
def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return self._static_method.get_pertensor_param(params_dtype)
|
||||
|
||||
def get_perchannel_param(
|
||||
self,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
return self._static_method.get_perchannel_param(
|
||||
output_size, params_dtype)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = 0,
|
||||
) -> torch.Tensor:
|
||||
if layer.is_kv_consumer:
|
||||
return self._static_method.apply(layer, x, bias, tp_rank)
|
||||
else:
|
||||
return self._dynamic_method.apply(layer, x, bias, tp_rank)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self._static_method.process_weights_after_loading(layer)
|
||||
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
||||
layer.is_kv_consumer = self._is_kv_consumer
|
||||
|
||||
|
||||
@register_scheme("W8A8_MIX", "moe")
|
||||
class AscendW8A8PDMixFusedMoeMethod(AscendW8A8DynamicFusedMoEMethod):
|
||||
|
||||
def get_dynamic_quant_param(self, num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = super().get_dynamic_quant_param(
|
||||
num_experts, intermediate_size_per_partition, hidden_sizes,
|
||||
params_dtype)
|
||||
param_dict["w2_deq_scale"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
dtype=torch.float32)
|
||||
param_dict["w13_deq_scale"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_input_offset"] = torch.empty(num_experts,
|
||||
1,
|
||||
dtype=torch.int8)
|
||||
param_dict["w13_input_offset"] = torch.empty(num_experts,
|
||||
1,
|
||||
dtype=torch.int8)
|
||||
|
||||
return param_dict
|
||||
@@ -24,27 +24,23 @@ from vllm_ascend.utils import (COMPRESSED_TENSORS_METHOD, AscendDeviceType,
|
||||
get_ascend_device_type,
|
||||
get_weight_prefetch_method, maybe_trans_nz)
|
||||
|
||||
|
||||
def quant_per_tensor(in_tensor: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
input_offset: torch.Tensor,
|
||||
function=False):
|
||||
return torch_npu.npu_quantize(in_tensor, input_scale, input_offset,
|
||||
torch.qint8, -1, function)
|
||||
from .base import AscendLinearScheme
|
||||
from .registry import register_scheme
|
||||
|
||||
|
||||
class AscendW8A8LinearMethod:
|
||||
"""Linear method for Ascend W8A8.
|
||||
@register_scheme("W8A8", "linear")
|
||||
class AscendW8A8LinearMethod(AscendLinearScheme):
|
||||
"""Linear method for Ascend W8A8 static quantization.
|
||||
|
||||
Args:
|
||||
w_sym: whether the linear weight is symmetrically quantized.
|
||||
This scheme uses static per-tensor quantization for activations
|
||||
and per-channel quantization for weights.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_weight(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype = torch.bfloat16,
|
||||
@@ -54,15 +50,14 @@ class AscendW8A8LinearMethod:
|
||||
}
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
|
||||
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_perchannel_param(
|
||||
self,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
@@ -82,15 +77,8 @@ class AscendW8A8LinearMethod:
|
||||
dtype=params_dtype)
|
||||
return params_dict
|
||||
|
||||
def get_pergroup_param(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
layer_type: Optional[str] = None) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
471
vllm_ascend/quantization/modelslim_config.py
Normal file
471
vllm_ascend/quantization/modelslim_config.py
Normal file
@@ -0,0 +1,471 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
"""ModelSlim quantization configuration and model mappings for Ascend.
|
||||
|
||||
This module provides the AscendModelSlimConfig class for parsing quantization
|
||||
configs generated by the ModelSlim tool, along with model-specific mappings.
|
||||
"""
|
||||
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
import torch
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.quantization import \
|
||||
register_quantization_config
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
UnquantizedEmbeddingMethod, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
|
||||
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
|
||||
|
||||
from .methods import get_scheme_class
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# key: model_type
|
||||
# value: orig_to_new_prefix
|
||||
QUANT_MODEL_PREFIX_MAPPINGS: Dict[str, Dict[str, str]] = {
|
||||
"qwen3_vl_moe": {
|
||||
"visual.": "model.visual.",
|
||||
"language_model.lm_head.": "lm_head.",
|
||||
"language_model.model.": "model.language_model.",
|
||||
},
|
||||
"qwen3_vl_text": {
|
||||
"visual.": "model.visual.",
|
||||
"language_model.lm_head.": "lm_head.",
|
||||
"language_model.model.": "model.language_model.",
|
||||
},
|
||||
}
|
||||
|
||||
# key: model_type
|
||||
# value: dict of fused module name -> list of original module names
|
||||
packed_modules_model_mapping: Dict[str, Dict[str, List[str]]] = {
|
||||
"qwen3_moe": {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
},
|
||||
"deepseek_v2": {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
|
||||
},
|
||||
"deepseek_v3": {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
|
||||
},
|
||||
"pangu_ultra_moe": {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
|
||||
},
|
||||
"kimi_k2": {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
|
||||
},
|
||||
"deepseek_v32": {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
|
||||
},
|
||||
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
|
||||
# NOTE 2.The description file generated by the current msmodelslim tool does not have
|
||||
# MTP layer info. Please manually add it and set the value to FLOAT.
|
||||
"deepseek_mtp": {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
||||
},
|
||||
"pangu_ultra_moe_mtp": {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
|
||||
},
|
||||
"qwen3_next": {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"in_proj": ["in_proj_qkvz", "in_proj_ba"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
||||
},
|
||||
"qwen2_5_vl": {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
},
|
||||
"qwen3_vl_moe": {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
},
|
||||
"glm4_moe": {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
||||
},
|
||||
"longcat_flash": {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
|
||||
},
|
||||
"minimax_m2": {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"experts": ["experts.0.w1", "experts.0.w2", "experts.0.w3"]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def get_packed_modules_mapping(model_type: str) -> Dict[str, List[str]]:
|
||||
"""Get packed modules mapping for a model type.
|
||||
|
||||
Args:
|
||||
model_type: The model type string (e.g., "deepseek_v3").
|
||||
|
||||
Returns:
|
||||
Dictionary mapping fused module names to their component module names.
|
||||
Returns empty dict if model_type is not found.
|
||||
"""
|
||||
return packed_modules_model_mapping.get(model_type, {})
|
||||
|
||||
|
||||
def get_prefix_mapping(model_type: str) -> Dict[str, str]:
|
||||
"""Get prefix mapping for a model type.
|
||||
|
||||
Args:
|
||||
model_type: The model type string (e.g., "qwen3_vl_moe").
|
||||
|
||||
Returns:
|
||||
Dictionary mapping original prefixes to new prefixes.
|
||||
Returns empty dict if model_type is not found.
|
||||
"""
|
||||
return QUANT_MODEL_PREFIX_MAPPINGS.get(model_type, {})
|
||||
|
||||
|
||||
def get_linear_quant_type(
|
||||
quant_description: Dict[str, Any], prefix: str,
|
||||
packed_modules_mapping: Dict[str, Any]) -> Optional[str]:
|
||||
"""Determine the quantization type for a linear layer.
|
||||
|
||||
Args:
|
||||
quant_description: The quantization description dictionary.
|
||||
prefix: The layer prefix.
|
||||
packed_modules_mapping: Mapping for packed/fused modules.
|
||||
|
||||
Returns:
|
||||
The quantization type string (e.g., "W8A8_DYNAMIC").
|
||||
"""
|
||||
proj_name = prefix.split(".")[-1]
|
||||
if proj_name in packed_modules_mapping:
|
||||
quant_type = None
|
||||
shard_prefixes = [
|
||||
prefix.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in packed_modules_mapping[proj_name]
|
||||
]
|
||||
for shard_prefix in shard_prefixes:
|
||||
shard_quant_type = quant_description[shard_prefix + '.weight']
|
||||
|
||||
if quant_type is None:
|
||||
quant_type = shard_quant_type
|
||||
elif shard_quant_type != quant_type:
|
||||
raise ValueError(
|
||||
f"Not all shards of {prefix} are quantized with same quant type."
|
||||
f"Shard {proj_name} uses {shard_quant_type}, but another shard"
|
||||
f"use {quant_type}. Please check quantization config.")
|
||||
else:
|
||||
quant_type = quant_description[prefix + '.weight']
|
||||
return quant_type
|
||||
|
||||
|
||||
def get_quant_type_for_layer(
|
||||
quant_description: Dict[str, Any],
|
||||
prefix: str,
|
||||
layer_type: str,
|
||||
packed_modules_mapping: Optional[Dict[str,
|
||||
Any]] = None) -> Optional[str]:
|
||||
"""Determine the quantization type for a layer.
|
||||
|
||||
Args:
|
||||
quant_description: The quantization description dictionary.
|
||||
prefix: The layer prefix.
|
||||
layer_type: The type of layer ("linear", "moe", "attention").
|
||||
packed_modules_mapping: Mapping for packed/fused modules.
|
||||
|
||||
Returns:
|
||||
The quantization type string (e.g., "W8A8_DYNAMIC").
|
||||
"""
|
||||
if packed_modules_mapping is None:
|
||||
packed_modules_mapping = dict()
|
||||
# Attention
|
||||
if layer_type == "attention" and 'fa_quant_type' in quant_description.keys(
|
||||
):
|
||||
return quant_description['fa_quant_type']
|
||||
# Linear / MoE
|
||||
return get_linear_quant_type(quant_description, prefix,
|
||||
packed_modules_mapping)
|
||||
|
||||
|
||||
def create_scheme_for_layer(
|
||||
quant_description: Dict[str, Any],
|
||||
prefix: str,
|
||||
layer_type: str,
|
||||
packed_modules_mapping: Optional[Dict[str, Any]] = None):
|
||||
"""Create a quantization scheme instance for a layer.
|
||||
|
||||
Args:
|
||||
quant_description: The quantization description dictionary.
|
||||
prefix: The layer prefix.
|
||||
layer_type: The type of layer ("linear", "moe", "attention").
|
||||
packed_modules_mapping: Mapping for packed/fused modules.
|
||||
|
||||
Returns:
|
||||
An instance of the appropriate quantization scheme class.
|
||||
"""
|
||||
logger.info_once("Using the vLLM Ascend modelslim Quantization now!")
|
||||
quant_type = get_quant_type_for_layer(quant_description, prefix,
|
||||
layer_type, packed_modules_mapping)
|
||||
|
||||
if quant_type is None:
|
||||
raise ValueError(
|
||||
f"Could not determine quantization type for layer {prefix}.")
|
||||
|
||||
# Use registry to get scheme class
|
||||
scheme_cls = get_scheme_class(quant_type, layer_type)
|
||||
if scheme_cls is not None:
|
||||
return scheme_cls()
|
||||
|
||||
raise NotImplementedError(
|
||||
f"Currently, vLLM Ascend doesn't support {quant_type} for {layer_type}."
|
||||
)
|
||||
|
||||
|
||||
@register_quantization_config(ASCEND_QUANTIZATION_METHOD)
|
||||
class AscendModelSlimConfig(QuantizationConfig):
|
||||
"""Config class for Ascend ModelSlim quantization.
|
||||
|
||||
This class is a general class that parses quantization configs
|
||||
that are supported on Ascend hardware, specifically for models
|
||||
quantized using the ModelSlim tool.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Dict[str, Any]):
|
||||
super().__init__()
|
||||
self.quant_description = quant_config
|
||||
# TODO(whx): remove this adaptation after adding "shared_head"
|
||||
# to prefix of DeepSeekShareHead in vLLM.
|
||||
extra_quant_dict = {}
|
||||
for k in self.quant_description.keys():
|
||||
if "shared_head" in k:
|
||||
new_k = k.replace(".shared_head.", ".")
|
||||
extra_quant_dict[new_k] = self.quant_description[k]
|
||||
if "weight_packed" in k:
|
||||
new_k = k.replace("weight_packed", "weight")
|
||||
extra_quant_dict[new_k] = self.quant_description[k]
|
||||
self.quant_description.update(extra_quant_dict)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "AscendModelSlimConfig:\n" + super().__repr__()
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return ASCEND_QUANTIZATION_METHOD
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.int8, torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError(
|
||||
"Ascend hardware dose not support \"get_min_capability\" feature.")
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quant_model_description.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "AscendModelSlimConfig":
|
||||
return cls(config)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
if hf_quant_cfg is not None:
|
||||
quant_method = hf_quant_cfg.get("quant_method", None)
|
||||
if not quant_method and torch.npu.is_available():
|
||||
return ASCEND_QUANTIZATION_METHOD
|
||||
return None
|
||||
|
||||
def quant_prefix_mapper(self, model_type: str, prefix: str) -> str:
|
||||
# TODO (Levi-JQ): will be removed when QuantizationConfig.apply_vllm_mapper is implemented
|
||||
prefix_mapping = QUANT_MODEL_PREFIX_MAPPINGS.get(model_type)
|
||||
if prefix_mapping:
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix=prefix_mapping)
|
||||
return hf_to_vllm_mapper._map_name(prefix)
|
||||
return prefix
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from .method_adapters import (AscendEmbeddingMethod, AscendFusedMoEMethod,
|
||||
AscendKVCacheMethod, AscendLinearMethod)
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_type = vllm_config.model_config.hf_config.model_type
|
||||
|
||||
if model_type in ["minimax", "minimax_m2"]:
|
||||
# Adapt to Minimax architecture: update layer names to MoE convention
|
||||
prefix = prefix.replace("mlp", "block_sparse_moe")
|
||||
# Normalize the prefix by stripping specific expert indices (e.g., 'experts.0' -> 'experts')
|
||||
parts = prefix.split('.')
|
||||
if "experts" in parts and len(parts) > 2:
|
||||
exp_idx = parts.index("experts")
|
||||
if exp_idx + 1 < len(parts) and parts[exp_idx + 1].isdigit():
|
||||
parts = parts[:exp_idx + 1]
|
||||
prefix = ".".join(parts)
|
||||
|
||||
if model_type in packed_modules_model_mapping:
|
||||
self.packed_modules_mapping = packed_modules_model_mapping[
|
||||
model_type]
|
||||
prefix = self.quant_prefix_mapper(model_type, prefix)
|
||||
from vllm.attention.layer import Attention
|
||||
if prefix.startswith("language_model"):
|
||||
prefix = prefix.split('.', 1)[-1]
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.is_layer_skipped_ascend(prefix,
|
||||
self.packed_modules_mapping):
|
||||
# Delayed import to avoid circular import
|
||||
from vllm_ascend.ops.linear import \
|
||||
AscendUnquantizedLinearMethod
|
||||
return AscendUnquantizedLinearMethod()
|
||||
scheme = create_scheme_for_layer(self.quant_description, prefix,
|
||||
"linear",
|
||||
self.packed_modules_mapping)
|
||||
return AscendLinearMethod(scheme)
|
||||
elif isinstance(layer, Attention) and \
|
||||
'fa_quant_type' in self.quant_description.keys() and \
|
||||
self.quant_description['fa_quant_type'] is not None:
|
||||
scheme = create_scheme_for_layer(self.quant_description, prefix,
|
||||
"attention",
|
||||
self.packed_modules_mapping)
|
||||
return AscendKVCacheMethod(scheme)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
if self.is_layer_skipped_ascend(prefix,
|
||||
self.packed_modules_mapping):
|
||||
# Delayed import to avoid circular import
|
||||
from vllm_ascend.ops.fused_moe.fused_moe import \
|
||||
AscendUnquantizedFusedMoEMethod
|
||||
return AscendUnquantizedFusedMoEMethod(layer.moe_config)
|
||||
scheme = create_scheme_for_layer(self.quant_description, prefix,
|
||||
"moe",
|
||||
self.packed_modules_mapping)
|
||||
return AscendFusedMoEMethod(scheme, layer.moe_config)
|
||||
elif isinstance(layer, VocabParallelEmbedding):
|
||||
if self.is_layer_skipped_ascend(prefix,
|
||||
self.packed_modules_mapping):
|
||||
return UnquantizedEmbeddingMethod()
|
||||
scheme = create_scheme_for_layer(self.quant_description, prefix,
|
||||
"linear",
|
||||
self.packed_modules_mapping)
|
||||
return AscendEmbeddingMethod(scheme)
|
||||
return None
|
||||
|
||||
def is_layer_skipped_ascend(
|
||||
self,
|
||||
prefix: str,
|
||||
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})):
|
||||
# adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped
|
||||
proj_name = prefix.split(".")[-1]
|
||||
if proj_name in fused_mapping:
|
||||
shard_prefixes = [
|
||||
prefix.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in fused_mapping[proj_name]
|
||||
]
|
||||
|
||||
is_skipped = None
|
||||
for shard_prefix in shard_prefixes:
|
||||
is_shard_skipped = self.quant_description[shard_prefix +
|
||||
'.weight'] == "FLOAT"
|
||||
|
||||
if is_skipped is None:
|
||||
is_skipped = is_shard_skipped
|
||||
elif is_shard_skipped != is_skipped:
|
||||
raise ValueError(
|
||||
f"Detected some but not all shards of {prefix} "
|
||||
"are quantized. All shards of fused layers "
|
||||
"to have the same precision.")
|
||||
else:
|
||||
is_skipped = self.quant_description[prefix + '.weight'] == "FLOAT"
|
||||
|
||||
assert is_skipped is not None
|
||||
return is_skipped
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
@@ -1,600 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional
|
||||
|
||||
import torch
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import \
|
||||
register_quantization_config
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
UnquantizedEmbeddingMethod, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.model_executor.parameter import PerTensorScaleParameter
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import (get_flashcomm2_otp_group,
|
||||
get_mlp_tp_group,
|
||||
get_otp_group)
|
||||
from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
|
||||
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
|
||||
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, flashcomm2_enable,
|
||||
mlp_tp_enable, oproj_tp_enable)
|
||||
|
||||
from .utils import get_quant_method, is_mx_quant_type
|
||||
|
||||
|
||||
@register_quantization_config(ASCEND_QUANTIZATION_METHOD)
|
||||
class AscendQuantConfig(QuantizationConfig):
|
||||
"""Config class for Ascend
|
||||
|
||||
This class is a general class that parse quantization configs
|
||||
that are supported on ascend hardware.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Dict[str, Any]):
|
||||
super().__init__()
|
||||
self.quant_description = quant_config
|
||||
# TODO(whx): remove this adaptation after adding "shared_head"
|
||||
# to prefix of DeepSeekShareHead in vLLM.
|
||||
extra_quant_dict = {}
|
||||
for k in self.quant_description.keys():
|
||||
if "shared_head" in k:
|
||||
new_k = k.replace(".shared_head.", ".")
|
||||
extra_quant_dict[new_k] = self.quant_description[k]
|
||||
if "weight_packed" in k:
|
||||
new_k = k.replace("weight_packed", "weight")
|
||||
extra_quant_dict[new_k] = self.quant_description[k]
|
||||
self.quant_description.update(extra_quant_dict)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "AscendQuantConfig:\n" + super().__repr__()
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return ASCEND_QUANTIZATION_METHOD
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.int8, torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError(
|
||||
"Ascend hardware dose not support \"get_min_capability\" feature.")
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quant_model_description.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "AscendQuantConfig":
|
||||
return cls(config)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
if hf_quant_cfg is not None:
|
||||
quant_method = hf_quant_cfg.get("quant_method", None)
|
||||
if not quant_method and torch.npu.is_available():
|
||||
return ASCEND_QUANTIZATION_METHOD
|
||||
return None
|
||||
|
||||
def quant_prefix_mapper(self, model_type: str, prefix: str) -> str:
|
||||
# TODO (Levi-JQ): will be removed when QuantizationConfig.apply_vllm_mapper is implemented
|
||||
prefix_mapping = QUANT_MODEL_PREFIX_MAPPINGS.get(model_type)
|
||||
if prefix_mapping:
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix=prefix_mapping)
|
||||
return hf_to_vllm_mapper._map_name(prefix)
|
||||
return prefix
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_type = vllm_config.model_config.hf_text_config.model_type
|
||||
|
||||
if model_type in ["minimax", "minimax_m2"]:
|
||||
prefix = prefix.replace("mlp", "block_sparse_moe")
|
||||
|
||||
#To adapt to minimax, modify the prefix of the model layer name
|
||||
parts = prefix.split('.')
|
||||
if "experts" in parts and len(parts) > 2:
|
||||
exp_idx = parts.index("experts")
|
||||
if exp_idx + 1 < len(parts) and parts[exp_idx + 1].isdigit():
|
||||
parts = parts[:exp_idx + 1]
|
||||
prefix = ".".join(parts)
|
||||
|
||||
if model_type in packed_modules_model_mapping:
|
||||
self.packed_modules_mapping = packed_modules_model_mapping[
|
||||
model_type]
|
||||
prefix = self.quant_prefix_mapper(model_type, prefix)
|
||||
from vllm.attention.layer import Attention
|
||||
if prefix.startswith("language_model"):
|
||||
prefix = prefix.split('.', 1)[-1]
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.is_layer_skipped_ascend(prefix,
|
||||
self.packed_modules_mapping):
|
||||
return AscendUnquantizedLinearMethod()
|
||||
return AscendLinearMethod(self, prefix,
|
||||
self.packed_modules_mapping, layer)
|
||||
elif isinstance(layer, Attention) and \
|
||||
'fa_quant_type' in self.quant_description.keys() and \
|
||||
self.quant_description['fa_quant_type'] is not None:
|
||||
return AscendKVCacheMethod(self, prefix)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
if self.is_layer_skipped_ascend(prefix,
|
||||
self.packed_modules_mapping):
|
||||
return AscendUnquantizedFusedMoEMethod(layer.moe_config)
|
||||
return AscendFusedMoEMethod(self, prefix,
|
||||
self.packed_modules_mapping, layer)
|
||||
elif isinstance(layer, VocabParallelEmbedding):
|
||||
if self.is_layer_skipped_ascend(prefix,
|
||||
self.packed_modules_mapping):
|
||||
return UnquantizedEmbeddingMethod()
|
||||
return AscendEmbeddingMethod(self, prefix,
|
||||
self.packed_modules_mapping, layer)
|
||||
return None
|
||||
|
||||
def is_layer_skipped_ascend(
|
||||
self,
|
||||
prefix: str,
|
||||
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})):
|
||||
# adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped
|
||||
proj_name = prefix.split(".")[-1]
|
||||
if proj_name in fused_mapping:
|
||||
shard_prefixes = [
|
||||
prefix.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in fused_mapping[proj_name]
|
||||
]
|
||||
|
||||
is_skipped = None
|
||||
for shard_prefix in shard_prefixes:
|
||||
is_shard_skipped = self.quant_description[shard_prefix +
|
||||
'.weight'] == "FLOAT"
|
||||
|
||||
if is_skipped is None:
|
||||
is_skipped = is_shard_skipped
|
||||
elif is_shard_skipped != is_skipped:
|
||||
raise ValueError(
|
||||
f"Detected some but not all shards of {prefix} "
|
||||
"are quantized. All shards of fused layers "
|
||||
"to have the same precision.")
|
||||
else:
|
||||
# NOTE: In GLM4.6, the MTP draft model shares the same LM head weigthts
|
||||
# with the main model. Therefore, before `load_weights()` runs, some parameter
|
||||
# names may not include the expected prefix and may appear only with the
|
||||
# ".head" suffix. This can trigger a load-time error, so here we replace the
|
||||
# key with "lm_head.weight".
|
||||
key = prefix + '.weight'
|
||||
if key not in self.quant_description and ".head" in prefix:
|
||||
key = 'lm_head.weight'
|
||||
is_skipped = self.quant_description[key] == "FLOAT"
|
||||
|
||||
assert is_skipped is not None
|
||||
return is_skipped
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
# key: model_type
|
||||
# value: orig_to_new_prefix
|
||||
QUANT_MODEL_PREFIX_MAPPINGS = {
|
||||
"qwen3_vl_moe": {
|
||||
"visual.": "model.visual.",
|
||||
"language_model.lm_head.": "lm_head.",
|
||||
"language_model.model.": "model.language_model.",
|
||||
},
|
||||
"qwen3_vl_text": {
|
||||
"visual.": "model.visual.",
|
||||
"language_model.lm_head.": "lm_head.",
|
||||
"language_model.model.": "model.language_model.",
|
||||
},
|
||||
}
|
||||
|
||||
packed_modules_model_mapping = {
|
||||
"qwen3_moe": {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
},
|
||||
"deepseek_v2": {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
|
||||
},
|
||||
"deepseek_v3": {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
|
||||
},
|
||||
"pangu_ultra_moe": {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
|
||||
},
|
||||
"kimi_k2": {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
|
||||
},
|
||||
"deepseek_v32": {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
|
||||
},
|
||||
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
|
||||
# NOTE 2.The description file generated by the current msmodelslim tool does not have
|
||||
# MTP layer info. Please manually add it and set the value to FLOAT.
|
||||
"deepseek_mtp": {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
||||
},
|
||||
"pangu_ultra_moe_mtp": {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
|
||||
},
|
||||
"qwen3_next": {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"in_proj": ["in_proj_qkvz", "in_proj_ba"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
||||
},
|
||||
"qwen2_5_vl": {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
},
|
||||
"qwen3_vl_moe": {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
},
|
||||
"glm4_moe": {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
||||
},
|
||||
"longcat_flash": {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
|
||||
},
|
||||
"minimax_m2": {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"experts": ["experts.0.w1", "experts.0.w2", "experts.0.w3"]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class AscendLinearMethod(LinearMethodBase):
|
||||
"""Linear method for Ascend quantization.
|
||||
|
||||
Args:
|
||||
quant_config: The Ascend quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
quant_config: AscendQuantConfig,
|
||||
prefix: str,
|
||||
packed_modules_mapping: Dict[str, Any] | None,
|
||||
layer: torch.nn.Module = None) -> None:
|
||||
self.quant_method = get_quant_method(quant_config.quant_description,
|
||||
prefix,
|
||||
"linear",
|
||||
packed_modules_mapping,
|
||||
layer=layer)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
weight_dict = self.quant_method.get_weight(input_size_per_partition,
|
||||
output_size_per_partition,
|
||||
params_dtype)
|
||||
|
||||
# Extract packing information (if present)
|
||||
packed_dim = weight_dict.pop("_packed_dim", None)
|
||||
packed_factor = weight_dict.pop("_packed_factor", None)
|
||||
|
||||
for weight_name, weight_param in weight_dict.items():
|
||||
param = torch.nn.Parameter(weight_param, requires_grad=False)
|
||||
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
|
||||
|
||||
# Set packing attributes if the weight is packed
|
||||
if packed_dim is not None and packed_factor is not None:
|
||||
set_weight_attrs(param, {
|
||||
"packed_dim": packed_dim,
|
||||
"packed_factor": packed_factor
|
||||
})
|
||||
|
||||
layer.register_parameter(weight_name, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
|
||||
for pertensor_name, pertensor_param in pertensor_dict.items():
|
||||
param = PerTensorScaleParameter(data=pertensor_param,
|
||||
weight_loader=weight_loader)
|
||||
# disable warning
|
||||
param.ignore_warning = True
|
||||
layer.register_parameter(pertensor_name, param)
|
||||
param.weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
perchannel_dict = self.quant_method.get_perchannel_param(
|
||||
output_size_per_partition, params_dtype)
|
||||
for perchannel_name, perchannel_param in perchannel_dict.items():
|
||||
param = torch.nn.Parameter(perchannel_param, requires_grad=False)
|
||||
set_weight_attrs(param, {"output_dim": 0})
|
||||
layer.register_parameter(perchannel_name, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
# NOTE: In w4a8 quantization implementation,
|
||||
# for down_proj and o_proj scale_bias shape is [output_size, 16],
|
||||
# others are [output_size, 1]
|
||||
layer_type = "row" if isinstance(layer,
|
||||
RowParallelLinear) else "others"
|
||||
|
||||
pergroup_dict = self.quant_method.get_pergroup_param(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition,
|
||||
params_dtype,
|
||||
layer_type=layer_type)
|
||||
for pergroup_name, pergroup_param in pergroup_dict.items():
|
||||
param = torch.nn.Parameter(pergroup_param, requires_grad=False)
|
||||
set_weight_attrs(param, {"output_dim": 0})
|
||||
layer.register_parameter(pergroup_name, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name \
|
||||
or is_mx_quant_type(self.quant_method):
|
||||
setattr(param, "input_dim", 1)
|
||||
param.input_dim = 1
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(layer, RowParallelLinear):
|
||||
if layer.prefix.find("o_proj") != -1 and oproj_tp_enable():
|
||||
tp_rank = get_otp_group().rank_in_group
|
||||
elif layer.prefix.find("down_proj") != -1 and mlp_tp_enable():
|
||||
tp_rank = get_mlp_tp_group().rank_in_group
|
||||
elif (layer.prefix.find("o_proj") != -1 or
|
||||
layer.prefix.find("out_proj") != -1) and flashcomm2_enable():
|
||||
if get_ascend_config(
|
||||
).flashcomm2_oproj_tensor_parallel_size == 1:
|
||||
tp_rank = 0
|
||||
else:
|
||||
tp_rank = get_flashcomm2_otp_group().rank_in_group
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
else:
|
||||
tp_rank = 0
|
||||
return self.quant_method.apply(layer, x, bias, tp_rank)
|
||||
|
||||
|
||||
class AscendKVCacheMethod(BaseKVCacheMethod):
|
||||
"""KVCache method for Ascend quantization.
|
||||
|
||||
Args:
|
||||
quant_config: The Ascend quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None:
|
||||
self.quant_method = get_quant_method(quant_config.quant_description,
|
||||
prefix, "attention")
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module) -> None:
|
||||
# Different from linear method, there are no weight processing/slicing
|
||||
# steps for attention in vllm. So the whole process of create weights
|
||||
# is hidden into the specific quant method.
|
||||
self.quant_method.create_weights(layer)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata,
|
||||
attn_type, scale, output) -> torch.Tensor:
|
||||
return self.quant_method.apply(layer, query, key, value, kv_cache,
|
||||
attn_metadata, attn_type, scale, output)
|
||||
|
||||
|
||||
class AscendFusedMoEMethod(FusedMoEMethodBase):
|
||||
"""FusedMoE method for Ascend quantization.
|
||||
|
||||
Args:
|
||||
quant_config: The Ascend quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
||||
packed_modules_mapping: Dict[str,
|
||||
Any], layer: torch.nn.Module):
|
||||
super().__init__(layer.moe_config)
|
||||
self.quant_method = get_quant_method(quant_config.quant_description,
|
||||
prefix,
|
||||
"moe",
|
||||
packed_modules_mapping,
|
||||
layer=layer)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
weight_param = self.quant_method.get_weight(
|
||||
num_experts, intermediate_size_per_partition, hidden_size,
|
||||
params_dtype)
|
||||
for param_key, param_value in weight_param.items():
|
||||
param = torch.nn.Parameter(param_value, requires_grad=False)
|
||||
layer.register_parameter(param_key, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
|
||||
per_group_param = [
|
||||
"weight_scale_second", "weight_offset_second", "scale_bias"
|
||||
] + ["weight_scale", "weight_offset"] if hasattr(
|
||||
self.quant_method,
|
||||
"group_size") and self.quant_method.group_size > 0 else []
|
||||
dynamic_quant_param = self.quant_method.get_dynamic_quant_param(
|
||||
num_experts, intermediate_size_per_partition, hidden_size,
|
||||
params_dtype)
|
||||
for param_key, param_value in dynamic_quant_param.items():
|
||||
param = torch.nn.Parameter(param_value, requires_grad=False)
|
||||
layer.register_parameter(param_key, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
if any(fields in param_key for fields in per_group_param):
|
||||
setattr(param, "quant_method",
|
||||
FusedMoeWeightScaleSupported.GROUP.value)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = False,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num=0,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.quant_method.apply(
|
||||
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
|
||||
global_num_experts, expert_map, topk_group, num_expert_group,
|
||||
custom_routing_function, scoring_func, routed_scaling_factor,
|
||||
e_score_correction_bias, is_prefill, enable_force_load_balance,
|
||||
log2phy, global_redundant_expert_num, **kwargs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module):
|
||||
# TODO: implement this function
|
||||
pass
|
||||
|
||||
@property
|
||||
def supports_eplb(self):
|
||||
supports_eplb = getattr(self.quant_method, "supports_eplb", False)
|
||||
return supports_eplb
|
||||
|
||||
|
||||
class AscendEmbeddingMethod(AscendLinearMethod):
|
||||
"""Embedding method for Ascend quantization.
|
||||
|
||||
Args:
|
||||
quant_config: The Ascend quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
||||
packed_modules_mapping: Dict[str, Any],
|
||||
layer: torch.nn.Module) -> None:
|
||||
self.quant_method = get_quant_method(quant_config.quant_description,
|
||||
prefix,
|
||||
"linear",
|
||||
packed_modules_mapping,
|
||||
layer=layer)
|
||||
@@ -1,129 +0,0 @@
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
import torch
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD
|
||||
|
||||
from .w4a4_flatquant_dynamic import AscendW4A4FlatQuantDynamicLinearMethod
|
||||
from .w4a4_laos_dynamic import AscendW4A4LaosDynamicLinearMethod
|
||||
from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod,
|
||||
AscendW4A8DynamicLinearMethod)
|
||||
from .w4a16 import AscendW4A16FusedMoEMethod
|
||||
from .w8a8 import AscendW8A8LinearMethod
|
||||
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
|
||||
AscendW8A8DynamicLinearMethod)
|
||||
from .w8a8_pdmix import (AscendW8A8PDMixFusedMoeMethod,
|
||||
AscendW8A8PDMixLinearMethod)
|
||||
from .w8a8mxfp8 import AscendW8A8MXFP8DynamicLinearMethod
|
||||
from .w8a16 import AscendW8A16LinearMethod
|
||||
|
||||
ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = {
|
||||
"W4A16": {
|
||||
"moe": AscendW4A16FusedMoEMethod,
|
||||
},
|
||||
"W4A8_DYNAMIC": {
|
||||
"linear": AscendW4A8DynamicLinearMethod,
|
||||
"moe": AscendW4A8DynamicFusedMoEMethod,
|
||||
},
|
||||
"W4A4_DYNAMIC": {
|
||||
"linear": AscendW4A4LaosDynamicLinearMethod,
|
||||
},
|
||||
"W4A4_FLATQUANT_DYNAMIC": {
|
||||
"linear": AscendW4A4FlatQuantDynamicLinearMethod,
|
||||
},
|
||||
"W8A8": {
|
||||
"linear": AscendW8A8LinearMethod,
|
||||
},
|
||||
"W8A8_DYNAMIC": {
|
||||
"linear": AscendW8A8DynamicLinearMethod,
|
||||
"moe": AscendW8A8DynamicFusedMoEMethod,
|
||||
},
|
||||
"W8A8_MIX": {
|
||||
"linear": AscendW8A8PDMixLinearMethod,
|
||||
"moe": AscendW8A8PDMixFusedMoeMethod,
|
||||
},
|
||||
"W8A16": {
|
||||
"linear": AscendW8A16LinearMethod,
|
||||
},
|
||||
"W8A8_MXFP8": {
|
||||
"linear": AscendW8A8MXFP8DynamicLinearMethod,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_linear_quant_type(quant_description: Dict[str, Any], prefix: str,
|
||||
packed_modules_mapping: Dict[str, Any]):
|
||||
proj_name = prefix.split(".")[-1]
|
||||
if proj_name in packed_modules_mapping:
|
||||
quant_type = None
|
||||
shard_prefixes = [
|
||||
prefix.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in packed_modules_mapping[proj_name]
|
||||
]
|
||||
for shard_prefix in shard_prefixes:
|
||||
shard_quant_type = quant_description[shard_prefix + '.weight']
|
||||
|
||||
if quant_type is None:
|
||||
quant_type = shard_quant_type
|
||||
elif shard_quant_type != quant_type:
|
||||
raise ValueError(
|
||||
f"Not all shards of {prefix} are quantized with same quant type."
|
||||
f"Shard {proj_name} uses {shard_quant_type}, but another shard"
|
||||
f"use {quant_type}. Please check quantization config.")
|
||||
else:
|
||||
quant_type = quant_description[prefix + '.weight']
|
||||
return quant_type
|
||||
|
||||
|
||||
def get_quant_method(quant_description: Dict[str, Any],
|
||||
prefix: str,
|
||||
layer_type: str,
|
||||
packed_modules_mapping: Optional[Dict[str, Any]] = None,
|
||||
layer: torch.nn.Module = None):
|
||||
if quant_description.get("quant_method") == COMPRESSED_TENSORS_METHOD:
|
||||
return get_quant_method_llmcompressor(layer)
|
||||
|
||||
return get_quant_method_modelslim(quant_description, prefix, layer_type,
|
||||
packed_modules_mapping)
|
||||
|
||||
|
||||
def get_quant_method_llmcompressor(layer: torch.nn.Module):
|
||||
logger.info_once("Using the vLLM Ascend llmcompressor Quantization now!")
|
||||
if layer.scheme is None:
|
||||
raise ValueError("A scheme must be defined for each layer")
|
||||
return layer.scheme
|
||||
|
||||
|
||||
def get_quant_method_modelslim(
|
||||
quant_description: Dict[str, Any],
|
||||
prefix: str,
|
||||
layer_type: str,
|
||||
packed_modules_mapping: Optional[Dict[str, Any]] = None):
|
||||
logger.info_once("Using the vLLM Ascend modelslim Quantization now!")
|
||||
if packed_modules_mapping is None:
|
||||
packed_modules_mapping = dict()
|
||||
# Attention
|
||||
if '.attn' in prefix and 'fa_quant_type' in quant_description.keys():
|
||||
quant_type = quant_description['fa_quant_type']
|
||||
# Linear
|
||||
else:
|
||||
quant_type = get_linear_quant_type(quant_description, prefix,
|
||||
packed_modules_mapping)
|
||||
if quant_type in ASCEND_QUANTIZATION_METHOD_MAP.keys():
|
||||
method_map = ASCEND_QUANTIZATION_METHOD_MAP[quant_type]
|
||||
if layer_type in method_map.keys():
|
||||
method_cls = method_map[layer_type]
|
||||
return method_cls()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Currently, vLLM Ascend doesn't support {quant_type} for {layer_type}."
|
||||
)
|
||||
raise NotImplementedError("Currently, vLLM Ascend only supports following quant types:" \
|
||||
f"{list(ASCEND_QUANTIZATION_METHOD_MAP.keys())}")
|
||||
|
||||
|
||||
def is_mx_quant_type(instance: Any) -> bool:
|
||||
"""Checks if the quantization method is a mix-precision type."""
|
||||
MX_QUANT_TYPES = (AscendW8A8MXFP8DynamicLinearMethod, )
|
||||
return isinstance(instance, MX_QUANT_TYPES)
|
||||
@@ -1,70 +0,0 @@
|
||||
from typing import Any, Dict, cast
|
||||
|
||||
import torch
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
from .w8a8 import AscendW8A8LinearMethod
|
||||
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
|
||||
AscendW8A8DynamicLinearMethod)
|
||||
|
||||
|
||||
class AscendW8A8PDMixLinearMethod(AscendW8A8DynamicLinearMethod):
|
||||
|
||||
def __init__(self):
|
||||
self.kv_transfer_config = get_current_vllm_config().kv_transfer_config
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def apply(layer, x, bias=None, tp_rank=0):
|
||||
if layer.is_kv_consumer:
|
||||
return AscendW8A8LinearMethod.apply(layer, x, bias, tp_rank)
|
||||
else:
|
||||
return AscendW8A8DynamicLinearMethod.apply(layer, x, bias, tp_rank)
|
||||
|
||||
@staticmethod
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return AscendW8A8LinearMethod.get_pertensor_param(params_dtype)
|
||||
|
||||
@staticmethod
|
||||
def get_perchannel_param(
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
return AscendW8A8LinearMethod.get_perchannel_param(
|
||||
output_size, params_dtype)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
AscendW8A8LinearMethod.process_weights_after_loading(
|
||||
cast(AscendW8A8LinearMethod, self), layer)
|
||||
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
||||
layer.is_kv_consumer = self.kv_transfer_config is not None and self.kv_transfer_config.is_kv_consumer
|
||||
|
||||
|
||||
class AscendW8A8PDMixFusedMoeMethod(AscendW8A8DynamicFusedMoEMethod):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def get_dynamic_quant_param(num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = AscendW8A8DynamicFusedMoEMethod.get_dynamic_quant_param(
|
||||
num_experts, intermediate_size_per_partition, hidden_sizes,
|
||||
params_dtype)
|
||||
param_dict["w2_deq_scale"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
dtype=torch.float32)
|
||||
param_dict["w13_deq_scale"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_input_offset"] = torch.empty(num_experts,
|
||||
1,
|
||||
dtype=torch.int8)
|
||||
param_dict["w13_input_offset"] = torch.empty(num_experts,
|
||||
1,
|
||||
dtype=torch.int8)
|
||||
|
||||
return param_dict
|
||||
Reference in New Issue
Block a user