### What this PR does / why we need it?
The deepseek w4a8 weights we supported before were in mindie-format
format. It uses int8 to represent int4, so the weight size is similar to
w8a8, and we need to do a few extra steps to make vllm-ascend load it
normally.
Now we can directly use the new weight format, which uses two int4 packs
to save the weight, the weight size is reduced, and there is no need to
do many extra operations to directly use it on vllm-ascend, but we are
also compatible with the weights of the previous mindie format.
The weight changes in the new version:
1. The weight is packed (2 int4 pack to int8)
2. The bias required in the apply method is directly generated by
modelslim
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
Adding ut case in `tests/ut/quantization/test_w4a8_dynamic.py`
#### 1.How to get weights using Modelslim
##### Installation steps
we can use the branch br_release_MindStudio_8.1.RC2_TR5_20260624
git clone -b br_release_MindStudio_8.1.RC2_TR5_20260624
https://gitee.com/ascend/msit.git
cd msit/msmodelslim
bash install.sh
##### Generate w4a8 weights
cd /example/DeepSeek
Command reference: msmodelslim/example/DeepSeek/README.md Execute the
[pre-check](https://gitee.com/ascend/msit/blob/br_release_MindStudio_8.1.RC2_TR5_20260624/msmodelslim/example/DeepSeek/README.md#%E8%BF%90%E8%A1%8C%E5%89%8D%E5%BF%85%E6%A3%80)
and [DeepSeek-R1 w4a8 mix
quantization](https://gitee.com/ascend/msit/blob/br_release_MindStudio_8.1.RC2_TR5_20260624/msmodelslim/example/DeepSeek/README.md#deepseek-r1-w4a8-%E6%B7%B7%E5%90%88%E9%87%8F%E5%8C%96%E5%89%8D%E4%B8%89%E5%B1%82-mlpw8a8-dynamic-%E9%87%8F%E5%8C%96mla%E5%85%B1%E4%BA%AB%E4%B8%93%E5%AE%B6w8a8%E9%87%8F%E5%8C%96%E8%B7%AF%E7%94%B1%E4%B8%93%E5%AE%B6w4a8-dynamic%E9%87%8F%E5%8C%96)
chapter
Reference command:python3 quant_deepseek_w4a8.py --model_path {Original
weight path} --save_path {Generate weight path}
##### Adapt to vllm-ascend
Modification in `config.json`:`"model_type":deepseekv2` is changed to
`"model_type":deepseek_v3`;
#### 2.How to run w4a8
##### a.How to run eager mode
export VLLM_ASCEND_MLA_PA=1
python -m vllm.entrypoints.openai.api_server --model=$1
--trust-remote-code -tp $2 -dp $3 --enable_expert_parallel
--quantization ascend --port $4 --max-model-len $5 --max-num-seqs $6
--enforce-eager
eg: python -m vllm.entrypoints.openai.api_server
--model=/weightpath/w4a8_4_layer --trust-remote-code -tp 4 -dp 4
--enable_expert_parallel --quantization ascend --port 8002
--max-model-len 5120 --max-num-seqs 128 --enforce-eager
##### b.How to run graph mode
export HCCL_BUFFSIZE=1024
python -m vllm.entrypoints.openai.api_server --model=$1
--trust-remote-code -tp $2 -dp $3 --enable_expert_parallel
--quantization ascend --port $4 --max-model-len $5
--additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}'
eg: python -m vllm.entrypoints.openai.api_server
--model=/weight/dsr1_w4a8_vllm --trust-remote-code -tp 4 -dp 4
--enable_expert_parallel --quantization ascend --port 8002
--max-model-len 5120
--additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}'
- vLLM version: v0.10.0
- vLLM main:
103f1ec8d3
---------
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
358 lines
15 KiB
Python
358 lines
15 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
from types import MappingProxyType
|
|
from typing import Any, Callable, Dict, List, Mapping, Optional
|
|
|
|
import torch
|
|
from vllm.distributed import get_tensor_model_parallel_rank
|
|
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
|
FusedMoeWeightScaleSupported)
|
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
|
RowParallelLinear,
|
|
UnquantizedLinearMethod)
|
|
from vllm.model_executor.layers.quantization import \
|
|
register_quantization_config
|
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
QuantizationConfig, QuantizeMethodBase)
|
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
UnquantizedEmbeddingMethod, VocabParallelEmbedding)
|
|
from vllm.model_executor.parameter import PerTensorScaleParameter
|
|
from vllm.model_executor.utils import set_weight_attrs
|
|
|
|
from vllm_ascend.ops.fused_moe import AscendUnquantizedFusedMoEMethod
|
|
from vllm_ascend.utils import ASCEND_QUATIZATION_METHOD
|
|
|
|
from .quantizer import AscendQuantizer
|
|
|
|
|
|
@register_quantization_config(ASCEND_QUATIZATION_METHOD)
|
|
class AscendQuantConfig(QuantizationConfig):
|
|
"""Config class for Ascend
|
|
|
|
This class is a general class that parse quantization configs
|
|
that are supported on ascend hardware.
|
|
"""
|
|
|
|
def __init__(self, quant_config: Dict[str, Any]):
|
|
self.quant_description = quant_config
|
|
|
|
def __repr__(self) -> str:
|
|
return "AscendQuantConfig:\n" + super().__repr__()
|
|
|
|
@classmethod
|
|
def get_name(cls) -> str:
|
|
return ASCEND_QUATIZATION_METHOD
|
|
|
|
@classmethod
|
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
|
return [torch.int8, torch.float16, torch.bfloat16]
|
|
|
|
@classmethod
|
|
def get_min_capability(cls) -> int:
|
|
raise NotImplementedError(
|
|
"Ascend hardware dose not support \"get_min_capability\" feature.")
|
|
|
|
@classmethod
|
|
def get_config_filenames(cls) -> List[str]:
|
|
return ["quant_model_description.json"]
|
|
|
|
@classmethod
|
|
def from_config(cls, config: Dict[str, Any]) -> "AscendQuantConfig":
|
|
return cls(config)
|
|
|
|
@classmethod
|
|
def override_quantization_method(cls, hf_quant_cfg,
|
|
user_quant) -> Optional[str]:
|
|
if torch.npu.is_available():
|
|
return ASCEND_QUATIZATION_METHOD
|
|
return None
|
|
|
|
def get_quant_method(self, layer: torch.nn.Module,
|
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
|
from vllm.attention.layer import Attention
|
|
if isinstance(layer, LinearBase):
|
|
if self.is_layer_skipped_ascend(prefix,
|
|
self.packed_modules_mapping):
|
|
return UnquantizedLinearMethod()
|
|
return AscendLinearMethod(self, prefix,
|
|
self.packed_modules_mapping)
|
|
elif isinstance(layer, Attention) and \
|
|
'fa_quant_type' in self.quant_description.keys() and \
|
|
self.quant_description['fa_quant_type'] is not None:
|
|
return AscendKVCacheMethod(self, prefix)
|
|
elif isinstance(layer, Attention) and self.quant_description.get(
|
|
'kv_quant_type') == 'C8':
|
|
return AscendKVCacheMethod(self, prefix)
|
|
elif isinstance(layer, FusedMoE):
|
|
if self.is_layer_skipped_ascend(prefix,
|
|
self.packed_modules_mapping):
|
|
return AscendUnquantizedFusedMoEMethod(layer.moe)
|
|
return AscendFusedMoEMethod(self, prefix,
|
|
self.packed_modules_mapping)
|
|
elif isinstance(layer, VocabParallelEmbedding):
|
|
if self.is_layer_skipped_ascend(prefix,
|
|
self.packed_modules_mapping):
|
|
return UnquantizedEmbeddingMethod()
|
|
return AscendEmbeddingMethod(self, prefix,
|
|
self.packed_modules_mapping)
|
|
return None
|
|
|
|
def is_layer_skipped_ascend(
|
|
self,
|
|
prefix: str,
|
|
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})):
|
|
# adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped
|
|
proj_name = prefix.split(".")[-1]
|
|
if proj_name in fused_mapping:
|
|
shard_prefixes = [
|
|
prefix.replace(proj_name, shard_proj_name)
|
|
for shard_proj_name in fused_mapping[proj_name]
|
|
]
|
|
|
|
is_skipped = None
|
|
for shard_prefix in shard_prefixes:
|
|
is_shard_skipped = self.quant_description[shard_prefix +
|
|
'.weight'] == "FLOAT"
|
|
|
|
if is_skipped is None:
|
|
is_skipped = is_shard_skipped
|
|
elif is_shard_skipped != is_skipped:
|
|
raise ValueError(
|
|
f"Detected some but not all shards of {prefix} "
|
|
"are quantized. All shards of fused layers "
|
|
"to have the same precision.")
|
|
else:
|
|
is_skipped = self.quant_description[prefix + '.weight'] == "FLOAT"
|
|
|
|
assert is_skipped is not None
|
|
return is_skipped
|
|
|
|
def get_scaled_act_names(self) -> List[str]:
|
|
return []
|
|
|
|
|
|
class AscendLinearMethod(LinearMethodBase):
|
|
"""Linear method for Ascend quantization.
|
|
|
|
This class calls AscendQuantizer to search a specific quantization
|
|
implementations supported on ascend hardware for linear methods.
|
|
|
|
Args:
|
|
quant_config: The Ascend quantization config.
|
|
"""
|
|
|
|
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
|
packed_modules_mapping: Dict[str, Any]) -> None:
|
|
self.quantizer = AscendQuantizer.get_quantizer(
|
|
quant_config.quant_description, prefix, packed_modules_mapping)
|
|
self.quant_method = self.quantizer.build_linear_method()
|
|
|
|
def create_weights(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
input_size_per_partition: int,
|
|
output_partition_sizes: List[int],
|
|
input_size: int,
|
|
output_size: int,
|
|
params_dtype: torch.dtype,
|
|
**extra_weight_attrs,
|
|
) -> None:
|
|
output_size_per_partition = sum(output_partition_sizes)
|
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
|
|
|
weight_dict = self.quant_method.get_weight(input_size_per_partition,
|
|
output_size_per_partition,
|
|
params_dtype)
|
|
for weight_name, weight_param in weight_dict.items():
|
|
param = torch.nn.Parameter(weight_param, requires_grad=False)
|
|
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
|
|
layer.register_parameter(weight_name, param)
|
|
set_weight_attrs(param, extra_weight_attrs)
|
|
|
|
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
|
|
for pertensor_name, pertensor_param in pertensor_dict.items():
|
|
param = PerTensorScaleParameter(data=pertensor_param,
|
|
weight_loader=weight_loader)
|
|
# disable warning
|
|
param.ignore_warning = True
|
|
layer.register_parameter(pertensor_name, param)
|
|
|
|
perchannel_dict = self.quant_method.get_perchannel_param(
|
|
output_size_per_partition, params_dtype)
|
|
for perchannel_name, perchannel_param in perchannel_dict.items():
|
|
param = torch.nn.Parameter(perchannel_param, requires_grad=False)
|
|
set_weight_attrs(param, {"output_dim": 0})
|
|
layer.register_parameter(perchannel_name, param)
|
|
set_weight_attrs(param, extra_weight_attrs)
|
|
|
|
pergroup_dict = self.quant_method.get_pergroup_param(
|
|
input_size_per_partition, output_size_per_partition, params_dtype)
|
|
for pergroup_name, pergroup_param in pergroup_dict.items():
|
|
param = torch.nn.Parameter(pergroup_param, requires_grad=False)
|
|
set_weight_attrs(param, {"output_dim": 0})
|
|
layer.register_parameter(pergroup_name, param)
|
|
set_weight_attrs(param, extra_weight_attrs)
|
|
if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name:
|
|
setattr(param, "input_dim", 1)
|
|
param.input_dim = 1
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
if hasattr(self.quant_method, "process_weights_after_loading"):
|
|
self.quant_method.process_weights_after_loading(layer)
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
bias: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
if isinstance(layer, RowParallelLinear):
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
return self.quant_method.apply(layer, x, bias, tp_rank)
|
|
return self.quant_method.apply(layer, x, bias)
|
|
|
|
|
|
class AscendKVCacheMethod(BaseKVCacheMethod):
|
|
"""KVCache method for Ascend quantization.
|
|
|
|
This class calls AscendQuantizer to search a specific quantization
|
|
implementations supported on ascend hardware for kvcache methods.
|
|
|
|
Args:
|
|
quant_config: The Ascend quantization config.
|
|
"""
|
|
|
|
def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None:
|
|
self.quantizer = AscendQuantizer.get_quantizer(
|
|
quant_config.quant_description, prefix)
|
|
self.quant_method = self.quantizer.build_attention_method()
|
|
|
|
def create_weights(self, layer: torch.nn.Module) -> None:
|
|
# Different from linear method, there are no weight processing/slicing
|
|
# steps for attention in vllm. So the whole process of create weights
|
|
# is hidden into the specific quant method.
|
|
self.quant_method.create_weights(layer)
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
if hasattr(self.quant_method, "process_weights_after_loading"):
|
|
self.quant_method.process_weights_after_loading(layer)
|
|
|
|
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
|
|
key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata,
|
|
attn_type, scale, output) -> torch.Tensor:
|
|
return self.quant_method.apply(layer, query, key, value, kv_cache,
|
|
attn_metadata, attn_type, scale, output)
|
|
|
|
|
|
class AscendFusedMoEMethod(FusedMoEMethodBase):
|
|
"""FusedMoE method for Ascend quantization.
|
|
|
|
This class calls AscendQuantizer to search a specific quantization
|
|
implementations supported on ascend hardware for kvcache methods.
|
|
|
|
Args:
|
|
quant_config: The Ascend quantization config.
|
|
"""
|
|
|
|
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
|
packed_modules_mapping: Dict[str, Any]):
|
|
self.quantizer = AscendQuantizer.get_quantizer(
|
|
quant_config.quant_description, prefix, packed_modules_mapping)
|
|
self.quant_method = self.quantizer.build_moe_method()
|
|
|
|
def create_weights(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
num_experts: int,
|
|
hidden_size: int,
|
|
intermediate_size_per_partition: int,
|
|
params_dtype: torch.dtype,
|
|
**extra_weight_attrs,
|
|
) -> None:
|
|
weight_param = self.quant_method.get_weight(
|
|
num_experts, intermediate_size_per_partition, hidden_size,
|
|
params_dtype)
|
|
for param_key, param_value in weight_param.items():
|
|
param = torch.nn.Parameter(param_value, requires_grad=False)
|
|
layer.register_parameter(param_key, param)
|
|
set_weight_attrs(param, extra_weight_attrs)
|
|
|
|
extra_weight_attrs.update(
|
|
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
|
|
per_group_param = [
|
|
"weight_scale_second", "weight_offset_second", "scale_bias"
|
|
]
|
|
dynamic_quant_param = self.quant_method.get_dynamic_quant_param(
|
|
num_experts, intermediate_size_per_partition, hidden_size,
|
|
params_dtype)
|
|
for param_key, param_value in dynamic_quant_param.items():
|
|
param = torch.nn.Parameter(param_value, requires_grad=False)
|
|
layer.register_parameter(param_key, param)
|
|
set_weight_attrs(param, extra_weight_attrs)
|
|
if any(fields in param_key for fields in per_group_param):
|
|
setattr(param, "quant_method",
|
|
FusedMoeWeightScaleSupported.GROUP.value)
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
renormalize: bool,
|
|
use_grouped_topk: bool = False,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
is_prefill: bool = True,
|
|
enable_force_load_balance: bool = False,
|
|
log2phy: torch.Tensor = None,
|
|
global_redundant_expert_num=0,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
return self.quant_method.apply(
|
|
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
|
|
global_num_experts, expert_map, topk_group, num_expert_group,
|
|
custom_routing_function, scoring_func, e_score_correction_bias,
|
|
is_prefill, enable_force_load_balance, log2phy,
|
|
global_redundant_expert_num, **kwargs)
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
if hasattr(self.quant_method, "process_weights_after_loading"):
|
|
self.quant_method.process_weights_after_loading(layer)
|
|
|
|
|
|
class AscendEmbeddingMethod(AscendLinearMethod):
|
|
"""Embedding method for Ascend quantization.
|
|
This class calls AscendQuantizer to search a specific quantization
|
|
implementations supported on ascend hardware for Embedding methods.
|
|
Args:
|
|
quant_config: The Ascend quantization config.
|
|
"""
|
|
|
|
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
|
packed_modules_mapping: Dict[str, Any]) -> None:
|
|
self.quantizer = AscendQuantizer.get_quantizer(
|
|
quant_config.quant_description, prefix, packed_modules_mapping)
|
|
self.quant_method = self.quantizer.build_linear_method()
|