set default attention backend for deterministic inference (#11801)
This commit is contained in:
@@ -44,6 +44,7 @@ from sglang.srt.utils import (
|
|||||||
is_remote_url,
|
is_remote_url,
|
||||||
is_sm90_supported,
|
is_sm90_supported,
|
||||||
is_sm100_supported,
|
is_sm100_supported,
|
||||||
|
is_sm120_supported,
|
||||||
is_triton_kernels_available,
|
is_triton_kernels_available,
|
||||||
is_valid_ipv6_address,
|
is_valid_ipv6_address,
|
||||||
json_list_type,
|
json_list_type,
|
||||||
@@ -1411,9 +1412,23 @@ class ServerArgs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check attention backend
|
# Check attention backend
|
||||||
if self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES:
|
if self.attention_backend is None:
|
||||||
|
# User didn't specify attention backend, fallback based on GPU architecture
|
||||||
|
if is_sm100_supported() or is_sm120_supported():
|
||||||
|
# Blackwell and newer architectures
|
||||||
|
self.attention_backend = "flashinfer"
|
||||||
|
else:
|
||||||
|
# Hopper (SM90) and older architectures
|
||||||
|
self.attention_backend = "fa3"
|
||||||
|
logger.warning(
|
||||||
|
f"Attention backend not specified. Falling back to '{self.attention_backend}' for deterministic inference. "
|
||||||
|
f"You can explicitly set --attention-backend to one of {DETERMINISTIC_ATTENTION_BACKEND_CHOICES}."
|
||||||
|
)
|
||||||
|
elif self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES:
|
||||||
|
# User explicitly specified an incompatible attention backend
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference."
|
f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference, "
|
||||||
|
f"but you explicitly specified '{self.attention_backend}'."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Currently, only FA3 supports radix cache. Support for other backends is in progress
|
# Currently, only FA3 supports radix cache. Support for other backends is in progress
|
||||||
|
|||||||
@@ -174,6 +174,15 @@ def is_blackwell():
|
|||||||
return torch.cuda.get_device_capability()[0] == 10
|
return torch.cuda.get_device_capability()[0] == 10
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def is_sm120_supported(device=None) -> bool:
|
||||||
|
if not is_cuda_alike():
|
||||||
|
return False
|
||||||
|
return (torch.cuda.get_device_capability(device)[0] == 12) and (
|
||||||
|
torch.version.cuda >= "12.8"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
def is_sm100_supported(device=None) -> bool:
|
def is_sm100_supported(device=None) -> bool:
|
||||||
if not is_cuda_alike():
|
if not is_cuda_alike():
|
||||||
|
|||||||
Reference in New Issue
Block a user