diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 5c8200f57..2be3e450b 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -12,7 +12,12 @@ import torch.nn.functional as F from einops import rearrange from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size -from sglang.srt.utils import is_cuda, print_info_once +from sglang.srt.utils import ( + get_device_capability, + is_blackwell, + is_cuda, + print_info_once, +) _is_cuda = is_cuda() @@ -20,7 +25,6 @@ if _is_cuda: from sgl_kernel.flash_attn import flash_attn_varlen_func from sglang.srt.distributed import ( - parallel_state, split_tensor_along_last_dim, tensor_model_parallel_all_gather, ) @@ -402,18 +406,14 @@ class VisionAttention(nn.Module): self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim ) - # priority: server_args > passed qkv_backend > sdpa - if global_server_args_dict["mm_attention_backend"] is None: - if qkv_backend is None: - if is_cuda(): - # Double prefill throughput by setting attn backend to Triton on CUDA - qkv_backend = "triton_attn" - else: - qkv_backend = "sdpa" + # Select attention backend via a unified method + _passed_backend = qkv_backend + qkv_backend = self._determine_attention_backend(_passed_backend) + if ( + global_server_args_dict["mm_attention_backend"] is None + and _passed_backend is None + ): print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.") - else: - qkv_backend = global_server_args_dict["mm_attention_backend"] - print_info_once(f"Using {qkv_backend} as multimodal attention backend.") self.customized_position_embedding_applier = ( @@ -461,6 +461,33 @@ class VisionAttention(nn.Module): prefix=add_prefix("proj", prefix), ) + def _determine_attention_backend(self, passed_backend: Optional[str]) -> str: + """Decide the multimodal attention backend string. + + Priority: server args override > constructor arg > platform default. + + Platform defaults: + - CUDA: "triton_attn" + - Non-CUDA: "sdpa" + """ + override_backend = global_server_args_dict["mm_attention_backend"] + if override_backend is not None: + backend = override_backend + elif passed_backend is not None: + backend = passed_backend + elif is_cuda(): + major, minor = get_device_capability() + if major == 9: + backend = "fa3" + else: + backend = "triton_attn" + else: + backend = "sdpa" + if backend == "fa3" and is_blackwell(): + raise ValueError("The 'fa3' backend is not supported on Blackwell GPUs") + + return backend + def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor): """apply qk norm for internvl vit attn""" q = q.flatten(1, 2)