diff --git a/sgl-kernel/tests/test_flash_attention.py b/sgl-kernel/tests/test_flash_attention.py index ff60b7710..402dc81d7 100644 --- a/sgl-kernel/tests/test_flash_attention.py +++ b/sgl-kernel/tests/test_flash_attention.py @@ -17,7 +17,7 @@ def is_fa3_supported(device=None) -> bool: # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x # now sgl-kernel only build fa3 for sm90a && cuda >= 12.4 return ( - (torch.cuda.get_device_capability(device)[0] >= 9) + (torch.cuda.get_device_capability(device)[0] == 9) and (torch.version.cuda >= "12.4") # or torch.cuda.get_device_capability(device) == (8, 0) # or torch.cuda.get_device_capability(device) == (8, 7)