Files

90 lines
4.4 KiB
Python
Raw Permalink Normal View History

[Feature]Supports DSv3.1 PD separation and C8 quantization (#7222) Co-authored-by: kunpengW-code <1289706727@qq.com> Co-authored-by: linsheng1 <1950916997@qq.com> ### What this PR does / why we need it? Currently, chunked prefill is forcibly enabled. DeepSeek V3.1 W8A8C8 supports only the PD separation scenario. C8 refers to quantizing the KV cache to int8, which aims to reduce the GPU memory usage of the KV cache and improve the inference throughput. Constraints: 1. Only the PD separation mode can be used and MooncakeLayerwiseConnector can be used to run the model. 2. Currently, only the activation value supports dynamic quantization, and the KV cache supports static quantization. C8 quantization with MTP is not supported. You can use ModelSlim for quantization. The quantization procedure is as follows: pip install transformers==4.48.2 git clone https://gitcode.com/Ascend/msmodelslim.git cd msmodelslim bash install.sh cd example/DeepSeek/ python3 quant_deepseek_w8a8.py --model_path <path/weight> --save_path <path/quant_weight> --anti_dataset../common/deepseek_anti_prompt_50_v3_1.json --calib_dataset../common/deepseek_calib_prompt_50_v3_1.json --rot --trust_remote_code True --fa_quant --dynamic --anti_method m6 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d --------- Signed-off-by: pichangping <1337510399@qq.com> Signed-off-by: Wang Kunpeng <1289706727@qq.com> Co-authored-by: Wang Kunpeng <1289706727@qq.com>
2026-03-16 22:49:05 +08:00
import torch
from vllm.config import get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
from .registry import register_scheme
def weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor):
"""fa_q weight loader."""
if param.numel() == 1 and loaded_weight.numel() == 1:
param.data.fill_(loaded_weight.item())
else:
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
shard_size = loaded_weight.shape[0] // tp_size
loaded_weight = loaded_weight.narrow(0, shard_size * tp_rank, shard_size)
assert param.size() == loaded_weight.size(), (
f"Attempted to load weight ({loaded_weight.size()}) into parameter ({param.size()}) when TP is ({tp_size})"
)
param.data.copy_(loaded_weight)
@register_scheme("FAKQuant", "attention")
class AscendFAQuantAttentionMethod:
def __init__(self):
self.transpose_weight = True
self.printFlag = False
vllm_config = get_current_vllm_config()
config = vllm_config.model_config.hf_config
self.kv_lora_rank = getattr(config, "kv_lora_rank", 0)
self.qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
def create_weights(self, layer: torch.nn.Module) -> None:
extra_module_names = ["fa_q", "fa_k", "fa_v"]
for name in extra_module_names:
setattr(layer, name, torch.nn.Module())
params_dict = {}
dtype = torch.get_default_dtype()
params_dict["fa_q.scale"] = torch.empty((layer.num_heads, 1), dtype=dtype)
params_dict["fa_k.scale"] = torch.empty((layer.num_kv_heads, 1), dtype=dtype)
params_dict["fa_v.scale"] = torch.empty((layer.num_kv_heads, 1), dtype=dtype)
params_dict["fa_q.offset"] = torch.empty((layer.num_heads, 1), dtype=torch.int8)
params_dict["fa_k.offset"] = torch.empty((layer.num_kv_heads, 1), dtype=torch.int8)
params_dict["fa_v.offset"] = torch.empty((layer.num_kv_heads, 1), dtype=torch.int8)
for name, weight in params_dict.items():
module_name, weight_name = name.rsplit(".", 1)
module = getattr(layer, module_name)
weight_param = torch.nn.Parameter(weight, requires_grad=False)
module.register_parameter(weight_name, weight_param)
# When loading weights, segment them according to TP
weight_param.weight_loader = weight_loader
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
fa_k_scale = torch.squeeze(layer.fa_k.scale).unsqueeze(0)
layer.fak_descale_float = torch.nn.Parameter(fa_k_scale.to(torch.float), requires_grad=False)
layer.fak_descale = torch.nn.Parameter(fa_k_scale, requires_grad=False)
layer.fak_descale_reciprocal = 1.0 / torch.nn.Parameter(fa_k_scale, requires_grad=False)
fa_k_offset = torch.squeeze(layer.fa_k.offset).unsqueeze(0)
layer.fak_offset = torch.nn.Parameter(fa_k_offset.to(layer.fak_descale.dtype), requires_grad=False)
repeated_quant_kscale = fa_k_scale.repeat(self.kv_lora_rank)
layer.quant_kscale = repeated_quant_kscale.view(1, self.kv_lora_rank)
layer.quant_kscale = 1.0 / torch.nn.Parameter(layer.quant_kscale.to(torch.float), requires_grad=False)
@register_scheme("INT8_DYNAMIC", "attention")
class AscendSFAQuantAttentionMethod:
def __init__(self):
vllm_config = get_current_vllm_config()
config = vllm_config.model_config.hf_config
self.index_head_dim = config.index_head_dim
def create_weights(self, layer: torch.nn.Module) -> None:
extra_module_names = ["indexer"]
for name in extra_module_names:
setattr(layer, name, torch.nn.Module())
params_dict = {}
params_dict["indexer.q_rot"] = torch.empty((self.index_head_dim, self.index_head_dim), dtype=torch.float32)
params_dict["indexer.k_rot"] = torch.empty((self.index_head_dim, self.index_head_dim), dtype=torch.float32)
for name, weight in params_dict.items():
module_name, weight_name = name.split(".")
module = getattr(layer, module_name)
weight_param = torch.nn.Parameter(weight, requires_grad=False)
module.register_parameter(weight_name, weight_param)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass