diff --git a/sgl-kernel/python/sgl_kernel/flash_attn.py b/sgl-kernel/python/sgl_kernel/flash_attn.py index a14173568..2d1a79489 100644 --- a/sgl-kernel/python/sgl_kernel/flash_attn.py +++ b/sgl-kernel/python/sgl_kernel/flash_attn.py @@ -204,3 +204,75 @@ def flash_attn_with_kvcache( ) # return (out, softmax_lse) if return_softmax_lse else out return (out, softmax_lse, *rest) if return_softmax_lse else out + + +def flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + seqused_q=None, + seqused_k=None, + softmax_scale=None, + causal=False, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=(-1, -1), + softcap=0.0, + num_splits=1, + pack_gqa=None, + sm_margin=0, + return_softmax_lse=False, +): + if not is_fa3_supported(): + raise NotImplementedError( + "flash_attn at sgl-kernel is only supported on sm90 and above" + ) + + if softmax_scale is None: + softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** ( + -0.5 + ) + + out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default( + q, + k, + v, + None, # k_new + None, # v_new + qv, # qv + None, # out + cu_seqlens_q, + cu_seqlens_k, + None, # cu_seqlens_k_new + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + None, # page_table, + None, # kv_batch_idx + None, # leftpad_k + None, # rotary cos + None, # rotary sin + None, # seqlens_rotary + q_descale, + k_descale, + v_descale, + softmax_scale, + causal, + window_size[0], + window_size[1], + softcap, + is_rotary_interleaved=False, + scheduler_metadata=None, + num_splits=num_splits, + pack_gqa=pack_gqa, + sm_margin=sm_margin, + ) + + return (out, softmax_lse, *rest) if return_softmax_lse else out