Update fa3 interface and add unit test (#9150)
This commit is contained in:
@@ -58,6 +58,7 @@ def flash_attn_with_kvcache(
|
||||
pack_gqa=None, # Can be tuned for speed
|
||||
sm_margin=0, # Can be tuned if some SMs are used for communication
|
||||
return_softmax_lse=False,
|
||||
sinks=None,
|
||||
):
|
||||
"""
|
||||
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
|
||||
@@ -205,6 +206,7 @@ def flash_attn_with_kvcache(
|
||||
num_splits,
|
||||
pack_gqa,
|
||||
sm_margin,
|
||||
sinks,
|
||||
)
|
||||
# return (out, softmax_lse) if return_softmax_lse else out
|
||||
return (out, softmax_lse, *rest) if return_softmax_lse else out
|
||||
@@ -232,6 +234,7 @@ def flash_attn_varlen_func(
|
||||
pack_gqa=None,
|
||||
sm_margin=0,
|
||||
return_softmax_lse=False,
|
||||
sinks=None,
|
||||
):
|
||||
if not is_fa3_supported():
|
||||
raise NotImplementedError(
|
||||
@@ -277,6 +280,7 @@ def flash_attn_varlen_func(
|
||||
num_splits=num_splits,
|
||||
pack_gqa=pack_gqa,
|
||||
sm_margin=sm_margin,
|
||||
sinks=sinks,
|
||||
)
|
||||
|
||||
return (out, softmax_lse, *rest) if return_softmax_lse else out
|
||||
|
||||
Reference in New Issue
Block a user