[Feat] Scale up fa3 kernel to sm8x arch (#5912)
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -11,17 +11,24 @@ from einops import rearrange, repeat
|
||||
apply_rotary_emb = None
|
||||
|
||||
|
||||
def is_hopper():
|
||||
# Only Hopper supports different V headdim
|
||||
return torch.cuda.get_device_properties(0).major >= 9
|
||||
|
||||
|
||||
def is_fa3_supported(device=None) -> bool:
|
||||
# FA3 can fail without a enough shared memory for a some shapes, currently
|
||||
# only 8.0 and 8.7 have enough shared memory for all shapes
|
||||
# There some fa3 FYI
|
||||
# FA3 can fail without a enough shared memory for a some shapes, such as higher
|
||||
# hidden_dim or some special cases.
|
||||
# Right now, fa3 is supported for sm80/sm87 and sm86/sm89. The main different
|
||||
# Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information
|
||||
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
|
||||
# now sgl-kernel only build fa3 for sm90a && cuda >= 12.4
|
||||
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
|
||||
# Thats mean if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
|
||||
return (
|
||||
(torch.cuda.get_device_capability(device)[0] == 9)
|
||||
and (torch.version.cuda >= "12.4")
|
||||
# or torch.cuda.get_device_capability(device) == (8, 0)
|
||||
# or torch.cuda.get_device_capability(device) == (8, 7)
|
||||
)
|
||||
torch.cuda.get_device_capability(device)[0] == 9
|
||||
or torch.cuda.get_device_capability(device)[0] == 8
|
||||
) and (torch.version.cuda >= "12.3")
|
||||
|
||||
|
||||
DISABLE_BACKWARD = True
|
||||
@@ -558,7 +565,8 @@ def test_flash_attn_kvcache(
|
||||
assert nheads % nheads_k == 0
|
||||
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
|
||||
dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
if dtype == torch.float8_e4m3fn or not is_hopper():
|
||||
# for fp8 and ampere arch, we not support v head dim != qk head dim
|
||||
dv_vals = [d]
|
||||
for dv in dv_vals:
|
||||
has_qv = d == 64 and dv >= 256
|
||||
|
||||
Reference in New Issue
Block a user