[Feat] Scale up fa3 kernel to sm8x arch (#5912)

Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
PGFLMG
2025-05-01 04:59:36 +08:00
committed by GitHub
parent 2afba1b1c1
commit 08acdb5c3d
4 changed files with 52 additions and 22 deletions

View File

@@ -10,10 +10,18 @@ except:
def is_fa3_supported(device=None) -> bool:
# now sgl-kernel only build fa3 for sm90a && cuda >= 12.3
return (torch.cuda.get_device_capability(device)[0] == 9) and (
torch.version.cuda >= "12.3"
)
# There some fa3 FYI
# FA3 can fail without a enough shared memory for a some shapes, such as higher
# hidden_dim or some special cases.
# Right now, fa3 is supported for sm80/sm87 and sm86/sm89. The main different
# Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
# Thats mean if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
return (
torch.cuda.get_device_capability(device)[0] == 9
or torch.cuda.get_device_capability(device)[0] == 8
) and (torch.version.cuda >= "12.3")
def maybe_contiguous(x):