Update fa3 interface and add unit test (#9150)

This commit is contained in:
Ke Bao
2025-08-13 20:05:02 +08:00
committed by GitHub
parent 3b3b3baf9f
commit 94f44b88d1
4 changed files with 54 additions and 12 deletions

View File

@@ -58,6 +58,7 @@ def flash_attn_with_kvcache(
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
return_softmax_lse=False,
sinks=None,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
@@ -205,6 +206,7 @@ def flash_attn_with_kvcache(
num_splits,
pack_gqa,
sm_margin,
sinks,
)
# return (out, softmax_lse) if return_softmax_lse else out
return (out, softmax_lse, *rest) if return_softmax_lse else out
@@ -232,6 +234,7 @@ def flash_attn_varlen_func(
pack_gqa=None,
sm_margin=0,
return_softmax_lse=False,
sinks=None,
):
if not is_fa3_supported():
raise NotImplementedError(
@@ -277,6 +280,7 @@ def flash_attn_varlen_func(
num_splits=num_splits,
pack_gqa=pack_gqa,
sm_margin=sm_margin,
sinks=sinks,
)
return (out, softmax_lse, *rest) if return_softmax_lse else out