Add flash_attn_varlen_func to sgl-kernel (#5315)
This commit is contained in:
@@ -204,3 +204,75 @@ def flash_attn_with_kvcache(
|
||||
)
|
||||
# return (out, softmax_lse) if return_softmax_lse else out
|
||||
return (out, softmax_lse, *rest) if return_softmax_lse else out
|
||||
|
||||
|
||||
def flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
seqused_q=None,
|
||||
seqused_k=None,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
qv=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
window_size=(-1, -1),
|
||||
softcap=0.0,
|
||||
num_splits=1,
|
||||
pack_gqa=None,
|
||||
sm_margin=0,
|
||||
return_softmax_lse=False,
|
||||
):
|
||||
if not is_fa3_supported():
|
||||
raise NotImplementedError(
|
||||
"flash_attn at sgl-kernel is only supported on sm90 and above"
|
||||
)
|
||||
|
||||
if softmax_scale is None:
|
||||
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
|
||||
-0.5
|
||||
)
|
||||
|
||||
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
None, # k_new
|
||||
None, # v_new
|
||||
qv, # qv
|
||||
None, # out
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
None, # cu_seqlens_k_new
|
||||
seqused_q,
|
||||
seqused_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
None, # page_table,
|
||||
None, # kv_batch_idx
|
||||
None, # leftpad_k
|
||||
None, # rotary cos
|
||||
None, # rotary sin
|
||||
None, # seqlens_rotary
|
||||
q_descale,
|
||||
k_descale,
|
||||
v_descale,
|
||||
softmax_scale,
|
||||
causal,
|
||||
window_size[0],
|
||||
window_size[1],
|
||||
softcap,
|
||||
is_rotary_interleaved=False,
|
||||
scheduler_metadata=None,
|
||||
num_splits=num_splits,
|
||||
pack_gqa=pack_gqa,
|
||||
sm_margin=sm_margin,
|
||||
)
|
||||
|
||||
return (out, softmax_lse, *rest) if return_softmax_lse else out
|
||||
|
||||
Reference in New Issue
Block a user