fix: update flash attn (#5308)
This commit is contained in:
@@ -10,15 +10,9 @@ except:
|
|||||||
|
|
||||||
|
|
||||||
def is_fa3_supported(device=None) -> bool:
|
def is_fa3_supported(device=None) -> bool:
|
||||||
# FA3 can fail without a enough shared memory for a some shapes, currently
|
# now sgl-kernel only build fa3 for sm90a && cuda >= 12.3
|
||||||
# only 8.0 and 8.7 have enough shared memory for all shapes
|
return (torch.cuda.get_device_capability(device)[0] == 9) and (
|
||||||
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
|
torch.version.cuda >= "12.3"
|
||||||
# now sgl-kernel only build fa3 for sm90a && cuda >= 12.4
|
|
||||||
return (
|
|
||||||
(torch.cuda.get_device_capability(device)[0] >= 9)
|
|
||||||
and (torch.version.cuda >= "12.4")
|
|
||||||
# or torch.cuda.get_device_capability(device) == (8, 0)
|
|
||||||
# or torch.cuda.get_device_capability(device) == (8, 7)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -144,7 +138,7 @@ def flash_attn_with_kvcache(
|
|||||||
"""
|
"""
|
||||||
if not is_fa3_supported():
|
if not is_fa3_supported():
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"flash_attn at sgl-kernel is only supported on sm90 and above"
|
"flash_attn at sgl-kernel is only supported on sm90 and cu123 above"
|
||||||
)
|
)
|
||||||
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
|
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
|
||||||
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
|
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
|
||||||
|
|||||||
Reference in New Issue
Block a user