diff --git a/sgl-kernel/python/sgl_kernel/flash_attn.py b/sgl-kernel/python/sgl_kernel/flash_attn.py index acf0807b0..a14173568 100644 --- a/sgl-kernel/python/sgl_kernel/flash_attn.py +++ b/sgl-kernel/python/sgl_kernel/flash_attn.py @@ -10,15 +10,9 @@ except: def is_fa3_supported(device=None) -> bool: - # FA3 can fail without a enough shared memory for a some shapes, currently - # only 8.0 and 8.7 have enough shared memory for all shapes - # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x - # now sgl-kernel only build fa3 for sm90a && cuda >= 12.4 - return ( - (torch.cuda.get_device_capability(device)[0] >= 9) - and (torch.version.cuda >= "12.4") - # or torch.cuda.get_device_capability(device) == (8, 0) - # or torch.cuda.get_device_capability(device) == (8, 7) + # now sgl-kernel only build fa3 for sm90a && cuda >= 12.3 + return (torch.cuda.get_device_capability(device)[0] == 9) and ( + torch.version.cuda >= "12.3" ) @@ -144,7 +138,7 @@ def flash_attn_with_kvcache( """ if not is_fa3_supported(): raise NotImplementedError( - "flash_attn at sgl-kernel is only supported on sm90 and above" + "flash_attn at sgl-kernel is only supported on sm90 and cu123 above" ) assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"