set default attention backend for deterministic inference (#11801)
This commit is contained in:
@@ -44,6 +44,7 @@ from sglang.srt.utils import (
|
||||
is_remote_url,
|
||||
is_sm90_supported,
|
||||
is_sm100_supported,
|
||||
is_sm120_supported,
|
||||
is_triton_kernels_available,
|
||||
is_valid_ipv6_address,
|
||||
json_list_type,
|
||||
@@ -1411,9 +1412,23 @@ class ServerArgs:
|
||||
)
|
||||
|
||||
# Check attention backend
|
||||
if self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES:
|
||||
if self.attention_backend is None:
|
||||
# User didn't specify attention backend, fallback based on GPU architecture
|
||||
if is_sm100_supported() or is_sm120_supported():
|
||||
# Blackwell and newer architectures
|
||||
self.attention_backend = "flashinfer"
|
||||
else:
|
||||
# Hopper (SM90) and older architectures
|
||||
self.attention_backend = "fa3"
|
||||
logger.warning(
|
||||
f"Attention backend not specified. Falling back to '{self.attention_backend}' for deterministic inference. "
|
||||
f"You can explicitly set --attention-backend to one of {DETERMINISTIC_ATTENTION_BACKEND_CHOICES}."
|
||||
)
|
||||
elif self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES:
|
||||
# User explicitly specified an incompatible attention backend
|
||||
raise ValueError(
|
||||
f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference."
|
||||
f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference, "
|
||||
f"but you explicitly specified '{self.attention_backend}'."
|
||||
)
|
||||
|
||||
# Currently, only FA3 supports radix cache. Support for other backends is in progress
|
||||
|
||||
@@ -174,6 +174,15 @@ def is_blackwell():
|
||||
return torch.cuda.get_device_capability()[0] == 10
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def is_sm120_supported(device=None) -> bool:
|
||||
if not is_cuda_alike():
|
||||
return False
|
||||
return (torch.cuda.get_device_capability(device)[0] == 12) and (
|
||||
torch.version.cuda >= "12.8"
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def is_sm100_supported(device=None) -> bool:
|
||||
if not is_cuda_alike():
|
||||
|
||||
Reference in New Issue
Block a user