Feature DeepSeek V3/R1 INT8 Quantization (block-wise) (#3730)
Co-authored-by: HandH1998 <1335248067@qq.com>
This commit is contained in:
@@ -47,6 +47,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
||||
input_to_float8,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from sglang.srt.layers.quantization.int8_utils import (
|
||||
block_dequant as int8_block_dequant,
|
||||
)
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
@@ -994,6 +997,18 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
weight, weight_scale, weight_block_size
|
||||
)
|
||||
self_attn.w_scale = scale
|
||||
if (
|
||||
hasattr(self.quant_config, "weight_block_size")
|
||||
and w.dtype == torch.int8
|
||||
):
|
||||
weight_block_size = self.quant_config.weight_block_size
|
||||
if weight_block_size is not None:
|
||||
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
||||
weight = w
|
||||
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
||||
w = int8_block_dequant(
|
||||
weight, weight_scale, weight_block_size
|
||||
).to(torch.bfloat16)
|
||||
w_kc, w_vc = w.unflatten(
|
||||
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
||||
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
||||
|
||||
Reference in New Issue
Block a user