Enable MLA by default (#1447)

This commit is contained in:
Ke Bao
2024-09-17 19:42:48 +08:00
committed by GitHub
parent 90a26be31c
commit c6b6d2e71b
8 changed files with 16 additions and 18 deletions

View File

@@ -419,7 +419,7 @@ class MiniCPM3DecoderLayer(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
if global_server_args_dict["enable_mla"]:
if not global_server_args_dict["disable_mla"]:
self.self_attn = MiniCPM3AttentionMLA(
config=config,
hidden_size=self.hidden_size,
@@ -653,7 +653,7 @@ class MiniCPM3ForCausalLM(nn.Module):
)
weight_loader(param, loaded_weight)
if global_server_args_dict["enable_mla"]:
if not global_server_args_dict["disable_mla"]:
for layer_id in range(self.config.num_hidden_layers):
self_attn = self.model.layers[layer_id].self_attn
w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(