Quick fix for DeepGemm requant to also cover MTP. (#7378)
This commit is contained in:
@@ -1988,11 +1988,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
and hasattr(self.quant_config, "weight_block_size")
|
and hasattr(self.quant_config, "weight_block_size")
|
||||||
and self.quant_config.weight_block_size is not None
|
and self.quant_config.weight_block_size is not None
|
||||||
):
|
):
|
||||||
self._weight_requant_ue8m0()
|
self._weight_requant_ue8m0(is_nextn)
|
||||||
|
|
||||||
def _weight_requant_ue8m0(self):
|
def _weight_requant_ue8m0(self, is_nextn=False):
|
||||||
if self.config.architectures[0] == "DeepseekV3ForCausalLMNextN":
|
|
||||||
return
|
|
||||||
weight_block_size = self.quant_config.weight_block_size
|
weight_block_size = self.quant_config.weight_block_size
|
||||||
|
|
||||||
moe_layers = list(
|
moe_layers = list(
|
||||||
@@ -2003,8 +2001,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
for layer_id in range(self.config.num_hidden_layers):
|
num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
|
||||||
layer = self.model.layers[layer_id]
|
for layer_id in range(num_hidden_layers):
|
||||||
|
if is_nextn:
|
||||||
|
layer = self.model.decoder
|
||||||
|
else:
|
||||||
|
layer = self.model.layers[layer_id]
|
||||||
|
|
||||||
for module in [
|
for module in [
|
||||||
layer.self_attn.fused_qkv_a_proj_with_mqa,
|
layer.self_attn.fused_qkv_a_proj_with_mqa,
|
||||||
@@ -2016,7 +2018,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
module.weight, module.weight_scale_inv, weight_block_size
|
module.weight, module.weight_scale_inv, weight_block_size
|
||||||
)
|
)
|
||||||
|
|
||||||
if layer_id in moe_layers:
|
if layer_id in moe_layers or is_nextn:
|
||||||
shared_experts = getattr(layer.mlp, "shared_experts", None)
|
shared_experts = getattr(layer.mlp, "shared_experts", None)
|
||||||
if shared_experts is not None:
|
if shared_experts is not None:
|
||||||
for module in [
|
for module in [
|
||||||
|
|||||||
Reference in New Issue
Block a user