[Feature]Supports DSv3.1 PD separation and C8 quantization (#7222)
Co-authored-by: kunpengW-code <1289706727@qq.com>
Co-authored-by: linsheng1 <1950916997@qq.com>
### What this PR does / why we need it?
Currently, chunked prefill is forcibly enabled. DeepSeek V3.1 W8A8C8
supports only the PD separation scenario. C8 refers to quantizing the KV
cache to int8, which aims to reduce the GPU memory usage of the KV cache
and improve the inference throughput.
Constraints:
1. Only the PD separation mode can be used and
MooncakeLayerwiseConnector can be used to run the model.
2. Currently, only the activation value supports dynamic quantization,
and the KV cache supports static quantization. C8 quantization with MTP
is not supported. You can use ModelSlim for quantization. The
quantization procedure is as follows:
pip install transformers==4.48.2
git clone https://gitcode.com/Ascend/msmodelslim.git
cd msmodelslim
bash install.sh
cd example/DeepSeek/
python3 quant_deepseek_w8a8.py --model_path <path/weight> --save_path
<path/quant_weight>
--anti_dataset../common/deepseek_anti_prompt_50_v3_1.json
--calib_dataset../common/deepseek_calib_prompt_50_v3_1.json --rot
--trust_remote_code True --fa_quant --dynamic --anti_method m6
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: pichangping <1337510399@qq.com>
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Co-authored-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
@@ -32,10 +32,11 @@ from typing import Any
|
||||
# Import base classes
|
||||
from .base import AscendAttentionScheme, AscendLinearScheme, AscendMoEScheme, QuantType
|
||||
|
||||
# Import all scheme classes for external access
|
||||
from .kv_c8 import AscendFAQuantAttentionMethod
|
||||
|
||||
# Import registry functions
|
||||
from .registry import get_scheme_class, register_scheme
|
||||
|
||||
# Import all scheme classes for external access
|
||||
from .w4a4_flatquant import AscendW4A4FlatQuantDynamicLinearMethod
|
||||
from .w4a4_laos_dynamic import AscendW4A4LaosDynamicLinearMethod
|
||||
from .w4a8 import AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod
|
||||
@@ -77,4 +78,5 @@ __all__ = [
|
||||
"AscendW4A16FusedMoEMethod",
|
||||
"AscendW4A4FlatQuantDynamicLinearMethod",
|
||||
"AscendW4A4LaosDynamicLinearMethod",
|
||||
"AscendFAQuantAttentionMethod",
|
||||
]
|
||||
|
||||
65
vllm_ascend/quantization/methods/kv_c8.py
Normal file
65
vllm_ascend/quantization/methods/kv_c8.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import torch
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
|
||||
|
||||
from .registry import register_scheme
|
||||
|
||||
|
||||
def weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor):
|
||||
"""fa_q weight loader."""
|
||||
if param.numel() == 1 and loaded_weight.numel() == 1:
|
||||
param.data.fill_(loaded_weight.item())
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
shard_size = loaded_weight.shape[0] // tp_size
|
||||
loaded_weight = loaded_weight.narrow(0, shard_size * tp_rank, shard_size)
|
||||
assert param.size() == loaded_weight.size(), (
|
||||
f"Attempted to load weight ({loaded_weight.size()}) into parameter ({param.size()}) when TP is ({tp_size})"
|
||||
)
|
||||
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
|
||||
@register_scheme("FAKQuant", "attention")
|
||||
class AscendFAQuantAttentionMethod:
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
self.printFlag = False
|
||||
vllm_config = get_current_vllm_config()
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.kv_lora_rank = getattr(config, "kv_lora_rank", 0)
|
||||
self.qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module) -> None:
|
||||
extra_module_names = ["fa_q", "fa_k", "fa_v"]
|
||||
for name in extra_module_names:
|
||||
setattr(layer, name, torch.nn.Module())
|
||||
params_dict = {}
|
||||
dtype = torch.get_default_dtype()
|
||||
params_dict["fa_q.scale"] = torch.empty((layer.num_heads, 1), dtype=dtype)
|
||||
params_dict["fa_k.scale"] = torch.empty((layer.num_kv_heads, 1), dtype=dtype)
|
||||
params_dict["fa_v.scale"] = torch.empty((layer.num_kv_heads, 1), dtype=dtype)
|
||||
params_dict["fa_q.offset"] = torch.empty((layer.num_heads, 1), dtype=torch.int8)
|
||||
params_dict["fa_k.offset"] = torch.empty((layer.num_kv_heads, 1), dtype=torch.int8)
|
||||
params_dict["fa_v.offset"] = torch.empty((layer.num_kv_heads, 1), dtype=torch.int8)
|
||||
|
||||
for name, weight in params_dict.items():
|
||||
module_name, weight_name = name.rsplit(".", 1)
|
||||
module = getattr(layer, module_name)
|
||||
weight_param = torch.nn.Parameter(weight, requires_grad=False)
|
||||
module.register_parameter(weight_name, weight_param)
|
||||
# When loading weights, segment them according to TP
|
||||
weight_param.weight_loader = weight_loader
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
fa_k_scale = torch.squeeze(layer.fa_k.scale).unsqueeze(0)
|
||||
layer.fak_descale_float = torch.nn.Parameter(fa_k_scale.to(torch.float), requires_grad=False)
|
||||
layer.fak_descale = torch.nn.Parameter(fa_k_scale, requires_grad=False)
|
||||
layer.fak_descale_reciprocal = 1.0 / torch.nn.Parameter(fa_k_scale, requires_grad=False)
|
||||
fa_k_offset = torch.squeeze(layer.fa_k.offset).unsqueeze(0)
|
||||
layer.fak_offset = torch.nn.Parameter(fa_k_offset.to(layer.fak_descale.dtype), requires_grad=False)
|
||||
|
||||
repeated_quant_kscale = fa_k_scale.repeat(self.kv_lora_rank)
|
||||
layer.quant_kscale = repeated_quant_kscale.view(1, self.kv_lora_rank)
|
||||
layer.quant_kscale = 1.0 / torch.nn.Parameter(layer.quant_kscale.to(torch.float), requires_grad=False)
|
||||
@@ -24,6 +24,7 @@ configs generated by the ModelSlim tool, along with model-specific mappings.
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Optional
|
||||
@@ -31,6 +32,7 @@ from typing import Any, Optional
|
||||
import torch
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.quantization import register_quantization_config
|
||||
@@ -38,7 +40,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import UnquantizedEmbeddingMethod, VocabParallelEmbedding
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
|
||||
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
|
||||
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD, calc_split_factor
|
||||
|
||||
from .methods import get_scheme_class
|
||||
|
||||
@@ -438,6 +440,7 @@ class AscendModelSlimConfig(QuantizationConfig):
|
||||
new_k = k.replace("weight_packed", "weight")
|
||||
extra_quant_dict[new_k] = self.quant_description[k]
|
||||
self.quant_description.update(extra_quant_dict)
|
||||
self._add_kvcache_quant_metadata()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "AscendModelSlimConfig:\n" + super().__repr__()
|
||||
@@ -509,8 +512,6 @@ class AscendModelSlimConfig(QuantizationConfig):
|
||||
self.packed_modules_mapping = packed_modules_model_mapping[model_type]
|
||||
prefix = self.quant_prefix_mapper(model_type, prefix)
|
||||
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
|
||||
if model_type != "kimi_k2":
|
||||
if prefix.startswith("language_model"):
|
||||
prefix = prefix.split(".", 1)[-1]
|
||||
@@ -522,11 +523,7 @@ class AscendModelSlimConfig(QuantizationConfig):
|
||||
return AscendUnquantizedLinearMethod()
|
||||
scheme = create_scheme_for_layer(self.quant_description, prefix, "linear", self.packed_modules_mapping)
|
||||
return AscendLinearMethod(scheme)
|
||||
elif (
|
||||
isinstance(layer, Attention)
|
||||
and "fa_quant_type" in self.quant_description
|
||||
and self.quant_description["fa_quant_type"] is not None
|
||||
):
|
||||
elif isinstance(layer, AttentionLayerBase) and self.is_fa_quant_layer(prefix):
|
||||
scheme = create_scheme_for_layer(self.quant_description, prefix, "attention", self.packed_modules_mapping)
|
||||
return AscendKVCacheMethod(scheme)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
@@ -573,6 +570,39 @@ class AscendModelSlimConfig(QuantizationConfig):
|
||||
assert is_skipped is not None
|
||||
return is_skipped
|
||||
|
||||
def is_fa_quant_layer(self, prefix):
|
||||
if self.enable_fa_quant:
|
||||
layer_id_str = "".join(re.findall(r"\.(\d+)\.", prefix))
|
||||
if layer_id_str.isdigit() and int(layer_id_str) in self.kvcache_quant_layers:
|
||||
return True
|
||||
return False
|
||||
|
||||
def enabling_fa_quant(self, vllm_config, layer_name) -> bool:
|
||||
is_decode_instance = (
|
||||
vllm_config.kv_transfer_config is not None
|
||||
and vllm_config.kv_transfer_config.is_kv_consumer
|
||||
and not vllm_config.kv_transfer_config.is_kv_producer
|
||||
)
|
||||
return bool(is_decode_instance and self.is_fa_quant_layer(layer_name))
|
||||
|
||||
def get_kv_quant_dtype(self, layer_name, cache_dtype, model_config):
|
||||
if self.enable_fa_quant and self.is_fa_quant_layer(layer_name):
|
||||
ori_dtype = model_config.dtype
|
||||
quant_dtype = torch.int8
|
||||
# For MLA models like deepseek, we only quantify K cache to ensure accuracy
|
||||
if model_config.use_mla:
|
||||
return quant_dtype, ori_dtype
|
||||
else:
|
||||
return quant_dtype, quant_dtype
|
||||
return cache_dtype, cache_dtype
|
||||
|
||||
def get_kv_quant_split_factor(self, layer_name, kv_head_dim_list):
|
||||
if self.enable_fa_quant and self.is_fa_quant_layer(layer_name):
|
||||
k_quant_head_dim = kv_head_dim_list[0]
|
||||
v_quant_head_dim = kv_head_dim_list[1] * 2
|
||||
kv_head_dim_list = [k_quant_head_dim, v_quant_head_dim]
|
||||
return calc_split_factor(kv_head_dim_list)
|
||||
|
||||
def maybe_update_config(self, model_name: str, revision: str | None = None) -> None:
|
||||
"""Load the ModelSlim quantization config from model directory.
|
||||
|
||||
@@ -606,6 +636,7 @@ class AscendModelSlimConfig(QuantizationConfig):
|
||||
with open(config_path) as f:
|
||||
self.quant_description = json.load(f)
|
||||
self._apply_extra_quant_adaptations()
|
||||
self._add_kvcache_quant_metadata()
|
||||
return
|
||||
|
||||
# Collect diagnostic info for the error message
|
||||
@@ -678,3 +709,13 @@ class AscendModelSlimConfig(QuantizationConfig):
|
||||
|
||||
def get_scaled_act_names(self) -> list[str]:
|
||||
return []
|
||||
|
||||
def _add_kvcache_quant_metadata(self):
|
||||
fa_quant_type = self.quant_description.get("fa_quant_type", "")
|
||||
self.enable_fa_quant = fa_quant_type != ""
|
||||
self.kvcache_quant_layers = []
|
||||
if self.enable_fa_quant:
|
||||
for key in self.quant_description:
|
||||
if "fa_k.scale" in key:
|
||||
_id = "".join(re.findall(r"\.(\d+)\.", key))
|
||||
self.kvcache_quant_layers.append(int(_id))
|
||||
|
||||
Reference in New Issue
Block a user