From c6b6d2e71b2b5626097d8e0d8c18f810e828d58e Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Tue, 17 Sep 2024 19:42:48 +0800 Subject: [PATCH] Enable MLA by default (#1447) --- README.md | 1 - docs/en/backend.md | 1 - python/sglang/srt/managers/schedule_batch.py | 2 +- python/sglang/srt/model_executor/model_runner.py | 6 +++--- python/sglang/srt/models/deepseek_v2.py | 4 ++-- python/sglang/srt/models/minicpm3.py | 4 ++-- python/sglang/srt/server_args.py | 14 +++++++------- test/srt/test_nightly_gsm8k_eval.py | 2 +- 8 files changed, 16 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 18f36fb62..6f23ca1f7 100644 --- a/README.md +++ b/README.md @@ -225,7 +225,6 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct - To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. - To enable fp8 weight quantization, add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments. - To enable fp8 kv cache quantization, add `--kv-cache-dtype fp8_e5m2`. -- To enable DeepSeek MLA acceleration, add `--enable-mla`. - If the model does not have a chat template in the Hugging Face tokenizer, you can specify a [custom chat template](docs/en/custom_chat_template.md). - To run tensor parallelism on multiple nodes, add `--nnodes 2`. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port. ``` diff --git a/docs/en/backend.md b/docs/en/backend.md index d8c3c7fb1..9e4fc7c26 100644 --- a/docs/en/backend.md +++ b/docs/en/backend.md @@ -81,7 +81,6 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct - To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. - To enable fp8 weight quantization, add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments. - To enable fp8 kv cache quantization, add `--kv-cache-dtype fp8_e5m2`. -- To enable DeepSeek MLA acceleration, add `--enable-mla`. - If the model does not have a chat template in the Hugging Face tokenizer, you can specify a [custom chat template](docs/en/custom_chat_template.md). - To run tensor parallelism on multiple nodes, add `--nnodes 2`. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port. ``` diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index edcf39026..2ab041726 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -40,7 +40,7 @@ global_server_args_dict = { "attention_backend": ServerArgs.attention_backend, "sampling_backend": ServerArgs.sampling_backend, "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32, - "enable_mla": ServerArgs.enable_mla, + "disable_mla": ServerArgs.disable_mla, "torchao_config": ServerArgs.torchao_config, } diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 123c98f56..9e614b81d 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -91,7 +91,7 @@ class ModelRunner: "attention_backend": server_args.attention_backend, "sampling_backend": server_args.sampling_backend, "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, - "enable_mla": server_args.enable_mla, + "disable_mla": server_args.disable_mla, "torchao_config": server_args.torchao_config, } ) @@ -329,7 +329,7 @@ class ModelRunner: ) if ( self.model_config.attention_arch == AttentionArch.MLA - and self.server_args.enable_mla + and not self.server_args.disable_mla ): cell_size = ( (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim) @@ -397,7 +397,7 @@ class ModelRunner: ) if ( self.model_config.attention_arch == AttentionArch.MLA - and self.server_args.enable_mla + and not self.server_args.disable_mla ): self.token_to_kv_pool = MLATokenToKVPool( self.max_total_num_tokens, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 30fa465d8..22b434a85 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -507,7 +507,7 @@ class DeepseekV2DecoderLayer(nn.Module): rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - if global_server_args_dict["enable_mla"]: + if not global_server_args_dict["disable_mla"]: self.self_attn = DeepseekV2AttentionMLA( config=config, hidden_size=self.hidden_size, @@ -732,7 +732,7 @@ class DeepseekV2ForCausalLM(nn.Module): ) weight_loader(param, loaded_weight) - if global_server_args_dict["enable_mla"]: + if not global_server_args_dict["disable_mla"]: for layer_id in range(self.config.num_hidden_layers): self_attn = self.model.layers[layer_id].self_attn w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten( diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index 40278f45d..a025d074e 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -419,7 +419,7 @@ class MiniCPM3DecoderLayer(nn.Module): rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - if global_server_args_dict["enable_mla"]: + if not global_server_args_dict["disable_mla"]: self.self_attn = MiniCPM3AttentionMLA( config=config, hidden_size=self.hidden_size, @@ -653,7 +653,7 @@ class MiniCPM3ForCausalLM(nn.Module): ) weight_loader(param, loaded_weight) - if global_server_args_dict["enable_mla"]: + if not global_server_args_dict["disable_mla"]: for layer_id in range(self.config.num_hidden_layers): self_attn = self.model.layers[layer_id].self_attn w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 33536ec16..a741d43a2 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -108,12 +108,12 @@ class ServerArgs: disable_cuda_graph_padding: bool = False disable_disk_cache: bool = False disable_custom_all_reduce: bool = False + disable_mla: bool = False enable_mixed_chunk: bool = False enable_torch_compile: bool = False max_torch_compile_bs: int = 32 torchao_config: str = "" enable_p2p_check: bool = False - enable_mla: bool = False triton_attention_reduce_in_fp32: bool = False # LoRA @@ -173,7 +173,7 @@ class ServerArgs: self.sampling_backend = "pytorch" # Default kernel backends - if self.enable_mla: + if not self.disable_mla: logger.info("MLA optimization is tunred on. Use triton backend.") self.attention_backend = "triton" @@ -514,6 +514,11 @@ class ServerArgs: default=False, help="Disable the custom all-reduce kernel and fall back to NCCL.", ) + parser.add_argument( + "--disable-mla", + action="store_true", + help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.", + ) parser.add_argument( "--enable-mixed-chunk", action="store_true", @@ -541,11 +546,6 @@ class ServerArgs: action="store_true", help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.", ) - parser.add_argument( - "--enable-mla", - action="store_true", - help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.", - ) parser.add_argument( "--triton-attention-reduce-in-fp32", action="store_true", diff --git a/test/srt/test_nightly_gsm8k_eval.py b/test/srt/test_nightly_gsm8k_eval.py index 08faae3f4..32a80dbd4 100644 --- a/test/srt/test_nightly_gsm8k_eval.py +++ b/test/srt/test_nightly_gsm8k_eval.py @@ -52,7 +52,7 @@ class TestEvalAccuracyLarge(unittest.TestCase): if is_tp2: other_args.extend(["--tp", "2"]) if "DeepSeek" in model: - other_args.extend(["--enable-mla", "--mem-frac", "0.85"]) + other_args.extend(["--mem-frac", "0.85"]) self.process = popen_launch_server( model,