diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 9de754d35..b6074f86b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -44,6 +44,7 @@ from sglang.srt.utils import ( is_remote_url, is_sm90_supported, is_sm100_supported, + is_sm120_supported, is_triton_kernels_available, is_valid_ipv6_address, json_list_type, @@ -1411,9 +1412,23 @@ class ServerArgs: ) # Check attention backend - if self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES: + if self.attention_backend is None: + # User didn't specify attention backend, fallback based on GPU architecture + if is_sm100_supported() or is_sm120_supported(): + # Blackwell and newer architectures + self.attention_backend = "flashinfer" + else: + # Hopper (SM90) and older architectures + self.attention_backend = "fa3" + logger.warning( + f"Attention backend not specified. Falling back to '{self.attention_backend}' for deterministic inference. " + f"You can explicitly set --attention-backend to one of {DETERMINISTIC_ATTENTION_BACKEND_CHOICES}." + ) + elif self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES: + # User explicitly specified an incompatible attention backend raise ValueError( - f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference." + f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference, " + f"but you explicitly specified '{self.attention_backend}'." ) # Currently, only FA3 supports radix cache. Support for other backends is in progress diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index e898b8923..3436b2682 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -174,6 +174,15 @@ def is_blackwell(): return torch.cuda.get_device_capability()[0] == 10 +@lru_cache(maxsize=1) +def is_sm120_supported(device=None) -> bool: + if not is_cuda_alike(): + return False + return (torch.cuda.get_device_capability(device)[0] == 12) and ( + torch.version.cuda >= "12.8" + ) + + @lru_cache(maxsize=1) def is_sm100_supported(device=None) -> bool: if not is_cuda_alike():