### Summary
This PR refactors the `vllm_ascend/quantization` module to improve code
organization, maintainability, and extensibility. The refactoring
introduces a clear separation of concerns with a registry-based scheme
discovery pattern, abstract base classes for quantization schemes, and
dedicated wrapper classes.
### Key Changes
#### 1. **Modular Directory Structure**
| Before | After |
|--------|-------|
| Flat file structure with mixed responsibilities | Organized into
`methods/` subpackage for schemes |
| Single `quant_config.py` (600+ lines) | Separate config files:
`modelslim_config.py`, `compressed_tensors_config.py` |
| `utils.py` with scheme lookup logic | `methods/registry.py` with
decorator-based registration |
#### 2. **Registry-Based Scheme Discovery**
Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a
decorator-based registry pattern:
```python
# Before: Manual dictionary mapping
ASCEND_QUANTIZATION_METHOD_MAP = {
"W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...},
...
}
# After: Decorator-based registration
@register_scheme("W8A8_DYNAMIC", "linear")
class AscendW8A8DynamicLinearMethod(AscendLinearScheme):
...
```
#### 3. **Abstract Base Classes**
Introduced three abstract base classes in `methods/base.py`:
- `AscendLinearScheme` - Base for linear layer quantization
- `AscendMoEScheme` - Base for MoE layer quantization
- `AscendAttentionScheme` - Base for attention layer quantization
#### 4. **Separated Config and Wrapper Classes**
- **Config classes** (`AscendModelSlimConfig`,
`AscendCompressedTensorsConfig`): Handle config parsing and scheme
selection
- **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`,
etc.): Implement vLLM interfaces and delegate to schemes
#### 5. **Cleaner Public API**
```python
# New clean module interface
from vllm_ascend.quantization import (
AscendModelSlimConfig,
AscendCompressedTensorsConfig,
)
from vllm_ascend.quantization.methods import get_scheme_class
```
### Architecture Diagram
```mermaid
classDiagram
direction TB
class QuantizationConfig {
<<vLLM Interface>>
+get_quant_method()
}
class AscendModelSlimConfig {
+quant_description
+get_quant_method()
-create_scheme_for_layer()
}
class AscendCompressedTensorsConfig {
+target_scheme_map
+get_quant_method()
-_get_scheme_from_parts()
}
class AscendLinearMethod {
<<Wrapper>>
+quant_method: AscendLinearScheme
+create_weights()
+apply()
}
class AscendFusedMoEMethod {
<<Wrapper>>
+quant_method: AscendMoEScheme
+create_weights()
+apply()
}
class AscendLinearScheme {
<<Abstract>>
+get_weight()*
+apply()*
+get_pertensor_param()
+get_perchannel_param()
}
class AscendMoEScheme {
<<Abstract>>
+get_weight()*
+get_dynamic_quant_param()*
+apply()*
}
class W8A8DynamicLinear {
+get_weight()
+apply()
}
class W8A8DynamicMoE {
+get_weight()
+apply()
}
QuantizationConfig <|-- AscendModelSlimConfig
QuantizationConfig <|-- AscendCompressedTensorsConfig
AscendModelSlimConfig ..> AscendLinearMethod : creates
AscendModelSlimConfig ..> AscendFusedMoEMethod : creates
AscendCompressedTensorsConfig ..> AscendLinearMethod : creates
AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates
AscendLinearMethod o-- AscendLinearScheme : delegates to
AscendFusedMoEMethod o-- AscendMoEScheme : delegates to
AscendLinearScheme <|-- W8A8DynamicLinear
AscendMoEScheme <|-- W8A8DynamicMoE
```
### Scheme Registration Flow
```mermaid
sequenceDiagram
participant Module as Scheme Module
participant Registry as _SCHEME_REGISTRY
participant Config as QuantConfig
participant Wrapper as Wrapper Class
Note over Module: At import time
Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear")
Registry->>Registry: Store (quant_type, layer_type) -> Class
Note over Config: At runtime
Config->>Config: Determine quant_type from description
Config->>Registry: get_scheme_class(quant_type, layer_type)
Registry-->>Config: Return scheme class
Config->>Config: scheme = scheme_cls()
Config->>Wrapper: Create wrapper with scheme
Wrapper-->>Config: Return wrapper instance
```
### File Changes Summary
| Original Files | Refactored Files |
|----------------|------------------|
| `__init__.py` (empty) | `__init__.py` (exports public API) |
| `quant_config.py` | `modelslim_config.py` + `wrappers.py` |
| `compressed_tensors/` | `compressed_tensors_config.py` |
| `utils.py` | `methods/registry.py` |
| `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` |
| `w8a8.py` | `methods/w8a8_static.py` |
| `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` |
| ... | `methods/base.py` (new) |
### Benefits
1. **Extensibility**: Adding new quantization schemes only requires
implementing the base class and adding `@register_scheme` decorator
2. **Maintainability**: Clear separation between config parsing, wrapper
logic, and scheme implementation
3. **Testability**: Abstract base classes enable easier unit testing and
mocking
4. **Discoverability**: Registry pattern makes it easy to list all
supported schemes
5. **Reduced Coupling**: Config classes no longer need to know about all
scheme implementations
___
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
279 lines
11 KiB
Python
279 lines
11 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
from typing import Any, Callable, Dict, Optional
|
|
|
|
import torch
|
|
import torch_npu
|
|
from vllm.config import get_current_vllm_config
|
|
from vllm.forward_context import get_forward_context
|
|
|
|
from vllm_ascend.ascend_config import get_ascend_config
|
|
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
|
|
|
from .base import AscendMoEScheme
|
|
from .registry import register_scheme
|
|
|
|
|
|
def unpack_from_int32(
|
|
weight: torch.Tensor,
|
|
shape: torch.Size,
|
|
num_bits: int,
|
|
packed_dim: int = 1,
|
|
) -> torch.Tensor:
|
|
"""Unpacks quantized weights from int32 format back to original bits.
|
|
|
|
:param weight: The packed int32 tensor containing quantized weights
|
|
:param shape: Original shape to restore, defaults to None
|
|
:param num_bits: The number of bits used for quantization (<= 8)
|
|
:param packed_dim: Dimension along which weights are packed (0 or 1), defaults to 1
|
|
:return: Unpacked tensor with int8 dtype after applying offset correction
|
|
"""
|
|
assert weight.dtype == torch.int32, f"Expecting `weight.dtype` is torch.int32 but got {weight.dtype}."
|
|
assert num_bits <= 8, f"Expecting `num_bits` should not be larger than 8 but got {num_bits}."
|
|
|
|
pack_factor = 32 // num_bits
|
|
mask = (1 << num_bits) - 1
|
|
|
|
if packed_dim == 1:
|
|
unpacked_weight = torch.zeros(
|
|
(weight.shape[0], weight.shape[1] * pack_factor),
|
|
device=weight.device,
|
|
dtype=torch.int32,
|
|
)
|
|
for i in range(pack_factor):
|
|
unpacked_weight[:, i::pack_factor] = (weight >>
|
|
(num_bits * i)) & mask
|
|
original_row_size = int(shape[1])
|
|
unpacked_weight = unpacked_weight[:, :original_row_size]
|
|
else:
|
|
unpacked_weight = torch.zeros(
|
|
(weight.shape[0] * pack_factor, weight.shape[1]),
|
|
device=weight.device,
|
|
dtype=torch.int32,
|
|
)
|
|
for i in range(pack_factor):
|
|
unpacked_weight[i::pack_factor, :] = (weight >>
|
|
(num_bits * i)) & mask
|
|
original_row_size = int(shape[0])
|
|
unpacked_weight = unpacked_weight[:original_row_size, :]
|
|
|
|
offset = pow(2, num_bits) // 2
|
|
unpacked_weight = (unpacked_weight - offset).to(torch.int8)
|
|
|
|
return unpacked_weight
|
|
|
|
|
|
def pack_to_int32(weight: torch.Tensor) -> torch.Tensor:
|
|
"""Packs quantized weights into int32 format for storage.
|
|
|
|
:param weight: The 3D tensor to pack, must be int8 or int32 dtype
|
|
:return: Packed tensor with int32 dtype optimized for storage
|
|
"""
|
|
assert weight.dim(
|
|
) == 3, f"Expecting `weight.dim()` is 3 ([e, n, k] or [e, k, n]) but got {weight.dim()}."
|
|
assert weight.dtype in [
|
|
torch.int8, torch.int32
|
|
], f"Expecting `weight.dtype` is torch.int8 or torch.int32 bug got {weight.dtype}."
|
|
|
|
if weight.dtype == torch.int32:
|
|
assert weight.shape[
|
|
-1] % 8 == 0, "the last dim of weight needs to be divided by 8."
|
|
packed_weight = torch_npu.npu_convert_weight_to_int4pack(
|
|
weight.flatten(0, 1))
|
|
packed_weight = packed_weight.view(weight.shape[0], weight.shape[1],
|
|
-1)
|
|
else:
|
|
assert weight.shape[
|
|
-1] % 4 == 0, "the last dim of weight needs to be divided by 4."
|
|
packed_weight = weight.view(torch.int32).contiguous()
|
|
|
|
return packed_weight
|
|
|
|
|
|
@register_scheme("W4A16", "moe")
|
|
class AscendW4A16FusedMoEMethod(AscendMoEScheme):
|
|
"""FusedMoE method for Ascend W4A16."""
|
|
|
|
def __init__(self) -> None:
|
|
self.transpose_weight = True
|
|
self.num_bits = 4 # dtype = torch.int4
|
|
self.pack_factor = 8 # pack 8 of torch.int4 tensors to torch.int32
|
|
|
|
vllm_config = get_current_vllm_config()
|
|
self.group_size = vllm_config.quant_config.quant_description.get(
|
|
"group_size", 32)
|
|
self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb
|
|
|
|
def get_weight(
|
|
self,
|
|
num_experts: int,
|
|
intermediate_size_per_partition: int,
|
|
hidden_sizes: int,
|
|
params_dtype: torch.dtype,
|
|
) -> Dict[str, Any]:
|
|
assert intermediate_size_per_partition % self.pack_factor == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `pack_factor` {self.pack_factor}"
|
|
assert hidden_sizes % self.pack_factor == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `pack_factor` {self.pack_factor}"
|
|
|
|
param_dict = {}
|
|
|
|
param_dict["w13_weight_packed"] = torch.empty(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
hidden_sizes // self.pack_factor,
|
|
dtype=torch.int32)
|
|
param_dict["w2_weight_packed"] = torch.empty(
|
|
num_experts,
|
|
hidden_sizes,
|
|
intermediate_size_per_partition // self.pack_factor,
|
|
dtype=torch.int32)
|
|
|
|
return param_dict
|
|
|
|
def get_dynamic_quant_param(
|
|
self,
|
|
num_experts: int,
|
|
intermediate_size_per_partition: int,
|
|
hidden_sizes: int,
|
|
params_dtype: torch.dtype,
|
|
) -> Dict[str, Any]:
|
|
assert intermediate_size_per_partition % self.group_size == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `group_size` {self.group_size}"
|
|
assert hidden_sizes % self.group_size == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `group_size` {self.group_size}"
|
|
|
|
param_dict = {}
|
|
|
|
param_dict["w13_weight_scale"] = torch.empty(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
hidden_sizes // self.group_size,
|
|
dtype=torch.bfloat16)
|
|
param_dict["w2_weight_scale"] = torch.empty(
|
|
num_experts,
|
|
hidden_sizes,
|
|
intermediate_size_per_partition // self.group_size,
|
|
dtype=torch.bfloat16)
|
|
param_dict["w13_weight_shape"] = torch.empty(num_experts,
|
|
2,
|
|
dtype=torch.int32)
|
|
param_dict["w2_weight_shape"] = torch.empty(num_experts,
|
|
2,
|
|
dtype=torch.int32)
|
|
param_dict["w13_weight_offset"] = torch.zeros(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
hidden_sizes // self.group_size,
|
|
dtype=torch.bfloat16)
|
|
param_dict["w2_weight_offset"] = torch.zeros(
|
|
num_experts,
|
|
hidden_sizes,
|
|
intermediate_size_per_partition // self.group_size,
|
|
dtype=torch.bfloat16)
|
|
|
|
return param_dict
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
renormalize: bool,
|
|
use_grouped_topk: bool = False,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
routed_scaling_factor: float = 1.0,
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
is_prefill: bool = True,
|
|
enable_force_load_balance: bool = True,
|
|
log2phy: Optional[torch.Tensor] = None,
|
|
global_redundant_expert_num: int = 0,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
assert router_logits.shape[
|
|
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
|
|
|
topk_weights, topk_ids = select_experts(
|
|
hidden_states=x,
|
|
router_logits=router_logits,
|
|
top_k=top_k,
|
|
use_grouped_topk=use_grouped_topk,
|
|
renormalize=renormalize,
|
|
topk_group=topk_group,
|
|
num_expert_group=num_expert_group,
|
|
custom_routing_function=custom_routing_function,
|
|
scoring_func=scoring_func,
|
|
e_score_correction_bias=e_score_correction_bias,
|
|
global_num_experts=global_num_experts)
|
|
|
|
topk_ids = topk_ids.to(torch.int32)
|
|
topk_weights = topk_weights.to(x.dtype)
|
|
|
|
moe_comm_method = get_forward_context().moe_comm_method
|
|
return moe_comm_method.fused_experts(
|
|
hidden_states=x,
|
|
w1=layer.w13_weight_packed,
|
|
w2=layer.w2_weight_packed,
|
|
w1_scale=layer.w13_weight_scale,
|
|
w2_scale=layer.w2_weight_scale,
|
|
w1_offset=layer.w13_weight_offset,
|
|
w2_offset=layer.w2_weight_offset,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
use_int4_w4a16=True,
|
|
expert_map=expert_map,
|
|
log2phy=log2phy,
|
|
dynamic_eplb=self.dynamic_eplb,
|
|
mc2_mask=kwargs.get("mc2_mask", None))
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
if self.transpose_weight:
|
|
w13_shape = layer.w13_weight_packed.data.shape
|
|
w2_shape = layer.w2_weight_packed.data.shape
|
|
unpacked_w13_weight = (unpack_from_int32(
|
|
layer.w13_weight_packed.data.flatten(0, 1),
|
|
torch.Size([
|
|
w13_shape[0] * w13_shape[1],
|
|
w13_shape[2] * self.pack_factor
|
|
]),
|
|
self.num_bits,
|
|
).view(w13_shape[0], w13_shape[1],
|
|
-1).transpose(1, 2).contiguous().int())
|
|
unpacked_w2_weight = (unpack_from_int32(
|
|
layer.w2_weight_packed.data.flatten(0, 1),
|
|
torch.Size([
|
|
w2_shape[0] * w2_shape[1], w2_shape[2] * self.pack_factor
|
|
]),
|
|
self.num_bits,
|
|
).view(w2_shape[0], w2_shape[1],
|
|
-1).transpose(1, 2).contiguous().int())
|
|
layer.w13_weight_packed.data = pack_to_int32(unpacked_w13_weight)
|
|
layer.w2_weight_packed.data = pack_to_int32(unpacked_w2_weight)
|
|
|
|
layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose(
|
|
1, 2).contiguous()
|
|
layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(
|
|
1, 2).contiguous()
|
|
|
|
layer.w13_weight_offset.data = layer.w13_weight_offset.data.transpose(
|
|
1, 2).contiguous()
|
|
layer.w2_weight_offset.data = layer.w2_weight_offset.data.transpose(
|
|
1, 2).contiguous()
|