From 481f608b8ec1528eac6ad427f941a86a82eab2a7 Mon Sep 17 00:00:00 2001 From: lambert0312 Date: Wed, 12 Mar 2025 16:37:16 +0800 Subject: [PATCH] Add INT8 support MTP NextN function (#3911) --- python/sglang/srt/models/deepseek_nextn.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py index 753e1ba1f..9d3159326 100644 --- a/python/sglang/srt/models/deepseek_nextn.py +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -30,6 +30,9 @@ from sglang.srt.layers.quantization.fp8_utils import ( block_quant_to_tensor_quant, normalize_e4m3fn_to_e4m3fnuz, ) +from sglang.srt.layers.quantization.int8_utils import ( + block_dequant as int8_block_dequant, +) from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -291,6 +294,23 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): weight, weight_scale, weight_block_size ) self_attn.w_scale = scale + if w.dtype == torch.int8: + if hasattr(self.quant_config, "weight_block_size"): + # block-wise int8 need it + weight_block_size = self.quant_config.weight_block_size + if weight_block_size is not None: + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale_inv + w = int8_block_dequant( + weight, weight_scale, weight_block_size + ).to(torch.bfloat16) + else: + # channel-wise int8 need it + assert hasattr(self_attn.kv_b_proj, "weight_scale") + w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to( + torch.bfloat16 + ) w_kc, w_vc = w.unflatten( 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)