Enable MLA by default (#1447)
This commit is contained in:
@@ -40,7 +40,7 @@ global_server_args_dict = {
|
||||
"attention_backend": ServerArgs.attention_backend,
|
||||
"sampling_backend": ServerArgs.sampling_backend,
|
||||
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
|
||||
"enable_mla": ServerArgs.enable_mla,
|
||||
"disable_mla": ServerArgs.disable_mla,
|
||||
"torchao_config": ServerArgs.torchao_config,
|
||||
}
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ class ModelRunner:
|
||||
"attention_backend": server_args.attention_backend,
|
||||
"sampling_backend": server_args.sampling_backend,
|
||||
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
||||
"enable_mla": server_args.enable_mla,
|
||||
"disable_mla": server_args.disable_mla,
|
||||
"torchao_config": server_args.torchao_config,
|
||||
}
|
||||
)
|
||||
@@ -329,7 +329,7 @@ class ModelRunner:
|
||||
)
|
||||
if (
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
and self.server_args.enable_mla
|
||||
and not self.server_args.disable_mla
|
||||
):
|
||||
cell_size = (
|
||||
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
||||
@@ -397,7 +397,7 @@ class ModelRunner:
|
||||
)
|
||||
if (
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
and self.server_args.enable_mla
|
||||
and not self.server_args.disable_mla
|
||||
):
|
||||
self.token_to_kv_pool = MLATokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
|
||||
@@ -507,7 +507,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
if global_server_args_dict["enable_mla"]:
|
||||
if not global_server_args_dict["disable_mla"]:
|
||||
self.self_attn = DeepseekV2AttentionMLA(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
@@ -732,7 +732,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
if global_server_args_dict["enable_mla"]:
|
||||
if not global_server_args_dict["disable_mla"]:
|
||||
for layer_id in range(self.config.num_hidden_layers):
|
||||
self_attn = self.model.layers[layer_id].self_attn
|
||||
w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
|
||||
|
||||
@@ -419,7 +419,7 @@ class MiniCPM3DecoderLayer(nn.Module):
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
if global_server_args_dict["enable_mla"]:
|
||||
if not global_server_args_dict["disable_mla"]:
|
||||
self.self_attn = MiniCPM3AttentionMLA(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
@@ -653,7 +653,7 @@ class MiniCPM3ForCausalLM(nn.Module):
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
if global_server_args_dict["enable_mla"]:
|
||||
if not global_server_args_dict["disable_mla"]:
|
||||
for layer_id in range(self.config.num_hidden_layers):
|
||||
self_attn = self.model.layers[layer_id].self_attn
|
||||
w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
|
||||
|
||||
@@ -108,12 +108,12 @@ class ServerArgs:
|
||||
disable_cuda_graph_padding: bool = False
|
||||
disable_disk_cache: bool = False
|
||||
disable_custom_all_reduce: bool = False
|
||||
disable_mla: bool = False
|
||||
enable_mixed_chunk: bool = False
|
||||
enable_torch_compile: bool = False
|
||||
max_torch_compile_bs: int = 32
|
||||
torchao_config: str = ""
|
||||
enable_p2p_check: bool = False
|
||||
enable_mla: bool = False
|
||||
triton_attention_reduce_in_fp32: bool = False
|
||||
|
||||
# LoRA
|
||||
@@ -173,7 +173,7 @@ class ServerArgs:
|
||||
self.sampling_backend = "pytorch"
|
||||
|
||||
# Default kernel backends
|
||||
if self.enable_mla:
|
||||
if not self.disable_mla:
|
||||
logger.info("MLA optimization is tunred on. Use triton backend.")
|
||||
self.attention_backend = "triton"
|
||||
|
||||
@@ -514,6 +514,11 @@ class ServerArgs:
|
||||
default=False,
|
||||
help="Disable the custom all-reduce kernel and fall back to NCCL.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-mla",
|
||||
action="store_true",
|
||||
help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-mixed-chunk",
|
||||
action="store_true",
|
||||
@@ -541,11 +546,6 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-mla",
|
||||
action="store_true",
|
||||
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--triton-attention-reduce-in-fp32",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user