[Feature] DeepSeek V3/R1 INT8 Quantization (channel-wise) (#3888)

Co-authored-by: yych0745 <1398089567@qq.com>
Co-authored-by: sleepcoo <sleepcoo@gmail.com>
Co-authored-by: b0urnee <2769086541@qq.com>
This commit is contained in:
HandH1998
2025-03-07 12:54:52 +08:00
committed by GitHub
parent 63ee26d162
commit c7f254468f
5 changed files with 369 additions and 21 deletions

View File

@@ -1202,18 +1202,22 @@ class DeepseekV2ForCausalLM(nn.Module):
weight, weight_scale, weight_block_size
)
self_attn.w_scale = scale
if (
hasattr(self.quant_config, "weight_block_size")
and w.dtype == torch.int8
):
weight_block_size = self.quant_config.weight_block_size
if weight_block_size is not None:
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
weight = w
weight_scale = self_attn.kv_b_proj.weight_scale_inv
w = int8_block_dequant(
weight, weight_scale, weight_block_size
).to(torch.bfloat16)
if w.dtype == torch.int8:
if hasattr(self.quant_config, "weight_block_size"):
# block-wise int8 need it
weight_block_size = self.quant_config.weight_block_size
if weight_block_size is not None:
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
weight = w
weight_scale = self_attn.kv_b_proj.weight_scale_inv
w = int8_block_dequant(
weight, weight_scale, weight_block_size
).to(torch.bfloat16)
else:
# channel-wise int8 need it
w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
torch.bfloat16
)
w_kc, w_vc = w.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)