minor: determine mm attn backend based on platforms (#9303)
This commit is contained in:
@@ -12,7 +12,12 @@ import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
||||
from sglang.srt.utils import is_cuda, print_info_once
|
||||
from sglang.srt.utils import (
|
||||
get_device_capability,
|
||||
is_blackwell,
|
||||
is_cuda,
|
||||
print_info_once,
|
||||
)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
@@ -20,7 +25,6 @@ if _is_cuda:
|
||||
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
parallel_state,
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
@@ -402,18 +406,14 @@ class VisionAttention(nn.Module):
|
||||
self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
|
||||
)
|
||||
|
||||
# priority: server_args > passed qkv_backend > sdpa
|
||||
if global_server_args_dict["mm_attention_backend"] is None:
|
||||
if qkv_backend is None:
|
||||
if is_cuda():
|
||||
# Double prefill throughput by setting attn backend to Triton on CUDA
|
||||
qkv_backend = "triton_attn"
|
||||
else:
|
||||
qkv_backend = "sdpa"
|
||||
# Select attention backend via a unified method
|
||||
_passed_backend = qkv_backend
|
||||
qkv_backend = self._determine_attention_backend(_passed_backend)
|
||||
if (
|
||||
global_server_args_dict["mm_attention_backend"] is None
|
||||
and _passed_backend is None
|
||||
):
|
||||
print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
|
||||
else:
|
||||
qkv_backend = global_server_args_dict["mm_attention_backend"]
|
||||
|
||||
print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
|
||||
|
||||
self.customized_position_embedding_applier = (
|
||||
@@ -461,6 +461,33 @@ class VisionAttention(nn.Module):
|
||||
prefix=add_prefix("proj", prefix),
|
||||
)
|
||||
|
||||
def _determine_attention_backend(self, passed_backend: Optional[str]) -> str:
|
||||
"""Decide the multimodal attention backend string.
|
||||
|
||||
Priority: server args override > constructor arg > platform default.
|
||||
|
||||
Platform defaults:
|
||||
- CUDA: "triton_attn"
|
||||
- Non-CUDA: "sdpa"
|
||||
"""
|
||||
override_backend = global_server_args_dict["mm_attention_backend"]
|
||||
if override_backend is not None:
|
||||
backend = override_backend
|
||||
elif passed_backend is not None:
|
||||
backend = passed_backend
|
||||
elif is_cuda():
|
||||
major, minor = get_device_capability()
|
||||
if major == 9:
|
||||
backend = "fa3"
|
||||
else:
|
||||
backend = "triton_attn"
|
||||
else:
|
||||
backend = "sdpa"
|
||||
if backend == "fa3" and is_blackwell():
|
||||
raise ValueError("The 'fa3' backend is not supported on Blackwell GPUs")
|
||||
|
||||
return backend
|
||||
|
||||
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
|
||||
"""apply qk norm for internvl vit attn"""
|
||||
q = q.flatten(1, 2)
|
||||
|
||||
Reference in New Issue
Block a user