Quick fix for DeepGemm requant to also cover MTP. (#7378)
This commit is contained in:
@@ -1988,11 +1988,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
and hasattr(self.quant_config, "weight_block_size")
|
||||
and self.quant_config.weight_block_size is not None
|
||||
):
|
||||
self._weight_requant_ue8m0()
|
||||
self._weight_requant_ue8m0(is_nextn)
|
||||
|
||||
def _weight_requant_ue8m0(self):
|
||||
if self.config.architectures[0] == "DeepseekV3ForCausalLMNextN":
|
||||
return
|
||||
def _weight_requant_ue8m0(self, is_nextn=False):
|
||||
weight_block_size = self.quant_config.weight_block_size
|
||||
|
||||
moe_layers = list(
|
||||
@@ -2003,8 +2001,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
)
|
||||
)
|
||||
|
||||
for layer_id in range(self.config.num_hidden_layers):
|
||||
layer = self.model.layers[layer_id]
|
||||
num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
|
||||
for layer_id in range(num_hidden_layers):
|
||||
if is_nextn:
|
||||
layer = self.model.decoder
|
||||
else:
|
||||
layer = self.model.layers[layer_id]
|
||||
|
||||
for module in [
|
||||
layer.self_attn.fused_qkv_a_proj_with_mqa,
|
||||
@@ -2016,7 +2018,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
module.weight, module.weight_scale_inv, weight_block_size
|
||||
)
|
||||
|
||||
if layer_id in moe_layers:
|
||||
if layer_id in moe_layers or is_nextn:
|
||||
shared_experts = getattr(layer.mlp, "shared_experts", None)
|
||||
if shared_experts is not None:
|
||||
for module in [
|
||||
|
||||
Reference in New Issue
Block a user