minor: determine mm attn backend based on platforms (#9303)
This commit is contained in:
@@ -12,7 +12,12 @@ import torch.nn.functional as F
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
||||||
from sglang.srt.utils import is_cuda, print_info_once
|
from sglang.srt.utils import (
|
||||||
|
get_device_capability,
|
||||||
|
is_blackwell,
|
||||||
|
is_cuda,
|
||||||
|
print_info_once,
|
||||||
|
)
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
@@ -20,7 +25,6 @@ if _is_cuda:
|
|||||||
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
||||||
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
parallel_state,
|
|
||||||
split_tensor_along_last_dim,
|
split_tensor_along_last_dim,
|
||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
)
|
)
|
||||||
@@ -402,18 +406,14 @@ class VisionAttention(nn.Module):
|
|||||||
self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
|
self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
|
||||||
)
|
)
|
||||||
|
|
||||||
# priority: server_args > passed qkv_backend > sdpa
|
# Select attention backend via a unified method
|
||||||
if global_server_args_dict["mm_attention_backend"] is None:
|
_passed_backend = qkv_backend
|
||||||
if qkv_backend is None:
|
qkv_backend = self._determine_attention_backend(_passed_backend)
|
||||||
if is_cuda():
|
if (
|
||||||
# Double prefill throughput by setting attn backend to Triton on CUDA
|
global_server_args_dict["mm_attention_backend"] is None
|
||||||
qkv_backend = "triton_attn"
|
and _passed_backend is None
|
||||||
else:
|
):
|
||||||
qkv_backend = "sdpa"
|
|
||||||
print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
|
print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
|
||||||
else:
|
|
||||||
qkv_backend = global_server_args_dict["mm_attention_backend"]
|
|
||||||
|
|
||||||
print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
|
print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
|
||||||
|
|
||||||
self.customized_position_embedding_applier = (
|
self.customized_position_embedding_applier = (
|
||||||
@@ -461,6 +461,33 @@ class VisionAttention(nn.Module):
|
|||||||
prefix=add_prefix("proj", prefix),
|
prefix=add_prefix("proj", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _determine_attention_backend(self, passed_backend: Optional[str]) -> str:
|
||||||
|
"""Decide the multimodal attention backend string.
|
||||||
|
|
||||||
|
Priority: server args override > constructor arg > platform default.
|
||||||
|
|
||||||
|
Platform defaults:
|
||||||
|
- CUDA: "triton_attn"
|
||||||
|
- Non-CUDA: "sdpa"
|
||||||
|
"""
|
||||||
|
override_backend = global_server_args_dict["mm_attention_backend"]
|
||||||
|
if override_backend is not None:
|
||||||
|
backend = override_backend
|
||||||
|
elif passed_backend is not None:
|
||||||
|
backend = passed_backend
|
||||||
|
elif is_cuda():
|
||||||
|
major, minor = get_device_capability()
|
||||||
|
if major == 9:
|
||||||
|
backend = "fa3"
|
||||||
|
else:
|
||||||
|
backend = "triton_attn"
|
||||||
|
else:
|
||||||
|
backend = "sdpa"
|
||||||
|
if backend == "fa3" and is_blackwell():
|
||||||
|
raise ValueError("The 'fa3' backend is not supported on Blackwell GPUs")
|
||||||
|
|
||||||
|
return backend
|
||||||
|
|
||||||
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
|
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
|
||||||
"""apply qk norm for internvl vit attn"""
|
"""apply qk norm for internvl vit attn"""
|
||||||
q = q.flatten(1, 2)
|
q = q.flatten(1, 2)
|
||||||
|
|||||||
Reference in New Issue
Block a user