Files
xc-llm-ascend/vllm_ascend/quantization/method_adapters.py
Cao Yi a69ef10c3a [Refactor] Quantization Module Refactor (#5738)
### Summary

This PR refactors the `vllm_ascend/quantization` module to improve code
organization, maintainability, and extensibility. The refactoring
introduces a clear separation of concerns with a registry-based scheme
discovery pattern, abstract base classes for quantization schemes, and
dedicated wrapper classes.

### Key Changes

#### 1. **Modular Directory Structure**

| Before | After |
|--------|-------|
| Flat file structure with mixed responsibilities | Organized into
`methods/` subpackage for schemes |
| Single `quant_config.py` (600+ lines) | Separate config files:
`modelslim_config.py`, `compressed_tensors_config.py` |
| `utils.py` with scheme lookup logic | `methods/registry.py` with
decorator-based registration |

#### 2. **Registry-Based Scheme Discovery**

Replaced hardcoded `ASCEND_QUANTIZATION_METHOD_MAP` dictionary with a
decorator-based registry pattern:

```python
# Before: Manual dictionary mapping
ASCEND_QUANTIZATION_METHOD_MAP = {
    "W8A8_DYNAMIC": {"linear": AscendW8A8DynamicLinearMethod, ...},
    ...
}

# After: Decorator-based registration
@register_scheme("W8A8_DYNAMIC", "linear")
class AscendW8A8DynamicLinearMethod(AscendLinearScheme):
    ...
```

#### 3. **Abstract Base Classes**

Introduced three abstract base classes in `methods/base.py`:
- `AscendLinearScheme` - Base for linear layer quantization
- `AscendMoEScheme` - Base for MoE layer quantization  
- `AscendAttentionScheme` - Base for attention layer quantization

#### 4. **Separated Config and Wrapper Classes**

- **Config classes** (`AscendModelSlimConfig`,
`AscendCompressedTensorsConfig`): Handle config parsing and scheme
selection
- **Wrapper classes** (`AscendLinearMethod`, `AscendFusedMoEMethod`,
etc.): Implement vLLM interfaces and delegate to schemes

#### 5. **Cleaner Public API**

```python
# New clean module interface
from vllm_ascend.quantization import (
    AscendModelSlimConfig,
    AscendCompressedTensorsConfig,
)
from vllm_ascend.quantization.methods import get_scheme_class
```

### Architecture Diagram

```mermaid
classDiagram
    direction TB
    
    class QuantizationConfig {
        <<vLLM Interface>>
        +get_quant_method()
    }
    
    class AscendModelSlimConfig {
        +quant_description
        +get_quant_method()
        -create_scheme_for_layer()
    }
    
    class AscendCompressedTensorsConfig {
        +target_scheme_map
        +get_quant_method()
        -_get_scheme_from_parts()
    }
    
    class AscendLinearMethod {
        <<Wrapper>>
        +quant_method: AscendLinearScheme
        +create_weights()
        +apply()
    }
    
    class AscendFusedMoEMethod {
        <<Wrapper>>
        +quant_method: AscendMoEScheme
        +create_weights()
        +apply()
    }
    
    class AscendLinearScheme {
        <<Abstract>>
        +get_weight()*
        +apply()*
        +get_pertensor_param()
        +get_perchannel_param()
    }
    
    class AscendMoEScheme {
        <<Abstract>>
        +get_weight()*
        +get_dynamic_quant_param()*
        +apply()*
    }
    
    class W8A8DynamicLinear {
        +get_weight()
        +apply()
    }
    
    class W8A8DynamicMoE {
        +get_weight()
        +apply()
    }
    
    QuantizationConfig <|-- AscendModelSlimConfig
    QuantizationConfig <|-- AscendCompressedTensorsConfig
    
    AscendModelSlimConfig ..> AscendLinearMethod : creates
    AscendModelSlimConfig ..> AscendFusedMoEMethod : creates
    AscendCompressedTensorsConfig ..> AscendLinearMethod : creates
    AscendCompressedTensorsConfig ..> AscendFusedMoEMethod : creates
    
    AscendLinearMethod o-- AscendLinearScheme : delegates to
    AscendFusedMoEMethod o-- AscendMoEScheme : delegates to
    
    AscendLinearScheme <|-- W8A8DynamicLinear
    AscendMoEScheme <|-- W8A8DynamicMoE
```

### Scheme Registration Flow

```mermaid
sequenceDiagram
    participant Module as Scheme Module
    participant Registry as _SCHEME_REGISTRY
    participant Config as QuantConfig
    participant Wrapper as Wrapper Class
    
    Note over Module: At import time
    Module->>Registry: @register_scheme("W8A8_DYNAMIC", "linear")
    Registry->>Registry: Store (quant_type, layer_type) -> Class
    
    Note over Config: At runtime
    Config->>Config: Determine quant_type from description
    Config->>Registry: get_scheme_class(quant_type, layer_type)
    Registry-->>Config: Return scheme class
    Config->>Config: scheme = scheme_cls()
    Config->>Wrapper: Create wrapper with scheme
    Wrapper-->>Config: Return wrapper instance
```

### File Changes Summary

| Original Files | Refactored Files |
|----------------|------------------|
| `__init__.py` (empty) | `__init__.py` (exports public API) |
| `quant_config.py` | `modelslim_config.py` + `wrappers.py` |
| `compressed_tensors/` | `compressed_tensors_config.py` |
| `utils.py` | `methods/registry.py` |
| `w8a8_dynamic.py` | `methods/w8a8_dynamic.py` |
| `w8a8.py` | `methods/w8a8_static.py` |
| `w4a4_flatquant_dynamic.py` | `methods/w4a4_flatquant.py` |
| ... | `methods/base.py` (new) |

### Benefits

1. **Extensibility**: Adding new quantization schemes only requires
implementing the base class and adding `@register_scheme` decorator
2. **Maintainability**: Clear separation between config parsing, wrapper
logic, and scheme implementation
3. **Testability**: Abstract base classes enable easier unit testing and
mocking
4. **Discoverability**: Registry pattern makes it easy to list all
supported schemes
5. **Reduced Coupling**: Config classes no longer need to know about all
scheme implementations

___

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
2026-01-23 14:13:47 +08:00

289 lines
12 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import Callable, List, Optional
import torch
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.linear import (LinearMethodBase,
RowParallelLinear)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.parameter import PerTensorScaleParameter
from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import (get_flashcomm2_otp_group,
get_mlp_tp_group,
get_otp_group)
from vllm_ascend.utils import flashcomm2_enable, mlp_tp_enable, oproj_tp_enable
from .methods import (AscendAttentionScheme, AscendLinearScheme,
AscendMoEScheme, is_mx_quant_type)
class AscendLinearMethod(LinearMethodBase):
"""Linear method for Ascend quantization.
This wrapper class delegates to the actual quantization scheme implementation.
The scheme is determined by the Config class and passed directly to this wrapper.
Args:
scheme: The quantization scheme instance (e.g., AscendW8A8DynamicLinearMethod).
"""
def __init__(self, scheme: AscendLinearScheme) -> None:
self.quant_method = scheme
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
weight_dict = self.quant_method.get_weight(input_size_per_partition,
output_size_per_partition,
params_dtype)
# Extract packing information (if present)
packed_dim = weight_dict.pop("_packed_dim", None)
packed_factor = weight_dict.pop("_packed_factor", None)
for weight_name, weight_param in weight_dict.items():
param = torch.nn.Parameter(weight_param, requires_grad=False)
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
# Set packing attributes if the weight is packed
if packed_dim is not None and packed_factor is not None:
set_weight_attrs(param, {
"packed_dim": packed_dim,
"packed_factor": packed_factor
})
layer.register_parameter(weight_name, param)
set_weight_attrs(param, extra_weight_attrs)
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
for pertensor_name, pertensor_param in pertensor_dict.items():
param = PerTensorScaleParameter(data=pertensor_param,
weight_loader=weight_loader)
# disable warning
param.ignore_warning = True
layer.register_parameter(pertensor_name, param)
param.weight_loader = extra_weight_attrs.get("weight_loader")
perchannel_dict = self.quant_method.get_perchannel_param(
output_size_per_partition, params_dtype)
for perchannel_name, perchannel_param in perchannel_dict.items():
param = torch.nn.Parameter(perchannel_param, requires_grad=False)
set_weight_attrs(param, {"output_dim": 0})
layer.register_parameter(perchannel_name, param)
set_weight_attrs(param, extra_weight_attrs)
# NOTE: In w4a8 quantization implementation,
# for down_proj and o_proj scale_bias shape is [output_size, 16],
# others are [output_size, 1]
layer_type = "row" if isinstance(layer,
RowParallelLinear) else "others"
pergroup_dict = self.quant_method.get_pergroup_param(
input_size_per_partition,
output_size_per_partition,
params_dtype,
layer_type=layer_type)
for pergroup_name, pergroup_param in pergroup_dict.items():
param = torch.nn.Parameter(pergroup_param, requires_grad=False)
set_weight_attrs(param, {"output_dim": 0})
layer.register_parameter(pergroup_name, param)
set_weight_attrs(param, extra_weight_attrs)
if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name \
or is_mx_quant_type(self.quant_method):
setattr(param, "input_dim", 1)
param.input_dim = 1
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
self.quant_method.process_weights_after_loading(layer)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(layer, RowParallelLinear):
if layer.prefix.find("o_proj") != -1 and oproj_tp_enable():
tp_rank = get_otp_group().rank_in_group
elif layer.prefix.find("down_proj") != -1 and mlp_tp_enable():
tp_rank = get_mlp_tp_group().rank_in_group
elif (layer.prefix.find("o_proj") != -1 or
layer.prefix.find("out_proj") != -1) and flashcomm2_enable():
if get_ascend_config(
).flashcomm2_oproj_tensor_parallel_size == 1:
tp_rank = 0
else:
tp_rank = get_flashcomm2_otp_group().rank_in_group
else:
tp_rank = get_tensor_model_parallel_rank()
else:
tp_rank = 0
return self.quant_method.apply(layer, x, bias, tp_rank)
class AscendKVCacheMethod(BaseKVCacheMethod):
"""KVCache method for Ascend quantization.
This wrapper class delegates to the actual attention quantization scheme.
Args:
scheme: The attention quantization scheme instance.
"""
def __init__(self, scheme: AscendAttentionScheme) -> None:
self.quant_method = scheme
def create_weights(self, layer: torch.nn.Module) -> None:
# Different from linear method, there are no weight processing/slicing
# steps for attention in vllm. So the whole process of create weights
# is hidden into the specific quant method.
self.quant_method.create_weights(layer)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.quant_method.process_weights_after_loading(layer)
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata,
attn_type, scale, output) -> torch.Tensor:
return self.quant_method.apply(layer, query, key, value, kv_cache,
attn_metadata, attn_type, scale, output)
class AscendFusedMoEMethod(FusedMoEMethodBase):
"""FusedMoE method for Ascend quantization.
This wrapper class delegates to the actual MoE quantization scheme.
Args:
scheme: The MoE quantization scheme instance.
moe_config: The FusedMoE configuration.
"""
def __init__(self, scheme: AscendMoEScheme,
moe_config: FusedMoEConfig) -> None:
super().__init__(moe_config)
self.quant_method = scheme
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
weight_param = self.quant_method.get_weight(
num_experts, intermediate_size_per_partition, hidden_size,
params_dtype)
for param_key, param_value in weight_param.items():
param = torch.nn.Parameter(param_value, requires_grad=False)
layer.register_parameter(param_key, param)
set_weight_attrs(param, extra_weight_attrs)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
per_group_param = [
"weight_scale_second", "weight_offset_second", "scale_bias"
] + ["weight_scale", "weight_offset"] if hasattr(
self.quant_method,
"group_size") and self.quant_method.group_size > 0 else []
dynamic_quant_param = self.quant_method.get_dynamic_quant_param(
num_experts, intermediate_size_per_partition, hidden_size,
params_dtype)
for param_key, param_value in dynamic_quant_param.items():
param = torch.nn.Parameter(param_value, requires_grad=False)
layer.register_parameter(param_key, param)
set_weight_attrs(param, extra_weight_attrs)
if any(fields in param_key for fields in per_group_param):
setattr(param, "quant_method",
FusedMoeWeightScaleSupported.GROUP.value)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = False,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num=0,
**kwargs,
) -> torch.Tensor:
return self.quant_method.apply(
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
global_num_experts, expert_map, topk_group, num_expert_group,
custom_routing_function, scoring_func, routed_scaling_factor,
e_score_correction_bias, is_prefill, enable_force_load_balance,
log2phy, global_redundant_expert_num, **kwargs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
self.quant_method.process_weights_after_loading(layer)
def get_fused_moe_quant_config(self, layer: torch.nn.Module):
pass
@property
def supports_eplb(self):
supports_eplb = getattr(self.quant_method, "supports_eplb", False)
return supports_eplb
class AscendEmbeddingMethod(AscendLinearMethod):
"""Embedding method for Ascend quantization.
This is essentially the same as AscendLinearMethod, just with a different name
for clarity when used with VocabParallelEmbedding layers.
Args:
scheme: The quantization scheme instance.
"""
def __init__(self, scheme: AscendLinearScheme) -> None:
self.quant_method = scheme