fix: use fa3 unit test on hopper only (#5304)
This commit is contained in:
@@ -17,7 +17,7 @@ def is_fa3_supported(device=None) -> bool:
|
|||||||
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
|
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
|
||||||
# now sgl-kernel only build fa3 for sm90a && cuda >= 12.4
|
# now sgl-kernel only build fa3 for sm90a && cuda >= 12.4
|
||||||
return (
|
return (
|
||||||
(torch.cuda.get_device_capability(device)[0] >= 9)
|
(torch.cuda.get_device_capability(device)[0] == 9)
|
||||||
and (torch.version.cuda >= "12.4")
|
and (torch.version.cuda >= "12.4")
|
||||||
# or torch.cuda.get_device_capability(device) == (8, 0)
|
# or torch.cuda.get_device_capability(device) == (8, 0)
|
||||||
# or torch.cuda.get_device_capability(device) == (8, 7)
|
# or torch.cuda.get_device_capability(device) == (8, 7)
|
||||||
|
|||||||
Reference in New Issue
Block a user