Add flash_attn_varlen_func to sgl-kernel (#5315)
This commit is contained in:
@@ -204,3 +204,75 @@ def flash_attn_with_kvcache(
|
|||||||
)
|
)
|
||||||
# return (out, softmax_lse) if return_softmax_lse else out
|
# return (out, softmax_lse) if return_softmax_lse else out
|
||||||
return (out, softmax_lse, *rest) if return_softmax_lse else out
|
return (out, softmax_lse, *rest) if return_softmax_lse else out
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attn_varlen_func(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
seqused_q=None,
|
||||||
|
seqused_k=None,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=False,
|
||||||
|
qv=None,
|
||||||
|
q_descale=None,
|
||||||
|
k_descale=None,
|
||||||
|
v_descale=None,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
softcap=0.0,
|
||||||
|
num_splits=1,
|
||||||
|
pack_gqa=None,
|
||||||
|
sm_margin=0,
|
||||||
|
return_softmax_lse=False,
|
||||||
|
):
|
||||||
|
if not is_fa3_supported():
|
||||||
|
raise NotImplementedError(
|
||||||
|
"flash_attn at sgl-kernel is only supported on sm90 and above"
|
||||||
|
)
|
||||||
|
|
||||||
|
if softmax_scale is None:
|
||||||
|
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
|
||||||
|
-0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
None, # k_new
|
||||||
|
None, # v_new
|
||||||
|
qv, # qv
|
||||||
|
None, # out
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
None, # cu_seqlens_k_new
|
||||||
|
seqused_q,
|
||||||
|
seqused_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
None, # page_table,
|
||||||
|
None, # kv_batch_idx
|
||||||
|
None, # leftpad_k
|
||||||
|
None, # rotary cos
|
||||||
|
None, # rotary sin
|
||||||
|
None, # seqlens_rotary
|
||||||
|
q_descale,
|
||||||
|
k_descale,
|
||||||
|
v_descale,
|
||||||
|
softmax_scale,
|
||||||
|
causal,
|
||||||
|
window_size[0],
|
||||||
|
window_size[1],
|
||||||
|
softcap,
|
||||||
|
is_rotary_interleaved=False,
|
||||||
|
scheduler_metadata=None,
|
||||||
|
num_splits=num_splits,
|
||||||
|
pack_gqa=pack_gqa,
|
||||||
|
sm_margin=sm_margin,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (out, softmax_lse, *rest) if return_softmax_lse else out
|
||||||
|
|||||||
Reference in New Issue
Block a user