Enable MLA by default (#1447)
This commit is contained in:
@@ -419,7 +419,7 @@ class MiniCPM3DecoderLayer(nn.Module):
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
if global_server_args_dict["enable_mla"]:
|
||||
if not global_server_args_dict["disable_mla"]:
|
||||
self.self_attn = MiniCPM3AttentionMLA(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
@@ -653,7 +653,7 @@ class MiniCPM3ForCausalLM(nn.Module):
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
if global_server_args_dict["enable_mla"]:
|
||||
if not global_server_args_dict["disable_mla"]:
|
||||
for layer_id in range(self.config.num_hidden_layers):
|
||||
self_attn = self.model.layers[layer_id].self_attn
|
||||
w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
|
||||
|
||||
Reference in New Issue
Block a user