From 297d374510269dc91197bb3a4aed6925226e7542 Mon Sep 17 00:00:00 2001 From: Yi Zhang <1109276519@qq.com> Date: Sat, 13 Sep 2025 17:18:26 +0800 Subject: [PATCH] support qwen3_next blackwell (#10403) --- .../sglang/srt/layers/attention/triton_backend.py | 8 +++++++- python/sglang/srt/mem_cache/memory_pool.py | 3 +++ python/sglang/srt/model_executor/model_runner.py | 5 +++++ python/sglang/srt/speculative/eagle_worker.py | 13 +++++++++++-- 4 files changed, 26 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 26241d849..1c6da6934 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -80,7 +80,13 @@ class TritonAttnBackend(AttentionBackend): self.num_kv_head = model_runner.model_config.get_num_kv_heads( get_attention_tp_size() ) - self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] + if model_runner.is_hybrid_gdn: + # For hybrid linear models, layer_id = 0 may not be full attention + self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim() + else: + self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[ + -1 + ] self.max_context_len = model_runner.model_config.context_len self.device = model_runner.device self.device_core_count = get_device_core_count(model_runner.gpu_id) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 1bd684ad3..7de38aabd 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -728,6 +728,9 @@ class HybridLinearKVPool(KVCache): layer_id_override=layer_id, ) + def get_v_head_dim(self): + return self.full_kv_pool.get_value_buffer(0).shape[-1] + class SWAKVPool(KVCache): """KV cache with separate pools for full and SWA attention layers.""" diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 36785d86f..c9aae4d2b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -127,6 +127,7 @@ from sglang.srt.utils import ( get_bool_env_var, get_cpu_ids_by_node, init_custom_process_group, + is_blackwell, is_fa3_default_architecture, is_flashinfer_available, is_hip, @@ -1832,6 +1833,10 @@ class ModelRunner: from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend full_attn_backend = AscendAttnBackend(self) + elif is_blackwell(): + from sglang.srt.layers.attention.triton_backend import TritonAttnBackend + + full_attn_backend = TritonAttnBackend(self) else: from sglang.srt.layers.attention.flashattention_backend import ( FlashAttentionBackend, diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index f454971ca..d3adec7b7 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -48,6 +48,7 @@ from sglang.srt.utils import ( empty_context, get_available_gpu_memory, get_bool_env_var, + is_blackwell, is_cuda, next_power_of_2, ) @@ -214,7 +215,11 @@ class EAGLEWorker(TpModelWorker): "triton": self._create_triton_decode_backend, "aiter": self._create_aiter_decode_backend, "fa3": self._create_fa3_decode_backend, - "hybrid_linear_attn": self._create_fa3_decode_backend, + "hybrid_linear_attn": ( + self._create_fa3_decode_backend + if not is_blackwell() + else self._create_triton_decode_backend + ), "flashmla": self._create_flashmla_decode_backend, "trtllm_mha": self._create_trtllm_mha_decode_backend, "trtllm_mla": self._create_trtllm_mla_decode_backend, @@ -232,7 +237,11 @@ class EAGLEWorker(TpModelWorker): "triton": self._create_triton_prefill_backend, "aiter": self._create_aiter_prefill_backend, "fa3": self._create_fa3_prefill_backend, - "hybrid_linear_attn": self._create_fa3_prefill_backend, + "hybrid_linear_attn": ( + self._create_fa3_prefill_backend + if not is_blackwell() + else self._create_triton_prefill_backend + ), "trtllm_mha": self._create_trtllm_mha_prefill_backend, "trtllm_mla": self._create_trtllm_mla_prefill_backend, }