From b4ac2b9c0cba38cbfff6655a55ffd3b656a86b04 Mon Sep 17 00:00:00 2001 From: DarkSharpness <76582120+DarkSharpness@users.noreply.github.com> Date: Mon, 11 Aug 2025 23:50:23 -0700 Subject: [PATCH] [Fix] Fix dual chunk model default behavior (#9032) --- .../attention/dual_chunk_flashattention_backend.py | 2 +- python/sglang/srt/model_executor/model_runner.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py b/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py index ea97ada22..84876b438 100644 --- a/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +++ b/python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py @@ -483,7 +483,7 @@ class DualChunkFlashAttentionBackend(AttentionBackend): ).squeeze(1) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) - def init_cuda_graph_state(self, max_bs: int): + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): """Initialize CUDA graph state for the attention backend. Args: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7681d5fe0..c865daf6b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -388,6 +388,19 @@ class ModelRunner: ): # override the default attention backend server_args.attention_backend = server_args.prefill_attention_backend + if ( + getattr(self.model_config.hf_config, "dual_chunk_attention_config", None) + is not None + ): + if server_args.attention_backend is None: + server_args.attention_backend = "dual_chunk_flash_attn" + logger.info("Dual chunk attention is turned on by default.") + elif server_args.attention_backend != "dual_chunk_flash_attn": + raise ValueError( + "Dual chunk attention is enabled, but attention backend is set to " + f"{server_args.attention_backend}. Please set it to 'dual_chunk_flash_attn'." + ) + if server_args.attention_backend is None: """ Auto select the fastest attention backend.