2026-01-09 16:26:31 +08:00
|
|
|
#
|
|
|
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
|
|
|
# This file is a part of the vllm-ascend project.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
#
|
|
|
|
|
|
2026-03-02 11:04:06 +08:00
|
|
|
from collections.abc import Callable
|
2026-02-06 14:56:53 +08:00
|
|
|
from typing import Any
|
2026-01-09 16:26:31 +08:00
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch_npu
|
2026-03-02 11:04:06 +08:00
|
|
|
from vllm.config import CompilationMode, get_current_vllm_config
|
|
|
|
|
from vllm.distributed import get_ep_group
|
2026-01-09 16:26:31 +08:00
|
|
|
|
2026-03-02 11:04:06 +08:00
|
|
|
from vllm_ascend.ascend_config import get_ascend_config
|
2026-03-13 09:11:46 +08:00
|
|
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
2026-03-02 18:17:01 +08:00
|
|
|
from vllm_ascend.device.mxfp_compat import (
|
2026-03-02 11:04:06 +08:00
|
|
|
FLOAT8_E8M0FNU_DTYPE,
|
|
|
|
|
ensure_mxfp8_linear_available,
|
|
|
|
|
ensure_mxfp8_moe_available,
|
|
|
|
|
)
|
2026-03-02 18:17:01 +08:00
|
|
|
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
2026-03-20 23:23:57 +08:00
|
|
|
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
|
2026-03-02 11:04:06 +08:00
|
|
|
|
|
|
|
|
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
|
[Refactor] Quantization Module Refactor (#5738)
### Summary
This PR refactors the `vllm_ascend/quantization` module to improve code
organization, maintainability, and extensibility. The refactoring
introduces a clear separation of concerns with a registry-based scheme
discovery pattern, abstract base classes for quantization schemes, and
dedicated wrapper classes.
### Key Changes
#### 1. **Modular Directory Structure**
| Before | After |
|--------|-------|
| Flat file structure with mixed responsibilities | Organized into
`methods/` subpackage for schemes |
| Single `quant_config.py` (600+ lines) | Separate config files:
`modelslim_config.py`, `compressed_tensors_config.py` |
| `utils.py` with scheme lookup logic | `methods/registry.py` with
decorator-based registration |
#### 2. **Registry-Based Scheme Discovery**
Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a
decorator-based registry pattern:
```python
# Before: Manual dictionary mapping
ASCEND_QUANTIZATION_METHOD_MAP = {
"W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...},
...
}
# After: Decorator-based registration
@register_scheme("W8A8_DYNAMIC", "linear")
class AscendW8A8DynamicLinearMethod(AscendLinearScheme):
...
```
#### 3. **Abstract Base Classes**
Introduced three abstract base classes in `methods/base.py`:
- `AscendLinearScheme` - Base for linear layer quantization
- `AscendMoEScheme` - Base for MoE layer quantization
- `AscendAttentionScheme` - Base for attention layer quantization
#### 4. **Separated Config and Wrapper Classes**
- **Config classes** (`AscendModelSlimConfig`,
`AscendCompressedTensorsConfig`): Handle config parsing and scheme
selection
- **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`,
etc.): Implement vLLM interfaces and delegate to schemes
#### 5. **Cleaner Public API**
```python
# New clean module interface
from vllm_ascend.quantization import (
AscendModelSlimConfig,
AscendCompressedTensorsConfig,
)
from vllm_ascend.quantization.methods import get_scheme_class
```
### Architecture Diagram
```mermaid
classDiagram
direction TB
class QuantizationConfig {
<<vLLM Interface>>
+get_quant_method()
}
class AscendModelSlimConfig {
+quant_description
+get_quant_method()
-create_scheme_for_layer()
}
class AscendCompressedTensorsConfig {
+target_scheme_map
+get_quant_method()
-_get_scheme_from_parts()
}
class AscendLinearMethod {
<<Wrapper>>
+quant_method: AscendLinearScheme
+create_weights()
+apply()
}
class AscendFusedMoEMethod {
<<Wrapper>>
+quant_method: AscendMoEScheme
+create_weights()
+apply()
}
class AscendLinearScheme {
<<Abstract>>
+get_weight()*
+apply()*
+get_pertensor_param()
+get_perchannel_param()
}
class AscendMoEScheme {
<<Abstract>>
+get_weight()*
+get_dynamic_quant_param()*
+apply()*
}
class W8A8DynamicLinear {
+get_weight()
+apply()
}
class W8A8DynamicMoE {
+get_weight()
+apply()
}
QuantizationConfig <|-- AscendModelSlimConfig
QuantizationConfig <|-- AscendCompressedTensorsConfig
AscendModelSlimConfig ..> AscendLinearMethod : creates
AscendModelSlimConfig ..> AscendFusedMoEMethod : creates
AscendCompressedTensorsConfig ..> AscendLinearMethod : creates
AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates
AscendLinearMethod o-- AscendLinearScheme : delegates to
AscendFusedMoEMethod o-- AscendMoEScheme : delegates to
AscendLinearScheme <|-- W8A8DynamicLinear
AscendMoEScheme <|-- W8A8DynamicMoE
```
### Scheme Registration Flow
```mermaid
sequenceDiagram
participant Module as Scheme Module
participant Registry as _SCHEME_REGISTRY
participant Config as QuantConfig
participant Wrapper as Wrapper Class
Note over Module: At import time
Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear")
Registry->>Registry: Store (quant_type, layer_type) -> Class
Note over Config: At runtime
Config->>Config: Determine quant_type from description
Config->>Registry: get_scheme_class(quant_type, layer_type)
Registry-->>Config: Return scheme class
Config->>Config: scheme = scheme_cls()
Config->>Wrapper: Create wrapper with scheme
Wrapper-->>Config: Return wrapper instance
```
### File Changes Summary
| Original Files | Refactored Files |
|----------------|------------------|
| `__init__.py` (empty) | `__init__.py` (exports public API) |
| `quant_config.py` | `modelslim_config.py` + `wrappers.py` |
| `compressed_tensors/` | `compressed_tensors_config.py` |
| `utils.py` | `methods/registry.py` |
| `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` |
| `w8a8.py` | `methods/w8a8_static.py` |
| `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` |
| ... | `methods/base.py` (new) |
### Benefits
1. **Extensibility**: Adding new quantization schemes only requires
implementing the base class and adding `@register_scheme` decorator
2. **Maintainability**: Clear separation between config parsing, wrapper
logic, and scheme implementation
3. **Testability**: Abstract base classes enable easier unit testing and
mocking
4. **Discoverability**: Registry pattern makes it easy to list all
supported schemes
5. **Reduced Coupling**: Config classes no longer need to know about all
scheme implementations
___
- vLLM version: v0.13.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d
---------
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
2026-01-23 14:13:47 +08:00
|
|
|
from .registry import register_scheme
|
2026-01-09 16:26:31 +08:00
|
|
|
|
[Refactor] Quantization Module Refactor (#5738)
### Summary
This PR refactors the `vllm_ascend/quantization` module to improve code
organization, maintainability, and extensibility. The refactoring
introduces a clear separation of concerns with a registry-based scheme
discovery pattern, abstract base classes for quantization schemes, and
dedicated wrapper classes.
### Key Changes
#### 1. **Modular Directory Structure**
| Before | After |
|--------|-------|
| Flat file structure with mixed responsibilities | Organized into
`methods/` subpackage for schemes |
| Single `quant_config.py` (600+ lines) | Separate config files:
`modelslim_config.py`, `compressed_tensors_config.py` |
| `utils.py` with scheme lookup logic | `methods/registry.py` with
decorator-based registration |
#### 2. **Registry-Based Scheme Discovery**
Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a
decorator-based registry pattern:
```python
# Before: Manual dictionary mapping
ASCEND_QUANTIZATION_METHOD_MAP = {
"W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...},
...
}
# After: Decorator-based registration
@register_scheme("W8A8_DYNAMIC", "linear")
class AscendW8A8DynamicLinearMethod(AscendLinearScheme):
...
```
#### 3. **Abstract Base Classes**
Introduced three abstract base classes in `methods/base.py`:
- `AscendLinearScheme` - Base for linear layer quantization
- `AscendMoEScheme` - Base for MoE layer quantization
- `AscendAttentionScheme` - Base for attention layer quantization
#### 4. **Separated Config and Wrapper Classes**
- **Config classes** (`AscendModelSlimConfig`,
`AscendCompressedTensorsConfig`): Handle config parsing and scheme
selection
- **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`,
etc.): Implement vLLM interfaces and delegate to schemes
#### 5. **Cleaner Public API**
```python
# New clean module interface
from vllm_ascend.quantization import (
AscendModelSlimConfig,
AscendCompressedTensorsConfig,
)
from vllm_ascend.quantization.methods import get_scheme_class
```
### Architecture Diagram
```mermaid
classDiagram
direction TB
class QuantizationConfig {
<<vLLM Interface>>
+get_quant_method()
}
class AscendModelSlimConfig {
+quant_description
+get_quant_method()
-create_scheme_for_layer()
}
class AscendCompressedTensorsConfig {
+target_scheme_map
+get_quant_method()
-_get_scheme_from_parts()
}
class AscendLinearMethod {
<<Wrapper>>
+quant_method: AscendLinearScheme
+create_weights()
+apply()
}
class AscendFusedMoEMethod {
<<Wrapper>>
+quant_method: AscendMoEScheme
+create_weights()
+apply()
}
class AscendLinearScheme {
<<Abstract>>
+get_weight()*
+apply()*
+get_pertensor_param()
+get_perchannel_param()
}
class AscendMoEScheme {
<<Abstract>>
+get_weight()*
+get_dynamic_quant_param()*
+apply()*
}
class W8A8DynamicLinear {
+get_weight()
+apply()
}
class W8A8DynamicMoE {
+get_weight()
+apply()
}
QuantizationConfig <|-- AscendModelSlimConfig
QuantizationConfig <|-- AscendCompressedTensorsConfig
AscendModelSlimConfig ..> AscendLinearMethod : creates
AscendModelSlimConfig ..> AscendFusedMoEMethod : creates
AscendCompressedTensorsConfig ..> AscendLinearMethod : creates
AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates
AscendLinearMethod o-- AscendLinearScheme : delegates to
AscendFusedMoEMethod o-- AscendMoEScheme : delegates to
AscendLinearScheme <|-- W8A8DynamicLinear
AscendMoEScheme <|-- W8A8DynamicMoE
```
### Scheme Registration Flow
```mermaid
sequenceDiagram
participant Module as Scheme Module
participant Registry as _SCHEME_REGISTRY
participant Config as QuantConfig
participant Wrapper as Wrapper Class
Note over Module: At import time
Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear")
Registry->>Registry: Store (quant_type, layer_type) -> Class
Note over Config: At runtime
Config->>Config: Determine quant_type from description
Config->>Registry: get_scheme_class(quant_type, layer_type)
Registry-->>Config: Return scheme class
Config->>Config: scheme = scheme_cls()
Config->>Wrapper: Create wrapper with scheme
Wrapper-->>Config: Return wrapper instance
```
### File Changes Summary
| Original Files | Refactored Files |
|----------------|------------------|
| `__init__.py` (empty) | `__init__.py` (exports public API) |
| `quant_config.py` | `modelslim_config.py` + `wrappers.py` |
| `compressed_tensors/` | `compressed_tensors_config.py` |
| `utils.py` | `methods/registry.py` |
| `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` |
| `w8a8.py` | `methods/w8a8_static.py` |
| `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` |
| ... | `methods/base.py` (new) |
### Benefits
1. **Extensibility**: Adding new quantization schemes only requires
implementing the base class and adding `@register_scheme` decorator
2. **Maintainability**: Clear separation between config parsing, wrapper
logic, and scheme implementation
3. **Testability**: Abstract base classes enable easier unit testing and
mocking
4. **Discoverability**: Registry pattern makes it easy to list all
supported schemes
5. **Reduced Coupling**: Config classes no longer need to know about all
scheme implementations
___
- vLLM version: v0.13.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d
---------
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
2026-01-23 14:13:47 +08:00
|
|
|
|
|
|
|
|
@register_scheme("W8A8_MXFP8", "linear")
|
|
|
|
|
class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme):
|
|
|
|
|
"""Linear method for Ascend W8A8_MXFP8 (Microscaling FP8) quantization.
|
2026-02-06 14:56:53 +08:00
|
|
|
|
[Refactor] Quantization Module Refactor (#5738)
### Summary
This PR refactors the `vllm_ascend/quantization` module to improve code
organization, maintainability, and extensibility. The refactoring
introduces a clear separation of concerns with a registry-based scheme
discovery pattern, abstract base classes for quantization schemes, and
dedicated wrapper classes.
### Key Changes
#### 1. **Modular Directory Structure**
| Before | After |
|--------|-------|
| Flat file structure with mixed responsibilities | Organized into
`methods/` subpackage for schemes |
| Single `quant_config.py` (600+ lines) | Separate config files:
`modelslim_config.py`, `compressed_tensors_config.py` |
| `utils.py` with scheme lookup logic | `methods/registry.py` with
decorator-based registration |
#### 2. **Registry-Based Scheme Discovery**
Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a
decorator-based registry pattern:
```python
# Before: Manual dictionary mapping
ASCEND_QUANTIZATION_METHOD_MAP = {
"W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...},
...
}
# After: Decorator-based registration
@register_scheme("W8A8_DYNAMIC", "linear")
class AscendW8A8DynamicLinearMethod(AscendLinearScheme):
...
```
#### 3. **Abstract Base Classes**
Introduced three abstract base classes in `methods/base.py`:
- `AscendLinearScheme` - Base for linear layer quantization
- `AscendMoEScheme` - Base for MoE layer quantization
- `AscendAttentionScheme` - Base for attention layer quantization
#### 4. **Separated Config and Wrapper Classes**
- **Config classes** (`AscendModelSlimConfig`,
`AscendCompressedTensorsConfig`): Handle config parsing and scheme
selection
- **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`,
etc.): Implement vLLM interfaces and delegate to schemes
#### 5. **Cleaner Public API**
```python
# New clean module interface
from vllm_ascend.quantization import (
AscendModelSlimConfig,
AscendCompressedTensorsConfig,
)
from vllm_ascend.quantization.methods import get_scheme_class
```
### Architecture Diagram
```mermaid
classDiagram
direction TB
class QuantizationConfig {
<<vLLM Interface>>
+get_quant_method()
}
class AscendModelSlimConfig {
+quant_description
+get_quant_method()
-create_scheme_for_layer()
}
class AscendCompressedTensorsConfig {
+target_scheme_map
+get_quant_method()
-_get_scheme_from_parts()
}
class AscendLinearMethod {
<<Wrapper>>
+quant_method: AscendLinearScheme
+create_weights()
+apply()
}
class AscendFusedMoEMethod {
<<Wrapper>>
+quant_method: AscendMoEScheme
+create_weights()
+apply()
}
class AscendLinearScheme {
<<Abstract>>
+get_weight()*
+apply()*
+get_pertensor_param()
+get_perchannel_param()
}
class AscendMoEScheme {
<<Abstract>>
+get_weight()*
+get_dynamic_quant_param()*
+apply()*
}
class W8A8DynamicLinear {
+get_weight()
+apply()
}
class W8A8DynamicMoE {
+get_weight()
+apply()
}
QuantizationConfig <|-- AscendModelSlimConfig
QuantizationConfig <|-- AscendCompressedTensorsConfig
AscendModelSlimConfig ..> AscendLinearMethod : creates
AscendModelSlimConfig ..> AscendFusedMoEMethod : creates
AscendCompressedTensorsConfig ..> AscendLinearMethod : creates
AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates
AscendLinearMethod o-- AscendLinearScheme : delegates to
AscendFusedMoEMethod o-- AscendMoEScheme : delegates to
AscendLinearScheme <|-- W8A8DynamicLinear
AscendMoEScheme <|-- W8A8DynamicMoE
```
### Scheme Registration Flow
```mermaid
sequenceDiagram
participant Module as Scheme Module
participant Registry as _SCHEME_REGISTRY
participant Config as QuantConfig
participant Wrapper as Wrapper Class
Note over Module: At import time
Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear")
Registry->>Registry: Store (quant_type, layer_type) -> Class
Note over Config: At runtime
Config->>Config: Determine quant_type from description
Config->>Registry: get_scheme_class(quant_type, layer_type)
Registry-->>Config: Return scheme class
Config->>Config: scheme = scheme_cls()
Config->>Wrapper: Create wrapper with scheme
Wrapper-->>Config: Return wrapper instance
```
### File Changes Summary
| Original Files | Refactored Files |
|----------------|------------------|
| `__init__.py` (empty) | `__init__.py` (exports public API) |
| `quant_config.py` | `modelslim_config.py` + `wrappers.py` |
| `compressed_tensors/` | `compressed_tensors_config.py` |
| `utils.py` | `methods/registry.py` |
| `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` |
| `w8a8.py` | `methods/w8a8_static.py` |
| `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` |
| ... | `methods/base.py` (new) |
### Benefits
1. **Extensibility**: Adding new quantization schemes only requires
implementing the base class and adding `@register_scheme` decorator
2. **Maintainability**: Clear separation between config parsing, wrapper
logic, and scheme implementation
3. **Testability**: Abstract base classes enable easier unit testing and
mocking
4. **Discoverability**: Registry pattern makes it easy to list all
supported schemes
5. **Reduced Coupling**: Config classes no longer need to know about all
scheme implementations
___
- vLLM version: v0.13.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d
---------
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
2026-01-23 14:13:47 +08:00
|
|
|
This scheme uses microscaling FP8 quantization with per-group scales.
|
|
|
|
|
The activation is dynamically quantized to FP8 (E4M3FN format) with
|
|
|
|
|
microscaling, and weights are stored in FP8 format with per-group scales.
|
2026-01-09 16:26:31 +08:00
|
|
|
"""
|
2026-02-06 14:56:53 +08:00
|
|
|
|
2026-01-09 16:26:31 +08:00
|
|
|
model_dtype = None
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
2026-03-02 11:04:06 +08:00
|
|
|
ensure_mxfp8_linear_available("W8A8_MXFP8 linear quantization")
|
2026-01-09 16:26:31 +08:00
|
|
|
vllm_config = get_current_vllm_config()
|
2026-02-06 14:56:53 +08:00
|
|
|
self.group_size = vllm_config.quant_config.quant_description.get("group_size", 32)
|
2026-01-09 16:26:31 +08:00
|
|
|
|
2026-02-06 14:56:53 +08:00
|
|
|
def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
|
|
|
|
|
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.float8_e4m3fn)}
|
2026-01-09 16:26:31 +08:00
|
|
|
return params_dict
|
|
|
|
|
|
2026-02-06 14:56:53 +08:00
|
|
|
def get_pergroup_param(
|
|
|
|
|
self, input_size: int, output_size: int, params_dtype: torch.dtype, layer_type: str | None = None
|
|
|
|
|
) -> dict[str, Any]:
|
2026-01-09 16:26:31 +08:00
|
|
|
params_dict = {}
|
2026-02-06 14:56:53 +08:00
|
|
|
params_dict["weight_scale"] = torch.empty(output_size, input_size // self.group_size, dtype=torch.uint8)
|
2026-01-09 16:26:31 +08:00
|
|
|
return params_dict
|
|
|
|
|
|
|
|
|
|
def apply(
|
|
|
|
|
self,
|
|
|
|
|
layer: torch.nn.Module,
|
|
|
|
|
x: torch.Tensor,
|
2026-02-06 14:56:53 +08:00
|
|
|
bias: torch.Tensor | None = None,
|
|
|
|
|
tp_rank: int | None = 0,
|
2026-01-09 16:26:31 +08:00
|
|
|
) -> torch.Tensor:
|
2026-03-20 16:18:58 +08:00
|
|
|
# reshape x for Qwen VL models
|
|
|
|
|
original_shape = x.shape
|
|
|
|
|
if x.dim() > 2:
|
|
|
|
|
x = x.view(-1, x.shape[-1])
|
2026-02-06 14:56:53 +08:00
|
|
|
quantized_x, dynamic_scale = torch_npu.npu_dynamic_mx_quant(x, dst_type=torch.float8_e4m3fn)
|
2026-01-09 16:26:31 +08:00
|
|
|
pertoken_scale = dynamic_scale
|
|
|
|
|
output_dtype = x.dtype
|
2026-03-20 16:18:58 +08:00
|
|
|
if bias is not None and bias.dtype != torch.float32:
|
|
|
|
|
bias = bias.to(torch.float32)
|
2026-01-09 16:26:31 +08:00
|
|
|
|
|
|
|
|
output = torch_npu.npu_quant_matmul(
|
|
|
|
|
quantized_x,
|
|
|
|
|
layer.weight,
|
|
|
|
|
layer.weight_scale,
|
2026-03-02 11:04:06 +08:00
|
|
|
scale_dtype=FLOAT8_E8M0FNU_DTYPE,
|
2026-01-09 16:26:31 +08:00
|
|
|
pertoken_scale=pertoken_scale,
|
2026-03-02 11:04:06 +08:00
|
|
|
pertoken_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
|
2026-01-09 16:26:31 +08:00
|
|
|
bias=bias,
|
|
|
|
|
output_dtype=output_dtype,
|
2026-02-06 14:56:53 +08:00
|
|
|
group_sizes=[1, 1, self.group_size],
|
|
|
|
|
)
|
2026-03-20 16:18:58 +08:00
|
|
|
# reshape output for Qwen VL models
|
|
|
|
|
if len(original_shape) > 2:
|
|
|
|
|
output = output.view(*original_shape[:-1], -1)
|
2026-01-09 16:26:31 +08:00
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
def process_weights_after_loading(self, layer):
|
|
|
|
|
n_dim, k_dim = layer.weight_scale.data.shape
|
2026-02-06 14:56:53 +08:00
|
|
|
layer.weight_scale.data = layer.weight_scale.data.reshape(n_dim, k_dim // 2, 2)
|
2026-01-09 16:26:31 +08:00
|
|
|
layer.weight.data = layer.weight.data.transpose(0, 1)
|
|
|
|
|
layer.weight_scale.data = layer.weight_scale.data.transpose(0, 1)
|
2026-03-02 11:04:06 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_scheme("W8A8_MXFP8", "moe")
|
|
|
|
|
class AscendW8A8MXFP8DynamicFusedMoEMethod(AscendMoEScheme):
|
|
|
|
|
"""FusedMoe method for Ascend W8A8_DYNAMIC."""
|
|
|
|
|
|
|
|
|
|
model_dtype = None
|
|
|
|
|
quant_type: QuantType = QuantType.MXFP8
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
ensure_mxfp8_moe_available("W8A8_MXFP8 MoE quantization")
|
|
|
|
|
self.ep_group = get_ep_group()
|
|
|
|
|
|
|
|
|
|
vllm_config = get_current_vllm_config()
|
|
|
|
|
self.group_size = vllm_config.quant_config.quant_description.get("group_size", 32)
|
|
|
|
|
ascend_config = get_ascend_config()
|
|
|
|
|
self.use_aclgraph = (
|
|
|
|
|
vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
|
|
|
|
and not vllm_config.model_config.enforce_eager
|
|
|
|
|
)
|
|
|
|
|
self.dynamic_eplb = ascend_config.eplb_config.dynamic_eplb
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_weight(
|
|
|
|
|
num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
|
|
|
|
|
) -> dict[str, Any]:
|
|
|
|
|
param_dict = {}
|
|
|
|
|
param_dict["w13_weight"] = torch.empty(
|
|
|
|
|
num_experts, 2 * intermediate_size_per_partition, hidden_sizes, dtype=torch.float8_e4m3fn
|
|
|
|
|
)
|
|
|
|
|
param_dict["w2_weight"] = torch.empty(
|
|
|
|
|
num_experts, hidden_sizes, intermediate_size_per_partition, dtype=torch.float8_e4m3fn
|
|
|
|
|
)
|
|
|
|
|
return param_dict
|
|
|
|
|
|
|
|
|
|
def get_dynamic_quant_param(
|
|
|
|
|
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
|
|
|
|
|
) -> dict[str, Any]:
|
|
|
|
|
param_dict = {}
|
|
|
|
|
param_dict["w13_weight_scale"] = torch.empty(
|
|
|
|
|
num_experts, 2 * intermediate_size_per_partition, hidden_sizes // self.group_size, dtype=torch.uint8
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
param_dict["w2_weight_scale"] = torch.empty(
|
|
|
|
|
num_experts, hidden_sizes, intermediate_size_per_partition // self.group_size, dtype=torch.uint8
|
|
|
|
|
)
|
|
|
|
|
return param_dict
|
|
|
|
|
|
|
|
|
|
def apply(
|
|
|
|
|
self,
|
|
|
|
|
layer: torch.nn.Module,
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
router_logits: torch.Tensor,
|
|
|
|
|
top_k: int,
|
|
|
|
|
renormalize: bool,
|
|
|
|
|
use_grouped_topk: bool = False,
|
|
|
|
|
global_num_experts: int = -1,
|
|
|
|
|
expert_map: torch.Tensor | None = None,
|
|
|
|
|
topk_group: int | None = None,
|
|
|
|
|
num_expert_group: int | None = None,
|
|
|
|
|
custom_routing_function: Callable | None = None,
|
|
|
|
|
scoring_func: str = "softmax",
|
|
|
|
|
routed_scaling_factor: float = 1.0,
|
|
|
|
|
e_score_correction_bias: torch.Tensor | None = None,
|
|
|
|
|
is_prefill: bool = True,
|
|
|
|
|
enable_force_load_balance: bool = True,
|
|
|
|
|
log2phy: torch.Tensor = None,
|
|
|
|
|
global_redundant_expert_num: int = 0,
|
2026-03-20 23:23:57 +08:00
|
|
|
pertoken_scale: Any | None = None,
|
|
|
|
|
activation: str = "silu",
|
|
|
|
|
apply_router_weight_on_input: bool = False,
|
|
|
|
|
mc2_mask: torch.Tensor | None = None,
|
2026-03-02 11:04:06 +08:00
|
|
|
) -> torch.Tensor:
|
|
|
|
|
expected = global_num_experts - global_redundant_expert_num
|
|
|
|
|
assert router_logits.shape[1] == expected, "Number of global experts mismatch (excluding redundancy)"
|
|
|
|
|
topk_weights, topk_ids = select_experts(
|
|
|
|
|
hidden_states=x,
|
|
|
|
|
router_logits=router_logits,
|
|
|
|
|
top_k=top_k,
|
|
|
|
|
use_grouped_topk=use_grouped_topk,
|
|
|
|
|
renormalize=renormalize,
|
|
|
|
|
topk_group=topk_group,
|
|
|
|
|
num_expert_group=num_expert_group,
|
|
|
|
|
custom_routing_function=custom_routing_function,
|
|
|
|
|
scoring_func=scoring_func,
|
|
|
|
|
e_score_correction_bias=e_score_correction_bias,
|
|
|
|
|
global_num_experts=global_num_experts,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# this is a naive implementation for experts load balance so as
|
|
|
|
|
# to avoid accumulating too much tokens on a single rank.
|
|
|
|
|
# currently it is only activated when doing profile runs.
|
|
|
|
|
if enable_force_load_balance:
|
2026-03-25 17:20:28 +08:00
|
|
|
random_matrix = torch.rand(
|
|
|
|
|
topk_ids.size(0), global_num_experts - global_redundant_expert_num, device=topk_ids.device
|
|
|
|
|
)
|
|
|
|
|
topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)
|
2026-03-02 11:04:06 +08:00
|
|
|
|
|
|
|
|
topk_weights = topk_weights.to(x.dtype)
|
|
|
|
|
|
2026-03-13 09:11:46 +08:00
|
|
|
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
2026-03-02 11:04:06 +08:00
|
|
|
return moe_comm_method.fused_experts(
|
2026-03-20 23:23:57 +08:00
|
|
|
fused_experts_input=build_fused_experts_input(
|
|
|
|
|
hidden_states=x,
|
|
|
|
|
topk_weights=topk_weights,
|
|
|
|
|
topk_ids=topk_ids,
|
|
|
|
|
w1=layer.w13_weight,
|
|
|
|
|
w2=layer.w2_weight,
|
|
|
|
|
quant_type=self.quant_type,
|
|
|
|
|
dynamic_eplb=self.dynamic_eplb,
|
|
|
|
|
expert_map=expert_map,
|
|
|
|
|
global_redundant_expert_num=global_redundant_expert_num,
|
|
|
|
|
mc2_mask=mc2_mask,
|
|
|
|
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
|
|
|
log2phy=log2phy,
|
|
|
|
|
pertoken_scale=pertoken_scale,
|
|
|
|
|
activation=activation,
|
|
|
|
|
mxfp_act_quant_type=torch.float8_e4m3fn,
|
|
|
|
|
mxfp_weight_quant_type=torch.float8_e4m3fn,
|
|
|
|
|
mxfp_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
|
|
|
|
|
mxfp_per_token_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
|
|
|
|
|
mxfp_use_bf16=(x.dtype == torch.bfloat16),
|
|
|
|
|
w1_scale=layer.w13_weight_scale,
|
|
|
|
|
w2_scale=layer.w2_weight_scale,
|
|
|
|
|
)
|
2026-03-02 11:04:06 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def process_weights_after_loading(self, layer):
|
|
|
|
|
g_num, n_size, k_size = layer.w13_weight_scale.shape
|
|
|
|
|
layer.w13_weight_scale.data = layer.w13_weight_scale.data.reshape(g_num, n_size, k_size // 2, 2)
|
|
|
|
|
g_num, n_size, k_size = layer.w2_weight_scale.shape
|
|
|
|
|
layer.w2_weight_scale.data = layer.w2_weight_scale.data.reshape(g_num, n_size, k_size // 2, 2)
|
|
|
|
|
layer.w13_weight.data = layer.w13_weight.data.transpose(1, 2)
|
|
|
|
|
layer.w2_weight.data = layer.w2_weight.data.transpose(1, 2)
|
|
|
|
|
layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose(1, 2)
|
|
|
|
|
layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(1, 2)
|