Remove unecessary is_fa3_supported check (#6112)
This commit is contained in:
@@ -144,10 +144,6 @@ def flash_attn_with_kvcache(
|
|||||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||||
normalization factor).
|
normalization factor).
|
||||||
"""
|
"""
|
||||||
if not is_fa3_supported():
|
|
||||||
raise NotImplementedError(
|
|
||||||
"flash_attn at sgl-kernel is only supported on sm90 and cu123 above"
|
|
||||||
)
|
|
||||||
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
|
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
|
||||||
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
|
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
|
||||||
if softmax_scale is None:
|
if softmax_scale is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user