2025-02-21 17:07:37 +08:00
|
|
|
#
|
|
|
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
|
|
|
# Copyright 2023 The vLLM team.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
2025-04-17 14:59:56 +08:00
|
|
|
# This file is a part of the vllm-ascend project.
|
2025-02-21 17:07:37 +08:00
|
|
|
#
|
|
|
|
|
from types import MappingProxyType
|
2025-04-07 10:56:12 +08:00
|
|
|
from typing import Any, Callable, Dict, List, Mapping, Optional
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
import torch
|
2025-09-19 20:50:14 +08:00
|
|
|
from vllm.config import get_current_vllm_config
|
2025-02-21 17:07:37 +08:00
|
|
|
from vllm.distributed import get_tensor_model_parallel_rank
|
2025-04-07 10:56:12 +08:00
|
|
|
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
|
|
|
|
FusedMoeWeightScaleSupported)
|
2025-02-21 17:07:37 +08:00
|
|
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
2025-10-14 17:39:26 +08:00
|
|
|
RowParallelLinear)
|
2025-02-21 17:07:37 +08:00
|
|
|
from vllm.model_executor.layers.quantization import \
|
|
|
|
|
register_quantization_config
|
|
|
|
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
|
|
|
QuantizationConfig, QuantizeMethodBase)
|
|
|
|
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
2025-07-29 18:51:57 +08:00
|
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
|
|
|
UnquantizedEmbeddingMethod, VocabParallelEmbedding)
|
2025-10-14 17:31:26 +08:00
|
|
|
from vllm.model_executor.parameter import PerTensorScaleParameter
|
2025-04-07 10:56:12 +08:00
|
|
|
from vllm.model_executor.utils import set_weight_attrs
|
2025-02-21 17:07:37 +08:00
|
|
|
|
2025-09-07 10:31:32 +08:00
|
|
|
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
|
|
|
|
|
get_otp_group)
|
2025-10-09 14:12:46 +08:00
|
|
|
from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod
|
2025-10-14 17:39:26 +08:00
|
|
|
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
|
2025-09-07 10:31:32 +08:00
|
|
|
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable,
|
|
|
|
|
oproj_tp_enable)
|
2025-04-21 19:25:51 +08:00
|
|
|
|
2025-09-04 11:35:14 +08:00
|
|
|
from .utils import get_quant_method
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
|
2025-08-26 09:06:16 +08:00
|
|
|
@register_quantization_config(ASCEND_QUANTIZATION_METHOD)
|
2025-02-21 17:07:37 +08:00
|
|
|
class AscendQuantConfig(QuantizationConfig):
|
2025-03-06 15:17:25 +08:00
|
|
|
"""Config class for Ascend
|
2025-08-20 20:25:18 +08:00
|
|
|
|
2025-03-06 15:17:25 +08:00
|
|
|
This class is a general class that parse quantization configs
|
|
|
|
|
that are supported on ascend hardware.
|
|
|
|
|
"""
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
def __init__(self, quant_config: Dict[str, Any]):
|
2025-09-11 16:40:51 +08:00
|
|
|
super().__init__()
|
2025-02-21 17:07:37 +08:00
|
|
|
self.quant_description = quant_config
|
2025-10-21 20:17:09 +08:00
|
|
|
# TODO(whx): remove this adaptation after adding "shared_head"
|
|
|
|
|
# to prefix of DeepSeekShareHead in vLLM.
|
|
|
|
|
extra_quant_dict = {}
|
|
|
|
|
for k in self.quant_description.keys():
|
|
|
|
|
if "shared_head" in k:
|
|
|
|
|
new_k = k.replace(".shared_head.", ".")
|
|
|
|
|
extra_quant_dict[new_k] = self.quant_description[k]
|
|
|
|
|
self.quant_description.update(extra_quant_dict)
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
|
return "AscendQuantConfig:\n" + super().__repr__()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_name(cls) -> str:
|
2025-08-26 09:06:16 +08:00
|
|
|
return ASCEND_QUANTIZATION_METHOD
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
|
|
|
|
return [torch.int8, torch.float16, torch.bfloat16]
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_min_capability(cls) -> int:
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"Ascend hardware dose not support \"get_min_capability\" feature.")
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_config_filenames(cls) -> List[str]:
|
2025-04-30 16:51:56 +08:00
|
|
|
return ["quant_model_description.json"]
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_config(cls, config: Dict[str, Any]) -> "AscendQuantConfig":
|
|
|
|
|
return cls(config)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def override_quantization_method(cls, hf_quant_cfg,
|
|
|
|
|
user_quant) -> Optional[str]:
|
|
|
|
|
if torch.npu.is_available():
|
2025-08-26 09:06:16 +08:00
|
|
|
return ASCEND_QUANTIZATION_METHOD
|
2025-02-21 17:07:37 +08:00
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def get_quant_method(self, layer: torch.nn.Module,
|
|
|
|
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
2025-09-19 20:50:14 +08:00
|
|
|
vllm_config = get_current_vllm_config()
|
|
|
|
|
model_type = vllm_config.model_config.hf_config.model_type
|
|
|
|
|
if model_type in packed_modules_model_mapping:
|
|
|
|
|
self.packed_modules_mapping = packed_modules_model_mapping[
|
|
|
|
|
model_type]
|
2025-02-21 17:07:37 +08:00
|
|
|
from vllm.attention.layer import Attention
|
2025-09-11 16:40:51 +08:00
|
|
|
if prefix.startswith("language_model"):
|
|
|
|
|
prefix = prefix.split('.', 1)[-1]
|
2025-02-21 17:07:37 +08:00
|
|
|
if isinstance(layer, LinearBase):
|
|
|
|
|
if self.is_layer_skipped_ascend(prefix,
|
|
|
|
|
self.packed_modules_mapping):
|
2025-10-14 17:39:26 +08:00
|
|
|
return AscendUnquantizedLinearMethod()
|
2025-03-12 11:33:21 +08:00
|
|
|
return AscendLinearMethod(self, prefix,
|
|
|
|
|
self.packed_modules_mapping)
|
2025-04-07 10:56:12 +08:00
|
|
|
elif isinstance(layer, Attention) and \
|
|
|
|
|
'fa_quant_type' in self.quant_description.keys() and \
|
|
|
|
|
self.quant_description['fa_quant_type'] is not None:
|
2025-03-06 15:17:25 +08:00
|
|
|
return AscendKVCacheMethod(self, prefix)
|
2025-06-28 18:51:07 +08:00
|
|
|
elif isinstance(layer, Attention) and self.quant_description.get(
|
|
|
|
|
'kv_quant_type') == 'C8':
|
|
|
|
|
return AscendKVCacheMethod(self, prefix)
|
2025-04-07 10:56:12 +08:00
|
|
|
elif isinstance(layer, FusedMoE):
|
|
|
|
|
if self.is_layer_skipped_ascend(prefix,
|
|
|
|
|
self.packed_modules_mapping):
|
2025-08-22 17:09:08 +08:00
|
|
|
return AscendUnquantizedFusedMoEMethod(layer.moe_config)
|
2025-04-07 10:56:12 +08:00
|
|
|
return AscendFusedMoEMethod(self, prefix,
|
|
|
|
|
self.packed_modules_mapping)
|
2025-07-29 18:51:57 +08:00
|
|
|
elif isinstance(layer, VocabParallelEmbedding):
|
|
|
|
|
if self.is_layer_skipped_ascend(prefix,
|
|
|
|
|
self.packed_modules_mapping):
|
|
|
|
|
return UnquantizedEmbeddingMethod()
|
|
|
|
|
return AscendEmbeddingMethod(self, prefix,
|
|
|
|
|
self.packed_modules_mapping)
|
2025-02-21 17:07:37 +08:00
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def is_layer_skipped_ascend(
|
|
|
|
|
self,
|
|
|
|
|
prefix: str,
|
|
|
|
|
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})):
|
|
|
|
|
# adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped
|
|
|
|
|
proj_name = prefix.split(".")[-1]
|
|
|
|
|
if proj_name in fused_mapping:
|
|
|
|
|
shard_prefixes = [
|
|
|
|
|
prefix.replace(proj_name, shard_proj_name)
|
|
|
|
|
for shard_proj_name in fused_mapping[proj_name]
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
is_skipped = None
|
|
|
|
|
for shard_prefix in shard_prefixes:
|
|
|
|
|
is_shard_skipped = self.quant_description[shard_prefix +
|
|
|
|
|
'.weight'] == "FLOAT"
|
|
|
|
|
|
|
|
|
|
if is_skipped is None:
|
|
|
|
|
is_skipped = is_shard_skipped
|
|
|
|
|
elif is_shard_skipped != is_skipped:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Detected some but not all shards of {prefix} "
|
|
|
|
|
"are quantized. All shards of fused layers "
|
|
|
|
|
"to have the same precision.")
|
[bugfix] Fixed the bug in retrieving the quantization method for mlp.… (#4797)
When retrieving the quantization method for MOE (e.g., the quantization
file of DeepSeek v3.2 exp do not match the model's naming convention in
eager mode), a KeyError is raised: "model.layers.3.mlp.experts.weight
not in self.quant_description". However the quantization file is like :
```bash
"model.layers.3.mlp.experts.255.gate_proj.weight": "W8A8_DYNAMIC",
"model.layers.3.mlp.experts.255.gate_proj.weight_scale": "W8A8_DYNAMIC",
"model.layers.3.mlp.experts.255.gate_proj.weight_offset": "W8A8_DYNAMIC",
"model.layers.3.mlp.experts.255.down_proj.weight": "W8A8_DYNAMIC",
"model.layers.3.mlp.experts.255.down_proj.weight_scale": "W8A8_DYNAMIC",
"model.layers.3.mlp.experts.255.down_proj.weight_offset": "W8A8_DYNAMIC",
"model.layers.3.mlp.experts.255.up_proj.weight": "W8A8_DYNAMIC",
"model.layers.3.mlp.experts.255.up_proj.weight_scale": "W8A8_DYNAMIC",
"model.layers.3.mlp.experts.255.up_proj.weight_offset": "W8A8_DYNAMIC",
```
Co-Authored-By: yangqinghao-cmss <yangqinghao_yewu@cmss.chinamobile.com>
Signed-off-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: yangqinghao-cmss <yangqinghao_yewu@cmss.chinamobile.com>
2025-12-09 08:47:19 +08:00
|
|
|
elif "experts" in prefix:
|
|
|
|
|
# For the experts' prefix (e.g., "model.layers.3.mlp.experts")
|
|
|
|
|
# Assume all experts within the same MLP use the same quantization method
|
|
|
|
|
experts_quant_description = [
|
|
|
|
|
self.quant_description[layer]
|
|
|
|
|
for layer in self.quant_description if prefix in layer
|
|
|
|
|
]
|
|
|
|
|
is_skipped = any(quantization == "FLOAT"
|
|
|
|
|
for quantization in experts_quant_description)
|
2025-02-21 17:07:37 +08:00
|
|
|
else:
|
|
|
|
|
is_skipped = self.quant_description[prefix + '.weight'] == "FLOAT"
|
|
|
|
|
|
|
|
|
|
assert is_skipped is not None
|
|
|
|
|
return is_skipped
|
|
|
|
|
|
|
|
|
|
def get_scaled_act_names(self) -> List[str]:
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
2025-09-19 20:50:14 +08:00
|
|
|
packed_modules_model_mapping = {
|
|
|
|
|
"qwen3_moe": {
|
|
|
|
|
"qkv_proj": [
|
|
|
|
|
"q_proj",
|
|
|
|
|
"k_proj",
|
|
|
|
|
"v_proj",
|
|
|
|
|
],
|
|
|
|
|
"gate_up_proj": [
|
|
|
|
|
"gate_proj",
|
|
|
|
|
"up_proj",
|
|
|
|
|
],
|
|
|
|
|
"experts":
|
|
|
|
|
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
|
|
|
|
},
|
|
|
|
|
"deepseek_v2": {
|
|
|
|
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
|
|
|
|
"experts":
|
2025-10-20 15:31:34 +08:00
|
|
|
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
|
|
|
|
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
|
2025-09-19 20:50:14 +08:00
|
|
|
},
|
|
|
|
|
"deepseek_v3": {
|
|
|
|
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
|
|
|
|
"experts":
|
2025-10-20 15:31:34 +08:00
|
|
|
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
|
|
|
|
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
|
2025-10-15 17:48:58 +08:00
|
|
|
},
|
2025-11-14 15:43:22 +08:00
|
|
|
"kimi_k2": {
|
|
|
|
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
|
|
|
|
"experts":
|
|
|
|
|
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
|
|
|
|
},
|
2025-10-15 17:48:58 +08:00
|
|
|
"deepseek_v32": {
|
|
|
|
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
|
|
|
|
"experts":
|
|
|
|
|
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
2025-09-19 20:50:14 +08:00
|
|
|
},
|
|
|
|
|
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
|
|
|
|
|
# NOTE 2.The description file generated by the current msmodelslim tool does not have
|
|
|
|
|
# MTP layer info. Please manually add it and set the value to FLOAT.
|
|
|
|
|
"deepseek_mtp": {
|
|
|
|
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
|
|
|
|
"experts":
|
|
|
|
|
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
|
|
|
|
},
|
|
|
|
|
"qwen3_next": {
|
|
|
|
|
"qkv_proj": [
|
|
|
|
|
"q_proj",
|
|
|
|
|
"k_proj",
|
|
|
|
|
"v_proj",
|
|
|
|
|
],
|
|
|
|
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
|
|
|
|
"in_proj": ["in_proj_qkvz", "in_proj_ba"],
|
|
|
|
|
},
|
|
|
|
|
"qwen2_5_vl": {
|
|
|
|
|
"qkv_proj": [
|
|
|
|
|
"q_proj",
|
|
|
|
|
"k_proj",
|
|
|
|
|
"v_proj",
|
|
|
|
|
],
|
|
|
|
|
"gate_up_proj": [
|
|
|
|
|
"gate_proj",
|
|
|
|
|
"up_proj",
|
|
|
|
|
],
|
2025-09-25 11:13:29 +08:00
|
|
|
},
|
|
|
|
|
"glm4_moe": {
|
|
|
|
|
"qkv_proj": [
|
|
|
|
|
"q_proj",
|
|
|
|
|
"k_proj",
|
|
|
|
|
"v_proj",
|
|
|
|
|
],
|
|
|
|
|
"gate_up_proj": [
|
|
|
|
|
"gate_proj",
|
|
|
|
|
"up_proj",
|
|
|
|
|
],
|
|
|
|
|
"experts":
|
|
|
|
|
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
|
|
|
|
},
|
2025-09-19 20:50:14 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
class AscendLinearMethod(LinearMethodBase):
|
|
|
|
|
"""Linear method for Ascend quantization.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
quant_config: The Ascend quantization config.
|
|
|
|
|
"""
|
|
|
|
|
|
2025-03-12 11:33:21 +08:00
|
|
|
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
|
|
|
|
packed_modules_mapping: Dict[str, Any]) -> None:
|
2025-09-04 11:35:14 +08:00
|
|
|
self.quant_method = get_quant_method(quant_config.quant_description,
|
|
|
|
|
prefix, "linear",
|
|
|
|
|
packed_modules_mapping)
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
def create_weights(
|
|
|
|
|
self,
|
|
|
|
|
layer: torch.nn.Module,
|
|
|
|
|
input_size_per_partition: int,
|
|
|
|
|
output_partition_sizes: List[int],
|
|
|
|
|
input_size: int,
|
|
|
|
|
output_size: int,
|
|
|
|
|
params_dtype: torch.dtype,
|
|
|
|
|
**extra_weight_attrs,
|
|
|
|
|
) -> None:
|
|
|
|
|
output_size_per_partition = sum(output_partition_sizes)
|
2025-10-14 17:31:26 +08:00
|
|
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
2025-02-21 17:07:37 +08:00
|
|
|
|
2025-03-06 15:17:25 +08:00
|
|
|
weight_dict = self.quant_method.get_weight(input_size_per_partition,
|
2025-02-21 17:07:37 +08:00
|
|
|
output_size_per_partition,
|
|
|
|
|
params_dtype)
|
2025-10-21 20:18:39 +08:00
|
|
|
|
|
|
|
|
# Extract packing information (if present)
|
|
|
|
|
packed_dim = weight_dict.pop("_packed_dim", None)
|
|
|
|
|
packed_factor = weight_dict.pop("_packed_factor", None)
|
|
|
|
|
|
2025-03-06 15:17:25 +08:00
|
|
|
for weight_name, weight_param in weight_dict.items():
|
2025-04-19 17:38:18 +08:00
|
|
|
param = torch.nn.Parameter(weight_param, requires_grad=False)
|
|
|
|
|
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
|
2025-10-21 20:18:39 +08:00
|
|
|
|
|
|
|
|
# Set packing attributes if the weight is packed
|
|
|
|
|
if packed_dim is not None and packed_factor is not None:
|
|
|
|
|
set_weight_attrs(param, {
|
|
|
|
|
"packed_dim": packed_dim,
|
|
|
|
|
"packed_factor": packed_factor
|
|
|
|
|
})
|
|
|
|
|
|
2025-04-19 17:38:18 +08:00
|
|
|
layer.register_parameter(weight_name, param)
|
|
|
|
|
set_weight_attrs(param, extra_weight_attrs)
|
2025-03-06 15:17:25 +08:00
|
|
|
|
|
|
|
|
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
|
|
|
|
|
for pertensor_name, pertensor_param in pertensor_dict.items():
|
2025-10-14 17:31:26 +08:00
|
|
|
param = PerTensorScaleParameter(data=pertensor_param,
|
|
|
|
|
weight_loader=weight_loader)
|
2025-03-06 15:17:25 +08:00
|
|
|
# disable warning
|
|
|
|
|
param.ignore_warning = True
|
|
|
|
|
layer.register_parameter(pertensor_name, param)
|
2025-10-20 15:31:34 +08:00
|
|
|
param.weight_loader = extra_weight_attrs.get("weight_loader")
|
2025-03-06 15:17:25 +08:00
|
|
|
|
|
|
|
|
perchannel_dict = self.quant_method.get_perchannel_param(
|
|
|
|
|
output_size_per_partition, params_dtype)
|
|
|
|
|
for perchannel_name, perchannel_param in perchannel_dict.items():
|
2025-04-19 17:38:18 +08:00
|
|
|
param = torch.nn.Parameter(perchannel_param, requires_grad=False)
|
|
|
|
|
set_weight_attrs(param, {"output_dim": 0})
|
|
|
|
|
layer.register_parameter(perchannel_name, param)
|
|
|
|
|
set_weight_attrs(param, extra_weight_attrs)
|
2025-02-21 17:07:37 +08:00
|
|
|
|
2025-10-21 20:18:39 +08:00
|
|
|
# NOTE: In w4a8 quantization implementation,
|
|
|
|
|
# for down_proj and o_proj scale_bias shape is [output_size, 16],
|
|
|
|
|
# others are [output_size, 1]
|
|
|
|
|
layer_type = "row" if isinstance(layer,
|
|
|
|
|
RowParallelLinear) else "others"
|
|
|
|
|
|
2025-07-30 14:57:14 +08:00
|
|
|
pergroup_dict = self.quant_method.get_pergroup_param(
|
2025-10-21 20:18:39 +08:00
|
|
|
input_size_per_partition,
|
|
|
|
|
output_size_per_partition,
|
|
|
|
|
params_dtype,
|
|
|
|
|
layer_type=layer_type)
|
2025-07-30 14:57:14 +08:00
|
|
|
for pergroup_name, pergroup_param in pergroup_dict.items():
|
|
|
|
|
param = torch.nn.Parameter(pergroup_param, requires_grad=False)
|
|
|
|
|
set_weight_attrs(param, {"output_dim": 0})
|
|
|
|
|
layer.register_parameter(pergroup_name, param)
|
|
|
|
|
set_weight_attrs(param, extra_weight_attrs)
|
|
|
|
|
if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name:
|
|
|
|
|
setattr(param, "input_dim", 1)
|
|
|
|
|
param.input_dim = 1
|
|
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
2025-03-06 15:17:25 +08:00
|
|
|
if hasattr(self.quant_method, "process_weights_after_loading"):
|
|
|
|
|
self.quant_method.process_weights_after_loading(layer)
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
def apply(
|
|
|
|
|
self,
|
|
|
|
|
layer: torch.nn.Module,
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
bias: Optional[torch.Tensor] = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
if isinstance(layer, RowParallelLinear):
|
2025-09-07 10:31:32 +08:00
|
|
|
if layer.prefix.find("o_proj") != -1 and oproj_tp_enable():
|
|
|
|
|
tp_rank = get_otp_group().rank_in_group
|
|
|
|
|
elif layer.prefix.find("down_proj") != -1 and mlp_tp_enable():
|
|
|
|
|
tp_rank = get_mlp_tp_group().rank_in_group
|
|
|
|
|
else:
|
|
|
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
|
|
|
else:
|
|
|
|
|
tp_rank = 0
|
|
|
|
|
return self.quant_method.apply(layer, x, bias, tp_rank)
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
|
2025-03-06 15:17:25 +08:00
|
|
|
class AscendKVCacheMethod(BaseKVCacheMethod):
|
|
|
|
|
"""KVCache method for Ascend quantization.
|
|
|
|
|
|
2025-02-21 17:07:37 +08:00
|
|
|
Args:
|
|
|
|
|
quant_config: The Ascend quantization config.
|
|
|
|
|
"""
|
|
|
|
|
|
2025-03-06 15:17:25 +08:00
|
|
|
def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None:
|
2025-09-04 11:35:14 +08:00
|
|
|
self.quant_method = get_quant_method(quant_config.quant_description,
|
|
|
|
|
prefix, "attention")
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
def create_weights(self, layer: torch.nn.Module) -> None:
|
2025-03-06 15:17:25 +08:00
|
|
|
# Different from linear method, there are no weight processing/slicing
|
|
|
|
|
# steps for attention in vllm. So the whole process of create weights
|
|
|
|
|
# is hidden into the specific quant method.
|
|
|
|
|
self.quant_method.create_weights(layer)
|
2025-02-21 17:07:37 +08:00
|
|
|
|
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
|
|
|
if hasattr(self.quant_method, "process_weights_after_loading"):
|
|
|
|
|
self.quant_method.process_weights_after_loading(layer)
|
|
|
|
|
|
2025-06-28 18:51:07 +08:00
|
|
|
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata,
|
|
|
|
|
attn_type, scale, output) -> torch.Tensor:
|
|
|
|
|
return self.quant_method.apply(layer, query, key, value, kv_cache,
|
|
|
|
|
attn_metadata, attn_type, scale, output)
|
2025-04-07 10:56:12 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class AscendFusedMoEMethod(FusedMoEMethodBase):
|
|
|
|
|
"""FusedMoE method for Ascend quantization.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
quant_config: The Ascend quantization config.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
|
|
|
|
packed_modules_mapping: Dict[str, Any]):
|
2025-09-04 11:35:14 +08:00
|
|
|
self.quant_method = get_quant_method(quant_config.quant_description,
|
|
|
|
|
prefix, "moe",
|
|
|
|
|
packed_modules_mapping)
|
2025-04-07 10:56:12 +08:00
|
|
|
|
|
|
|
|
def create_weights(
|
|
|
|
|
self,
|
|
|
|
|
layer: torch.nn.Module,
|
|
|
|
|
num_experts: int,
|
|
|
|
|
hidden_size: int,
|
|
|
|
|
intermediate_size_per_partition: int,
|
|
|
|
|
params_dtype: torch.dtype,
|
|
|
|
|
**extra_weight_attrs,
|
|
|
|
|
) -> None:
|
|
|
|
|
weight_param = self.quant_method.get_weight(
|
|
|
|
|
num_experts, intermediate_size_per_partition, hidden_size,
|
|
|
|
|
params_dtype)
|
|
|
|
|
for param_key, param_value in weight_param.items():
|
|
|
|
|
param = torch.nn.Parameter(param_value, requires_grad=False)
|
|
|
|
|
layer.register_parameter(param_key, param)
|
|
|
|
|
set_weight_attrs(param, extra_weight_attrs)
|
|
|
|
|
|
|
|
|
|
extra_weight_attrs.update(
|
|
|
|
|
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
|
2025-08-20 20:25:18 +08:00
|
|
|
per_group_param = [
|
|
|
|
|
"weight_scale_second", "weight_offset_second", "scale_bias"
|
|
|
|
|
]
|
2025-04-07 10:56:12 +08:00
|
|
|
dynamic_quant_param = self.quant_method.get_dynamic_quant_param(
|
|
|
|
|
num_experts, intermediate_size_per_partition, hidden_size,
|
|
|
|
|
params_dtype)
|
|
|
|
|
for param_key, param_value in dynamic_quant_param.items():
|
|
|
|
|
param = torch.nn.Parameter(param_value, requires_grad=False)
|
|
|
|
|
layer.register_parameter(param_key, param)
|
|
|
|
|
set_weight_attrs(param, extra_weight_attrs)
|
2025-08-20 20:25:18 +08:00
|
|
|
if any(fields in param_key for fields in per_group_param):
|
2025-08-06 10:17:44 +08:00
|
|
|
setattr(param, "quant_method",
|
|
|
|
|
FusedMoeWeightScaleSupported.GROUP.value)
|
2025-04-07 10:56:12 +08:00
|
|
|
|
|
|
|
|
def apply(
|
|
|
|
|
self,
|
|
|
|
|
layer: torch.nn.Module,
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
router_logits: torch.Tensor,
|
[quantization] Support w8a8 quantization (#580)
### What this PR does / why we need it?
Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on
linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model
has [quantize
filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27).
If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply,
otherwise will use VLLMAscendQuantizer directly.
- This patch fix installation docs to make installation work
- This patch enable norm quantization by patch `RMSNorm.__init__`,
`RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model`
- Add `AscendW8A8LinearMethod` for W8A8
- Add `AscendW8A8DynamicLinearMethod` and
`AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC
- Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8`
### Does this PR introduce _any_ user-facing change?
Yes, support w8a8 quantization. After this patch supported, users can
use below commands to run w8a8 models:
```
vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B"
```
### How was this patch tested?
0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8`
1. From @Yikun:
I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls
refer to
https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613
2. From @dingdingchaomian :
Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both
models were quantized using Ascend's msmodelslim tool:
- Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one
for w8a8 dynamic.
- Deepseek-v2-lite-chat were tested once because its quantization used
both static and dynamic w8a8.
Models were tested using both off line inference and online serving, and
both work well. The inference codes are exactly the same with the
examples in
https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with
model path and tensor parallel number changed.
---------
Signed-off-by: dingdingchaomian <wangce21@huawei.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Co-authored-by: dingdingchaomian <wangce21@huawei.com>
Co-authored-by: Angazenn <zengyanjia@huawei.com>
Co-authored-by: liujiaxu <liujiaxu4@huawei.com>
Co-authored-by: ApsarasX <apsarax@outlook.com>
Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
|
|
|
top_k: int,
|
2025-04-07 10:56:12 +08:00
|
|
|
renormalize: bool,
|
[quantization] Support w8a8 quantization (#580)
### What this PR does / why we need it?
Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on
linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model
has [quantize
filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27).
If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply,
otherwise will use VLLMAscendQuantizer directly.
- This patch fix installation docs to make installation work
- This patch enable norm quantization by patch `RMSNorm.__init__`,
`RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model`
- Add `AscendW8A8LinearMethod` for W8A8
- Add `AscendW8A8DynamicLinearMethod` and
`AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC
- Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8`
### Does this PR introduce _any_ user-facing change?
Yes, support w8a8 quantization. After this patch supported, users can
use below commands to run w8a8 models:
```
vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B"
```
### How was this patch tested?
0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8`
1. From @Yikun:
I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls
refer to
https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613
2. From @dingdingchaomian :
Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both
models were quantized using Ascend's msmodelslim tool:
- Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one
for w8a8 dynamic.
- Deepseek-v2-lite-chat were tested once because its quantization used
both static and dynamic w8a8.
Models were tested using both off line inference and online serving, and
both work well. The inference codes are exactly the same with the
examples in
https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with
model path and tensor parallel number changed.
---------
Signed-off-by: dingdingchaomian <wangce21@huawei.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Co-authored-by: dingdingchaomian <wangce21@huawei.com>
Co-authored-by: Angazenn <zengyanjia@huawei.com>
Co-authored-by: liujiaxu <liujiaxu4@huawei.com>
Co-authored-by: ApsarasX <apsarax@outlook.com>
Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
|
|
|
use_grouped_topk: bool = False,
|
|
|
|
|
global_num_experts: int = -1,
|
|
|
|
|
expert_map: Optional[torch.Tensor] = None,
|
2025-04-23 16:23:25 +08:00
|
|
|
topk_group: Optional[int] = None,
|
|
|
|
|
num_expert_group: Optional[int] = None,
|
2025-04-07 10:56:12 +08:00
|
|
|
custom_routing_function: Optional[Callable] = None,
|
|
|
|
|
scoring_func: str = "softmax",
|
[quantization] Support w8a8 quantization (#580)
### What this PR does / why we need it?
Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on
linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model
has [quantize
filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27).
If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply,
otherwise will use VLLMAscendQuantizer directly.
- This patch fix installation docs to make installation work
- This patch enable norm quantization by patch `RMSNorm.__init__`,
`RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model`
- Add `AscendW8A8LinearMethod` for W8A8
- Add `AscendW8A8DynamicLinearMethod` and
`AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC
- Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8`
### Does this PR introduce _any_ user-facing change?
Yes, support w8a8 quantization. After this patch supported, users can
use below commands to run w8a8 models:
```
vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B"
```
### How was this patch tested?
0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8`
1. From @Yikun:
I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls
refer to
https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613
2. From @dingdingchaomian :
Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both
models were quantized using Ascend's msmodelslim tool:
- Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one
for w8a8 dynamic.
- Deepseek-v2-lite-chat were tested once because its quantization used
both static and dynamic w8a8.
Models were tested using both off line inference and online serving, and
both work well. The inference codes are exactly the same with the
examples in
https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with
model path and tensor parallel number changed.
---------
Signed-off-by: dingdingchaomian <wangce21@huawei.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Co-authored-by: dingdingchaomian <wangce21@huawei.com>
Co-authored-by: Angazenn <zengyanjia@huawei.com>
Co-authored-by: liujiaxu <liujiaxu4@huawei.com>
Co-authored-by: ApsarasX <apsarax@outlook.com>
Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
|
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
2025-04-23 16:23:25 +08:00
|
|
|
is_prefill: bool = True,
|
2025-05-15 09:19:55 +08:00
|
|
|
enable_force_load_balance: bool = False,
|
2025-06-09 19:28:11 +08:00
|
|
|
log2phy: torch.Tensor = None,
|
|
|
|
|
global_redundant_expert_num=0,
|
[quantization] Support w8a8 quantization (#580)
### What this PR does / why we need it?
Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on
linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model
has [quantize
filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27).
If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply,
otherwise will use VLLMAscendQuantizer directly.
- This patch fix installation docs to make installation work
- This patch enable norm quantization by patch `RMSNorm.__init__`,
`RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model`
- Add `AscendW8A8LinearMethod` for W8A8
- Add `AscendW8A8DynamicLinearMethod` and
`AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC
- Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8`
### Does this PR introduce _any_ user-facing change?
Yes, support w8a8 quantization. After this patch supported, users can
use below commands to run w8a8 models:
```
vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B"
```
### How was this patch tested?
0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8`
1. From @Yikun:
I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls
refer to
https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613
2. From @dingdingchaomian :
Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both
models were quantized using Ascend's msmodelslim tool:
- Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one
for w8a8 dynamic.
- Deepseek-v2-lite-chat were tested once because its quantization used
both static and dynamic w8a8.
Models were tested using both off line inference and online serving, and
both work well. The inference codes are exactly the same with the
examples in
https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with
model path and tensor parallel number changed.
---------
Signed-off-by: dingdingchaomian <wangce21@huawei.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Co-authored-by: dingdingchaomian <wangce21@huawei.com>
Co-authored-by: Angazenn <zengyanjia@huawei.com>
Co-authored-by: liujiaxu <liujiaxu4@huawei.com>
Co-authored-by: ApsarasX <apsarax@outlook.com>
Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00
|
|
|
**kwargs,
|
2025-04-07 10:56:12 +08:00
|
|
|
) -> torch.Tensor:
|
2025-05-15 09:19:55 +08:00
|
|
|
return self.quant_method.apply(
|
|
|
|
|
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
|
|
|
|
|
global_num_experts, expert_map, topk_group, num_expert_group,
|
|
|
|
|
custom_routing_function, scoring_func, e_score_correction_bias,
|
2025-06-09 19:28:11 +08:00
|
|
|
is_prefill, enable_force_load_balance, log2phy,
|
|
|
|
|
global_redundant_expert_num, **kwargs)
|
2025-04-08 09:15:56 +08:00
|
|
|
|
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
|
|
|
if hasattr(self.quant_method, "process_weights_after_loading"):
|
|
|
|
|
self.quant_method.process_weights_after_loading(layer)
|
2025-07-29 18:51:57 +08:00
|
|
|
|
2025-09-20 17:37:57 +08:00
|
|
|
def get_fused_moe_quant_config(self, layer: torch.nn.Module):
|
|
|
|
|
# TODO: implement this function
|
|
|
|
|
pass
|
|
|
|
|
|
2025-07-29 18:51:57 +08:00
|
|
|
|
|
|
|
|
class AscendEmbeddingMethod(AscendLinearMethod):
|
|
|
|
|
"""Embedding method for Ascend quantization.
|
2025-09-04 11:35:14 +08:00
|
|
|
|
2025-07-29 18:51:57 +08:00
|
|
|
Args:
|
|
|
|
|
quant_config: The Ascend quantization config.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
|
|
|
|
packed_modules_mapping: Dict[str, Any]) -> None:
|
2025-09-04 11:35:14 +08:00
|
|
|
self.quant_method = get_quant_method(quant_config.quant_description,
|
|
|
|
|
prefix, "linear",
|
|
|
|
|
packed_modules_mapping)
|