From 8977be1df3eb60a3c6f7e2f593ef356ea7201106 Mon Sep 17 00:00:00 2001 From: Yaphets24 <44045681+Yaphets24@users.noreply.github.com> Date: Wed, 25 Mar 2026 09:18:00 +0800 Subject: [PATCH] [Bugfix]Fix deepseek 3.2 C8 precision by rotary tensor (#7537) ### What this PR does / why we need it? During the attention quantization process of DeepSeek V3.2, it is necessary to retrieve the Hadamard matrix from the weights to facilitate the computation. ### Does this PR introduce _any_ user-facing change? No. But there will be two new tensor in quant weight. ### How was this patch tested? - vLLM version: v0.18.0 - vLLM main: https://github.com/vllm-project/vllm/commit/8b6325758cce5f9c36d38f2462edbd368b97a07c --------- Signed-off-by: mayumeng Co-authored-by: mayumeng --- vllm_ascend/attention/sfa_v1.py | 17 ++++++++----- .../patch/worker/patch_weight_utils.py | 11 ++++++++- vllm_ascend/quantization/methods/kv_c8.py | 24 +++++++++++++++++++ vllm_ascend/quantization/modelslim_config.py | 22 ++++++++++++++--- 4 files changed, 64 insertions(+), 10 deletions(-) diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index da708060..7d787648 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -356,8 +356,9 @@ class AscendSFAImpl(MLAAttentionImpl): # Supports forward using the all-gather o_proj weight for decode requests when Sharded CP is enabled. o_proj_full_pool: torch.Tensor | None = None - # qk_hadamard tensor shared when dsa c8 enabled - qk_hadamard: torch.Tensor | None = None + # q_hadamard and k_hadamard tensor shared when dsa c8 enabled + q_hadamard: torch.Tensor | None = None + k_hadamard: torch.Tensor | None = None def __init__( self, @@ -525,8 +526,12 @@ class AscendSFAImpl(MLAAttentionImpl): # if mlapo, W_UK_T can't trans nz self.W_UK_T = maybe_trans_nz(self.W_UK_T) - if self.use_sparse_c8_indexer and AscendSFAImpl.qk_hadamard is None: - AscendSFAImpl.qk_hadamard = torch.tensor(scipy.linalg.hadamard(128), dtype=torch.bfloat16, device="npu") / ( + if self.use_sparse_c8_indexer and AscendSFAImpl.q_hadamard is None: + AscendSFAImpl.q_hadamard = torch.tensor(scipy.linalg.hadamard(128), dtype=torch.bfloat16, device="npu") / ( + 128**0.5 + ) + if self.use_sparse_c8_indexer and AscendSFAImpl.k_hadamard is None: + AscendSFAImpl.k_hadamard = torch.tensor(scipy.linalg.hadamard(128), dtype=torch.bfloat16, device="npu") / ( 128**0.5 ) @@ -890,7 +895,7 @@ class AscendSFAImpl(MLAAttentionImpl): k_li = torch.cat([k_li_pe, k_li_nope], dim=-1) # [b*s,128] if self.use_sparse_c8_indexer: - k_li = k_li @ AscendSFAImpl.qk_hadamard + k_li = k_li @ AscendSFAImpl.k_hadamard k_li, k_li_scale = torch_npu.npu_dynamic_quant(k_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype) k_li_scale = k_li_scale.to(self.c8_k_scale_cache_dtype) # [b*s,] k_li_scale = k_li_scale.unsqueeze(-1) # [b*s,1] @@ -930,7 +935,7 @@ class AscendSFAImpl(MLAAttentionImpl): if self.use_sparse_c8_indexer: q_li_shape_ori = q_li.shape - q_li = q_li @ AscendSFAImpl.qk_hadamard + q_li = q_li @ AscendSFAImpl.q_hadamard q_li, q_li_scale = torch_npu.npu_dynamic_quant(q_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype) q_li_scale = q_li_scale.to(self.c8_k_scale_cache_dtype) diff --git a/vllm_ascend/patch/worker/patch_weight_utils.py b/vllm_ascend/patch/worker/patch_weight_utils.py index 809b168b..68570905 100644 --- a/vllm_ascend/patch/worker/patch_weight_utils.py +++ b/vllm_ascend/patch/worker/patch_weight_utils.py @@ -39,7 +39,16 @@ def patch_deepseek(module): def new_remap(name: str, params_dict: dict): name = ori_maybe_remap_kv_scale_name(name, params_dict) - replace_scale_names = ["fa_q.scale", "fa_k.scale", "fa_v.scale", "fa_q.offset", "fa_k.offset", "fa_v.offset"] + replace_scale_names = [ + "fa_q.scale", + "fa_k.scale", + "fa_v.scale", + "fa_q.offset", + "fa_k.offset", + "fa_v.offset", + "indexer.q_rot", + "indexer.k_rot", + ] for scale_name in replace_scale_names: if name.endswith(scale_name): diff --git a/vllm_ascend/quantization/methods/kv_c8.py b/vllm_ascend/quantization/methods/kv_c8.py index 8a700484..10056741 100644 --- a/vllm_ascend/quantization/methods/kv_c8.py +++ b/vllm_ascend/quantization/methods/kv_c8.py @@ -63,3 +63,27 @@ class AscendFAQuantAttentionMethod: repeated_quant_kscale = fa_k_scale.repeat(self.kv_lora_rank) layer.quant_kscale = repeated_quant_kscale.view(1, self.kv_lora_rank) layer.quant_kscale = 1.0 / torch.nn.Parameter(layer.quant_kscale.to(torch.float), requires_grad=False) + + +@register_scheme("INT8_DYNAMIC", "attention") +class AscendSFAQuantAttentionMethod: + def __init__(self): + vllm_config = get_current_vllm_config() + config = vllm_config.model_config.hf_config + self.index_head_dim = config.index_head_dim + + def create_weights(self, layer: torch.nn.Module) -> None: + extra_module_names = ["indexer"] + for name in extra_module_names: + setattr(layer, name, torch.nn.Module()) + params_dict = {} + params_dict["indexer.q_rot"] = torch.empty((self.index_head_dim, self.index_head_dim), dtype=torch.float32) + params_dict["indexer.k_rot"] = torch.empty((self.index_head_dim, self.index_head_dim), dtype=torch.float32) + for name, weight in params_dict.items(): + module_name, weight_name = name.split(".") + module = getattr(layer, module_name) + weight_param = torch.nn.Parameter(weight, requires_grad=False) + module.register_parameter(weight_name, weight_param) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + pass diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index 120f2191..0ad0bde4 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -379,6 +379,8 @@ def get_quant_type_for_layer( # Attention if layer_type == "attention" and "fa_quant_type" in quant_description: return quant_description["fa_quant_type"] + if layer_type == "attention" and "indexer_quant_type" in quant_description: + return quant_description["indexer_quant_type"] # Linear / MoE return get_linear_quant_type(quant_description, prefix, packed_modules_mapping) @@ -582,7 +584,9 @@ class AscendModelSlimConfig(QuantizationConfig): return AscendUnquantizedLinearMethod() scheme = create_scheme_for_layer(self.quant_description, prefix, "linear", self.packed_modules_mapping) return AscendLinearMethod(scheme) - elif isinstance(layer, AttentionLayerBase) and self.is_fa_quant_layer(prefix): + elif isinstance(layer, AttentionLayerBase) and ( + self.is_fa_quant_layer(prefix) or self.is_indexer_quant_layer(prefix) + ): scheme = create_scheme_for_layer(self.quant_description, prefix, "attention", self.packed_modules_mapping) return AscendKVCacheMethod(scheme) elif isinstance(layer, FusedMoE): @@ -636,6 +640,13 @@ class AscendModelSlimConfig(QuantizationConfig): return True return False + def is_indexer_quant_layer(self, prefix): + if self.enable_indexer_quant: + layer_id_str = "".join(re.findall(r"\.(\d+)\.", prefix)) + if layer_id_str.isdigit() and int(layer_id_str) in self.indexer_quant_layers: + return True + return False + def enabling_fa_quant(self, vllm_config, layer_name) -> bool: is_decode_instance = ( vllm_config.kv_transfer_config is not None @@ -773,8 +784,13 @@ class AscendModelSlimConfig(QuantizationConfig): fa_quant_type = self.quant_description.get("fa_quant_type", "") self.enable_fa_quant = fa_quant_type != "" self.kvcache_quant_layers = [] - if self.enable_fa_quant: + indexer_quant_type = self.quant_description.get("indexer_quant_type", "") + self.enable_indexer_quant = indexer_quant_type != "" + self.indexer_quant_layers = [] + if self.enable_fa_quant or self.enable_indexer_quant: for key in self.quant_description: + _id = "".join(re.findall(r"\.(\d+)\.", key)) if "fa_k.scale" in key: - _id = "".join(re.findall(r"\.(\d+)\.", key)) self.kvcache_quant_layers.append(int(_id)) + if "indexer.quant_type" in key: + self.indexer_quant_layers.append(int(_id))