Quick fix for DeepGemm requant to also cover MTP. (#7378)

This commit is contained in:
Charles Chen
2025-06-23 12:08:54 -07:00
committed by GitHub
parent bdbb8d009a
commit e5ddeb04d5

View File

@@ -1988,11 +1988,9 @@ class DeepseekV2ForCausalLM(nn.Module):
and hasattr(self.quant_config, "weight_block_size")
and self.quant_config.weight_block_size is not None
):
self._weight_requant_ue8m0()
self._weight_requant_ue8m0(is_nextn)
def _weight_requant_ue8m0(self):
if self.config.architectures[0] == "DeepseekV3ForCausalLMNextN":
return
def _weight_requant_ue8m0(self, is_nextn=False):
weight_block_size = self.quant_config.weight_block_size
moe_layers = list(
@@ -2003,8 +2001,12 @@ class DeepseekV2ForCausalLM(nn.Module):
)
)
for layer_id in range(self.config.num_hidden_layers):
layer = self.model.layers[layer_id]
num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
for layer_id in range(num_hidden_layers):
if is_nextn:
layer = self.model.decoder
else:
layer = self.model.layers[layer_id]
for module in [
layer.self_attn.fused_qkv_a_proj_with_mqa,
@@ -2016,7 +2018,7 @@ class DeepseekV2ForCausalLM(nn.Module):
module.weight, module.weight_scale_inv, weight_block_size
)
if layer_id in moe_layers:
if layer_id in moe_layers or is_nextn:
shared_experts = getattr(layer.mlp, "shared_experts", None)
if shared_experts is not None:
for module in [