diff --git a/sgl-kernel/python/sgl_kernel/flash_attn.py b/sgl-kernel/python/sgl_kernel/flash_attn.py index 36951325e..cbdcf35cb 100644 --- a/sgl-kernel/python/sgl_kernel/flash_attn.py +++ b/sgl-kernel/python/sgl_kernel/flash_attn.py @@ -1,4 +1,5 @@ -from typing import List, Optional, Tuple, Union +from functools import lru_cache +from typing import Optional, Union import torch import torch.nn as nn @@ -9,6 +10,7 @@ except: raise ImportError("Can not import sgl_kernel. Please check your installation.") +@lru_cache(maxsize=1) def is_fa3_supported(device=None) -> bool: # There some fa3 FYI # FA3 can fail without a enough shared memory for a some shapes, such as higher @@ -18,10 +20,10 @@ def is_fa3_supported(device=None) -> bool: # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x # And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. # That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3. - return ( + return (torch.version.cuda >= "12.3") and ( torch.cuda.get_device_capability(device)[0] == 9 or torch.cuda.get_device_capability(device)[0] == 8 - ) and (torch.version.cuda >= "12.3") + ) def maybe_contiguous(x): diff --git a/sgl-kernel/tests/test_flash_attention.py b/sgl-kernel/tests/test_flash_attention.py index 0900e5940..159390e54 100644 --- a/sgl-kernel/tests/test_flash_attention.py +++ b/sgl-kernel/tests/test_flash_attention.py @@ -25,10 +25,10 @@ def is_fa3_supported(device=None) -> bool: # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x # And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. # That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3. - return ( + return (torch.version.cuda >= "12.3") and ( torch.cuda.get_device_capability(device)[0] == 9 or torch.cuda.get_device_capability(device)[0] == 8 - ) and (torch.version.cuda >= "12.3") + ) DISABLE_BACKWARD = True