Reduce overhead for fa by not calling heavy CUDA property check (#7375)
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
from typing import List, Optional, Tuple, Union
|
from functools import lru_cache
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -9,6 +10,7 @@ except:
|
|||||||
raise ImportError("Can not import sgl_kernel. Please check your installation.")
|
raise ImportError("Can not import sgl_kernel. Please check your installation.")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
def is_fa3_supported(device=None) -> bool:
|
def is_fa3_supported(device=None) -> bool:
|
||||||
# There some fa3 FYI
|
# There some fa3 FYI
|
||||||
# FA3 can fail without a enough shared memory for a some shapes, such as higher
|
# FA3 can fail without a enough shared memory for a some shapes, such as higher
|
||||||
@@ -18,10 +20,10 @@ def is_fa3_supported(device=None) -> bool:
|
|||||||
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
|
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
|
||||||
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
|
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
|
||||||
# That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
|
# That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
|
||||||
return (
|
return (torch.version.cuda >= "12.3") and (
|
||||||
torch.cuda.get_device_capability(device)[0] == 9
|
torch.cuda.get_device_capability(device)[0] == 9
|
||||||
or torch.cuda.get_device_capability(device)[0] == 8
|
or torch.cuda.get_device_capability(device)[0] == 8
|
||||||
) and (torch.version.cuda >= "12.3")
|
)
|
||||||
|
|
||||||
|
|
||||||
def maybe_contiguous(x):
|
def maybe_contiguous(x):
|
||||||
|
|||||||
@@ -25,10 +25,10 @@ def is_fa3_supported(device=None) -> bool:
|
|||||||
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
|
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
|
||||||
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
|
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
|
||||||
# That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
|
# That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
|
||||||
return (
|
return (torch.version.cuda >= "12.3") and (
|
||||||
torch.cuda.get_device_capability(device)[0] == 9
|
torch.cuda.get_device_capability(device)[0] == 9
|
||||||
or torch.cuda.get_device_capability(device)[0] == 8
|
or torch.cuda.get_device_capability(device)[0] == 8
|
||||||
) and (torch.version.cuda >= "12.3")
|
)
|
||||||
|
|
||||||
|
|
||||||
DISABLE_BACKWARD = True
|
DISABLE_BACKWARD = True
|
||||||
|
|||||||
Reference in New Issue
Block a user