[Bugfix]Fix deepseek 3.2 C8 precision by rotary tensor (#7537)
### What this PR does / why we need it?
During the attention quantization process of DeepSeek V3.2, it is
necessary to retrieve the Hadamard matrix from the weights to facilitate
the computation.
### Does this PR introduce _any_ user-facing change?
No. But there will be two new tensor in quant weight.
### How was this patch tested?
- vLLM version: v0.18.0
- vLLM main:
8b6325758c
---------
Signed-off-by: mayumeng <m30059191@china.huawei.com>
Co-authored-by: mayumeng <m30059191@china.huawei.com>
This commit is contained in:
@@ -379,6 +379,8 @@ def get_quant_type_for_layer(
|
||||
# Attention
|
||||
if layer_type == "attention" and "fa_quant_type" in quant_description:
|
||||
return quant_description["fa_quant_type"]
|
||||
if layer_type == "attention" and "indexer_quant_type" in quant_description:
|
||||
return quant_description["indexer_quant_type"]
|
||||
# Linear / MoE
|
||||
return get_linear_quant_type(quant_description, prefix, packed_modules_mapping)
|
||||
|
||||
@@ -582,7 +584,9 @@ class AscendModelSlimConfig(QuantizationConfig):
|
||||
return AscendUnquantizedLinearMethod()
|
||||
scheme = create_scheme_for_layer(self.quant_description, prefix, "linear", self.packed_modules_mapping)
|
||||
return AscendLinearMethod(scheme)
|
||||
elif isinstance(layer, AttentionLayerBase) and self.is_fa_quant_layer(prefix):
|
||||
elif isinstance(layer, AttentionLayerBase) and (
|
||||
self.is_fa_quant_layer(prefix) or self.is_indexer_quant_layer(prefix)
|
||||
):
|
||||
scheme = create_scheme_for_layer(self.quant_description, prefix, "attention", self.packed_modules_mapping)
|
||||
return AscendKVCacheMethod(scheme)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
@@ -636,6 +640,13 @@ class AscendModelSlimConfig(QuantizationConfig):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_indexer_quant_layer(self, prefix):
|
||||
if self.enable_indexer_quant:
|
||||
layer_id_str = "".join(re.findall(r"\.(\d+)\.", prefix))
|
||||
if layer_id_str.isdigit() and int(layer_id_str) in self.indexer_quant_layers:
|
||||
return True
|
||||
return False
|
||||
|
||||
def enabling_fa_quant(self, vllm_config, layer_name) -> bool:
|
||||
is_decode_instance = (
|
||||
vllm_config.kv_transfer_config is not None
|
||||
@@ -773,8 +784,13 @@ class AscendModelSlimConfig(QuantizationConfig):
|
||||
fa_quant_type = self.quant_description.get("fa_quant_type", "")
|
||||
self.enable_fa_quant = fa_quant_type != ""
|
||||
self.kvcache_quant_layers = []
|
||||
if self.enable_fa_quant:
|
||||
indexer_quant_type = self.quant_description.get("indexer_quant_type", "")
|
||||
self.enable_indexer_quant = indexer_quant_type != ""
|
||||
self.indexer_quant_layers = []
|
||||
if self.enable_fa_quant or self.enable_indexer_quant:
|
||||
for key in self.quant_description:
|
||||
_id = "".join(re.findall(r"\.(\d+)\.", key))
|
||||
if "fa_k.scale" in key:
|
||||
_id = "".join(re.findall(r"\.(\d+)\.", key))
|
||||
self.kvcache_quant_layers.append(int(_id))
|
||||
if "indexer.quant_type" in key:
|
||||
self.indexer_quant_layers.append(int(_id))
|
||||
|
||||
Reference in New Issue
Block a user