From e5ddeb04d5cf0ffa8f06e2feb7fe931664c44f50 Mon Sep 17 00:00:00 2001 From: Charles Chen Date: Mon, 23 Jun 2025 12:08:54 -0700 Subject: [PATCH] Quick fix for DeepGemm requant to also cover MTP. (#7378) --- python/sglang/srt/models/deepseek_v2.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 83b5c833e..c73200400 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1988,11 +1988,9 @@ class DeepseekV2ForCausalLM(nn.Module): and hasattr(self.quant_config, "weight_block_size") and self.quant_config.weight_block_size is not None ): - self._weight_requant_ue8m0() + self._weight_requant_ue8m0(is_nextn) - def _weight_requant_ue8m0(self): - if self.config.architectures[0] == "DeepseekV3ForCausalLMNextN": - return + def _weight_requant_ue8m0(self, is_nextn=False): weight_block_size = self.quant_config.weight_block_size moe_layers = list( @@ -2003,8 +2001,12 @@ class DeepseekV2ForCausalLM(nn.Module): ) ) - for layer_id in range(self.config.num_hidden_layers): - layer = self.model.layers[layer_id] + num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers + for layer_id in range(num_hidden_layers): + if is_nextn: + layer = self.model.decoder + else: + layer = self.model.layers[layer_id] for module in [ layer.self_attn.fused_qkv_a_proj_with_mqa, @@ -2016,7 +2018,7 @@ class DeepseekV2ForCausalLM(nn.Module): module.weight, module.weight_scale_inv, weight_block_size ) - if layer_id in moe_layers: + if layer_id in moe_layers or is_nextn: shared_experts = getattr(layer.mlp, "shared_experts", None) if shared_experts is not None: for module in [