support qwen3_next blackwell (#10403)
This commit is contained in:
@@ -80,7 +80,13 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
|
||||
get_attention_tp_size()
|
||||
)
|
||||
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
||||
if model_runner.is_hybrid_gdn:
|
||||
# For hybrid linear models, layer_id = 0 may not be full attention
|
||||
self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
|
||||
else:
|
||||
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[
|
||||
-1
|
||||
]
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
self.device = model_runner.device
|
||||
self.device_core_count = get_device_core_count(model_runner.gpu_id)
|
||||
|
||||
@@ -728,6 +728,9 @@ class HybridLinearKVPool(KVCache):
|
||||
layer_id_override=layer_id,
|
||||
)
|
||||
|
||||
def get_v_head_dim(self):
|
||||
return self.full_kv_pool.get_value_buffer(0).shape[-1]
|
||||
|
||||
|
||||
class SWAKVPool(KVCache):
|
||||
"""KV cache with separate pools for full and SWA attention layers."""
|
||||
|
||||
@@ -127,6 +127,7 @@ from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
get_cpu_ids_by_node,
|
||||
init_custom_process_group,
|
||||
is_blackwell,
|
||||
is_fa3_default_architecture,
|
||||
is_flashinfer_available,
|
||||
is_hip,
|
||||
@@ -1832,6 +1833,10 @@ class ModelRunner:
|
||||
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
||||
|
||||
full_attn_backend = AscendAttnBackend(self)
|
||||
elif is_blackwell():
|
||||
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||
|
||||
full_attn_backend = TritonAttnBackend(self)
|
||||
else:
|
||||
from sglang.srt.layers.attention.flashattention_backend import (
|
||||
FlashAttentionBackend,
|
||||
|
||||
@@ -48,6 +48,7 @@ from sglang.srt.utils import (
|
||||
empty_context,
|
||||
get_available_gpu_memory,
|
||||
get_bool_env_var,
|
||||
is_blackwell,
|
||||
is_cuda,
|
||||
next_power_of_2,
|
||||
)
|
||||
@@ -214,7 +215,11 @@ class EAGLEWorker(TpModelWorker):
|
||||
"triton": self._create_triton_decode_backend,
|
||||
"aiter": self._create_aiter_decode_backend,
|
||||
"fa3": self._create_fa3_decode_backend,
|
||||
"hybrid_linear_attn": self._create_fa3_decode_backend,
|
||||
"hybrid_linear_attn": (
|
||||
self._create_fa3_decode_backend
|
||||
if not is_blackwell()
|
||||
else self._create_triton_decode_backend
|
||||
),
|
||||
"flashmla": self._create_flashmla_decode_backend,
|
||||
"trtllm_mha": self._create_trtllm_mha_decode_backend,
|
||||
"trtllm_mla": self._create_trtllm_mla_decode_backend,
|
||||
@@ -232,7 +237,11 @@ class EAGLEWorker(TpModelWorker):
|
||||
"triton": self._create_triton_prefill_backend,
|
||||
"aiter": self._create_aiter_prefill_backend,
|
||||
"fa3": self._create_fa3_prefill_backend,
|
||||
"hybrid_linear_attn": self._create_fa3_prefill_backend,
|
||||
"hybrid_linear_attn": (
|
||||
self._create_fa3_prefill_backend
|
||||
if not is_blackwell()
|
||||
else self._create_triton_prefill_backend
|
||||
),
|
||||
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
|
||||
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user