[Bugfix]Fix deepseek 3.2 C8 precision by rotary tensor (#7537)
### What this PR does / why we need it?
During the attention quantization process of DeepSeek V3.2, it is
necessary to retrieve the Hadamard matrix from the weights to facilitate
the computation.
### Does this PR introduce _any_ user-facing change?
No. But there will be two new tensor in quant weight.
### How was this patch tested?
- vLLM version: v0.18.0
- vLLM main:
8b6325758c
---------
Signed-off-by: mayumeng <m30059191@china.huawei.com>
Co-authored-by: mayumeng <m30059191@china.huawei.com>
This commit is contained in:
@@ -356,8 +356,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
# Supports forward using the all-gather o_proj weight for decode requests when Sharded CP is enabled.
|
||||
o_proj_full_pool: torch.Tensor | None = None
|
||||
|
||||
# qk_hadamard tensor shared when dsa c8 enabled
|
||||
qk_hadamard: torch.Tensor | None = None
|
||||
# q_hadamard and k_hadamard tensor shared when dsa c8 enabled
|
||||
q_hadamard: torch.Tensor | None = None
|
||||
k_hadamard: torch.Tensor | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -525,8 +526,12 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
# if mlapo, W_UK_T can't trans nz
|
||||
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
|
||||
|
||||
if self.use_sparse_c8_indexer and AscendSFAImpl.qk_hadamard is None:
|
||||
AscendSFAImpl.qk_hadamard = torch.tensor(scipy.linalg.hadamard(128), dtype=torch.bfloat16, device="npu") / (
|
||||
if self.use_sparse_c8_indexer and AscendSFAImpl.q_hadamard is None:
|
||||
AscendSFAImpl.q_hadamard = torch.tensor(scipy.linalg.hadamard(128), dtype=torch.bfloat16, device="npu") / (
|
||||
128**0.5
|
||||
)
|
||||
if self.use_sparse_c8_indexer and AscendSFAImpl.k_hadamard is None:
|
||||
AscendSFAImpl.k_hadamard = torch.tensor(scipy.linalg.hadamard(128), dtype=torch.bfloat16, device="npu") / (
|
||||
128**0.5
|
||||
)
|
||||
|
||||
@@ -890,7 +895,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
k_li = torch.cat([k_li_pe, k_li_nope], dim=-1) # [b*s,128]
|
||||
|
||||
if self.use_sparse_c8_indexer:
|
||||
k_li = k_li @ AscendSFAImpl.qk_hadamard
|
||||
k_li = k_li @ AscendSFAImpl.k_hadamard
|
||||
k_li, k_li_scale = torch_npu.npu_dynamic_quant(k_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype)
|
||||
k_li_scale = k_li_scale.to(self.c8_k_scale_cache_dtype) # [b*s,]
|
||||
k_li_scale = k_li_scale.unsqueeze(-1) # [b*s,1]
|
||||
@@ -930,7 +935,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
|
||||
if self.use_sparse_c8_indexer:
|
||||
q_li_shape_ori = q_li.shape
|
||||
q_li = q_li @ AscendSFAImpl.qk_hadamard
|
||||
q_li = q_li @ AscendSFAImpl.q_hadamard
|
||||
q_li, q_li_scale = torch_npu.npu_dynamic_quant(q_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype)
|
||||
q_li_scale = q_li_scale.to(self.c8_k_scale_cache_dtype)
|
||||
|
||||
|
||||
@@ -39,7 +39,16 @@ def patch_deepseek(module):
|
||||
def new_remap(name: str, params_dict: dict):
|
||||
name = ori_maybe_remap_kv_scale_name(name, params_dict)
|
||||
|
||||
replace_scale_names = ["fa_q.scale", "fa_k.scale", "fa_v.scale", "fa_q.offset", "fa_k.offset", "fa_v.offset"]
|
||||
replace_scale_names = [
|
||||
"fa_q.scale",
|
||||
"fa_k.scale",
|
||||
"fa_v.scale",
|
||||
"fa_q.offset",
|
||||
"fa_k.offset",
|
||||
"fa_v.offset",
|
||||
"indexer.q_rot",
|
||||
"indexer.k_rot",
|
||||
]
|
||||
|
||||
for scale_name in replace_scale_names:
|
||||
if name.endswith(scale_name):
|
||||
|
||||
@@ -63,3 +63,27 @@ class AscendFAQuantAttentionMethod:
|
||||
repeated_quant_kscale = fa_k_scale.repeat(self.kv_lora_rank)
|
||||
layer.quant_kscale = repeated_quant_kscale.view(1, self.kv_lora_rank)
|
||||
layer.quant_kscale = 1.0 / torch.nn.Parameter(layer.quant_kscale.to(torch.float), requires_grad=False)
|
||||
|
||||
|
||||
@register_scheme("INT8_DYNAMIC", "attention")
|
||||
class AscendSFAQuantAttentionMethod:
|
||||
def __init__(self):
|
||||
vllm_config = get_current_vllm_config()
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.index_head_dim = config.index_head_dim
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module) -> None:
|
||||
extra_module_names = ["indexer"]
|
||||
for name in extra_module_names:
|
||||
setattr(layer, name, torch.nn.Module())
|
||||
params_dict = {}
|
||||
params_dict["indexer.q_rot"] = torch.empty((self.index_head_dim, self.index_head_dim), dtype=torch.float32)
|
||||
params_dict["indexer.k_rot"] = torch.empty((self.index_head_dim, self.index_head_dim), dtype=torch.float32)
|
||||
for name, weight in params_dict.items():
|
||||
module_name, weight_name = name.split(".")
|
||||
module = getattr(layer, module_name)
|
||||
weight_param = torch.nn.Parameter(weight, requires_grad=False)
|
||||
module.register_parameter(weight_name, weight_param)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
@@ -379,6 +379,8 @@ def get_quant_type_for_layer(
|
||||
# Attention
|
||||
if layer_type == "attention" and "fa_quant_type" in quant_description:
|
||||
return quant_description["fa_quant_type"]
|
||||
if layer_type == "attention" and "indexer_quant_type" in quant_description:
|
||||
return quant_description["indexer_quant_type"]
|
||||
# Linear / MoE
|
||||
return get_linear_quant_type(quant_description, prefix, packed_modules_mapping)
|
||||
|
||||
@@ -582,7 +584,9 @@ class AscendModelSlimConfig(QuantizationConfig):
|
||||
return AscendUnquantizedLinearMethod()
|
||||
scheme = create_scheme_for_layer(self.quant_description, prefix, "linear", self.packed_modules_mapping)
|
||||
return AscendLinearMethod(scheme)
|
||||
elif isinstance(layer, AttentionLayerBase) and self.is_fa_quant_layer(prefix):
|
||||
elif isinstance(layer, AttentionLayerBase) and (
|
||||
self.is_fa_quant_layer(prefix) or self.is_indexer_quant_layer(prefix)
|
||||
):
|
||||
scheme = create_scheme_for_layer(self.quant_description, prefix, "attention", self.packed_modules_mapping)
|
||||
return AscendKVCacheMethod(scheme)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
@@ -636,6 +640,13 @@ class AscendModelSlimConfig(QuantizationConfig):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_indexer_quant_layer(self, prefix):
|
||||
if self.enable_indexer_quant:
|
||||
layer_id_str = "".join(re.findall(r"\.(\d+)\.", prefix))
|
||||
if layer_id_str.isdigit() and int(layer_id_str) in self.indexer_quant_layers:
|
||||
return True
|
||||
return False
|
||||
|
||||
def enabling_fa_quant(self, vllm_config, layer_name) -> bool:
|
||||
is_decode_instance = (
|
||||
vllm_config.kv_transfer_config is not None
|
||||
@@ -773,8 +784,13 @@ class AscendModelSlimConfig(QuantizationConfig):
|
||||
fa_quant_type = self.quant_description.get("fa_quant_type", "")
|
||||
self.enable_fa_quant = fa_quant_type != ""
|
||||
self.kvcache_quant_layers = []
|
||||
if self.enable_fa_quant:
|
||||
indexer_quant_type = self.quant_description.get("indexer_quant_type", "")
|
||||
self.enable_indexer_quant = indexer_quant_type != ""
|
||||
self.indexer_quant_layers = []
|
||||
if self.enable_fa_quant or self.enable_indexer_quant:
|
||||
for key in self.quant_description:
|
||||
_id = "".join(re.findall(r"\.(\d+)\.", key))
|
||||
if "fa_k.scale" in key:
|
||||
_id = "".join(re.findall(r"\.(\d+)\.", key))
|
||||
self.kvcache_quant_layers.append(int(_id))
|
||||
if "indexer.quant_type" in key:
|
||||
self.indexer_quant_layers.append(int(_id))
|
||||
|
||||
Reference in New Issue
Block a user