support qwen3_next blackwell (#10403)

This commit is contained in:
Yi Zhang
2025-09-13 17:18:26 +08:00
committed by GitHub
parent 31e9d3a5aa
commit 297d374510
4 changed files with 26 additions and 3 deletions

View File

@@ -80,7 +80,13 @@ class TritonAttnBackend(AttentionBackend):
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
if model_runner.is_hybrid_gdn:
# For hybrid linear models, layer_id = 0 may not be full attention
self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
else:
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[
-1
]
self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device
self.device_core_count = get_device_core_count(model_runner.gpu_id)

View File

@@ -728,6 +728,9 @@ class HybridLinearKVPool(KVCache):
layer_id_override=layer_id,
)
def get_v_head_dim(self):
return self.full_kv_pool.get_value_buffer(0).shape[-1]
class SWAKVPool(KVCache):
"""KV cache with separate pools for full and SWA attention layers."""

View File

@@ -127,6 +127,7 @@ from sglang.srt.utils import (
get_bool_env_var,
get_cpu_ids_by_node,
init_custom_process_group,
is_blackwell,
is_fa3_default_architecture,
is_flashinfer_available,
is_hip,
@@ -1832,6 +1833,10 @@ class ModelRunner:
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
full_attn_backend = AscendAttnBackend(self)
elif is_blackwell():
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
full_attn_backend = TritonAttnBackend(self)
else:
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,

View File

@@ -48,6 +48,7 @@ from sglang.srt.utils import (
empty_context,
get_available_gpu_memory,
get_bool_env_var,
is_blackwell,
is_cuda,
next_power_of_2,
)
@@ -214,7 +215,11 @@ class EAGLEWorker(TpModelWorker):
"triton": self._create_triton_decode_backend,
"aiter": self._create_aiter_decode_backend,
"fa3": self._create_fa3_decode_backend,
"hybrid_linear_attn": self._create_fa3_decode_backend,
"hybrid_linear_attn": (
self._create_fa3_decode_backend
if not is_blackwell()
else self._create_triton_decode_backend
),
"flashmla": self._create_flashmla_decode_backend,
"trtllm_mha": self._create_trtllm_mha_decode_backend,
"trtllm_mla": self._create_trtllm_mla_decode_backend,
@@ -232,7 +237,11 @@ class EAGLEWorker(TpModelWorker):
"triton": self._create_triton_prefill_backend,
"aiter": self._create_aiter_prefill_backend,
"fa3": self._create_fa3_prefill_backend,
"hybrid_linear_attn": self._create_fa3_prefill_backend,
"hybrid_linear_attn": (
self._create_fa3_prefill_backend
if not is_blackwell()
else self._create_triton_prefill_backend
),
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
}