support qwen3_next blackwell (#10403)
This commit is contained in:
@@ -80,7 +80,13 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
|
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
|
||||||
get_attention_tp_size()
|
get_attention_tp_size()
|
||||||
)
|
)
|
||||||
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
if model_runner.is_hybrid_gdn:
|
||||||
|
# For hybrid linear models, layer_id = 0 may not be full attention
|
||||||
|
self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
|
||||||
|
else:
|
||||||
|
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[
|
||||||
|
-1
|
||||||
|
]
|
||||||
self.max_context_len = model_runner.model_config.context_len
|
self.max_context_len = model_runner.model_config.context_len
|
||||||
self.device = model_runner.device
|
self.device = model_runner.device
|
||||||
self.device_core_count = get_device_core_count(model_runner.gpu_id)
|
self.device_core_count = get_device_core_count(model_runner.gpu_id)
|
||||||
|
|||||||
@@ -728,6 +728,9 @@ class HybridLinearKVPool(KVCache):
|
|||||||
layer_id_override=layer_id,
|
layer_id_override=layer_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_v_head_dim(self):
|
||||||
|
return self.full_kv_pool.get_value_buffer(0).shape[-1]
|
||||||
|
|
||||||
|
|
||||||
class SWAKVPool(KVCache):
|
class SWAKVPool(KVCache):
|
||||||
"""KV cache with separate pools for full and SWA attention layers."""
|
"""KV cache with separate pools for full and SWA attention layers."""
|
||||||
|
|||||||
@@ -127,6 +127,7 @@ from sglang.srt.utils import (
|
|||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
get_cpu_ids_by_node,
|
get_cpu_ids_by_node,
|
||||||
init_custom_process_group,
|
init_custom_process_group,
|
||||||
|
is_blackwell,
|
||||||
is_fa3_default_architecture,
|
is_fa3_default_architecture,
|
||||||
is_flashinfer_available,
|
is_flashinfer_available,
|
||||||
is_hip,
|
is_hip,
|
||||||
@@ -1832,6 +1833,10 @@ class ModelRunner:
|
|||||||
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
||||||
|
|
||||||
full_attn_backend = AscendAttnBackend(self)
|
full_attn_backend = AscendAttnBackend(self)
|
||||||
|
elif is_blackwell():
|
||||||
|
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||||
|
|
||||||
|
full_attn_backend = TritonAttnBackend(self)
|
||||||
else:
|
else:
|
||||||
from sglang.srt.layers.attention.flashattention_backend import (
|
from sglang.srt.layers.attention.flashattention_backend import (
|
||||||
FlashAttentionBackend,
|
FlashAttentionBackend,
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ from sglang.srt.utils import (
|
|||||||
empty_context,
|
empty_context,
|
||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
|
is_blackwell,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
next_power_of_2,
|
next_power_of_2,
|
||||||
)
|
)
|
||||||
@@ -214,7 +215,11 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
"triton": self._create_triton_decode_backend,
|
"triton": self._create_triton_decode_backend,
|
||||||
"aiter": self._create_aiter_decode_backend,
|
"aiter": self._create_aiter_decode_backend,
|
||||||
"fa3": self._create_fa3_decode_backend,
|
"fa3": self._create_fa3_decode_backend,
|
||||||
"hybrid_linear_attn": self._create_fa3_decode_backend,
|
"hybrid_linear_attn": (
|
||||||
|
self._create_fa3_decode_backend
|
||||||
|
if not is_blackwell()
|
||||||
|
else self._create_triton_decode_backend
|
||||||
|
),
|
||||||
"flashmla": self._create_flashmla_decode_backend,
|
"flashmla": self._create_flashmla_decode_backend,
|
||||||
"trtllm_mha": self._create_trtllm_mha_decode_backend,
|
"trtllm_mha": self._create_trtllm_mha_decode_backend,
|
||||||
"trtllm_mla": self._create_trtllm_mla_decode_backend,
|
"trtllm_mla": self._create_trtllm_mla_decode_backend,
|
||||||
@@ -232,7 +237,11 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
"triton": self._create_triton_prefill_backend,
|
"triton": self._create_triton_prefill_backend,
|
||||||
"aiter": self._create_aiter_prefill_backend,
|
"aiter": self._create_aiter_prefill_backend,
|
||||||
"fa3": self._create_fa3_prefill_backend,
|
"fa3": self._create_fa3_prefill_backend,
|
||||||
"hybrid_linear_attn": self._create_fa3_prefill_backend,
|
"hybrid_linear_attn": (
|
||||||
|
self._create_fa3_prefill_backend
|
||||||
|
if not is_blackwell()
|
||||||
|
else self._create_triton_prefill_backend
|
||||||
|
),
|
||||||
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
|
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
|
||||||
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
|
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user