feat: support fa cute in sgl-kernel (#10205)
Co-authored-by: cicirori <32845984+cicirori@users.noreply.github.com>
This commit is contained in:
376
sgl-kernel/python/sgl_kernel/_fa4_interface.py
Normal file
376
sgl-kernel/python/sgl_kernel/_fa4_interface.py
Normal file
@@ -0,0 +1,376 @@
|
||||
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/203b9b3dba39d5d08dffb49c09aa622984dff07d/flash_attn/cute/interface.py
|
||||
|
||||
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
||||
# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.1.0.
|
||||
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import torch
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
from flash_attn.cute.flash_fwd import FlashAttentionForwardSm90
|
||||
from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100
|
||||
|
||||
|
||||
def maybe_contiguous(x):
|
||||
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
||||
|
||||
|
||||
torch2cute_dtype_map = {
|
||||
torch.float16: cutlass.Float16,
|
||||
torch.bfloat16: cutlass.BFloat16,
|
||||
torch.float32: cutlass.Float32,
|
||||
}
|
||||
|
||||
|
||||
def _flash_attn_fwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens_q: Optional[torch.Tensor] = None,
|
||||
cu_seqlens_k: Optional[torch.Tensor] = None,
|
||||
seqused_q: Optional[torch.Tensor] = None,
|
||||
seqused_k: Optional[torch.Tensor] = None,
|
||||
page_table: Optional[torch.Tensor] = None,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
softcap: Optional[float] = None,
|
||||
window_size_left: Optional[int] = None,
|
||||
window_size_right: Optional[int] = None,
|
||||
learnable_sink: Optional[torch.Tensor] = None,
|
||||
# m_block_size: int = 128,
|
||||
# n_block_size: int = 64,
|
||||
# num_threads: int = 128,
|
||||
m_block_size: int = 128,
|
||||
n_block_size: int = 128,
|
||||
num_threads: int = 384,
|
||||
pack_gqa: Optional[bool] = None,
|
||||
_compute_capability: Optional[int] = None,
|
||||
return_softmax_lse: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
q, k, v = [maybe_contiguous(t) for t in (q, k, v)]
|
||||
num_head, head_dim = q.shape[-2:]
|
||||
if cu_seqlens_q is None:
|
||||
batch_size, seqlen_q = q.shape[:2]
|
||||
total_q = batch_size * seqlen_q
|
||||
else:
|
||||
batch_size = cu_seqlens_q.shape[0] - 1
|
||||
seqlen_q = None
|
||||
total_q = q.shape[0]
|
||||
if page_table is not None:
|
||||
assert cu_seqlens_k is None, "page_table is not supported with cu_seqlens_k"
|
||||
assert page_table.dtype == torch.int32, "page_table must be int32"
|
||||
assert (
|
||||
page_table.stride(-1) == 1
|
||||
), "page_table must be contiguous in the last dimension"
|
||||
max_num_pages_per_seq = page_table.shape[1]
|
||||
assert page_table.shape == (batch_size, max_num_pages_per_seq)
|
||||
num_pages, page_size = k.shape[:2]
|
||||
seqlen_k = num_pages * page_size
|
||||
else:
|
||||
num_pages, page_size = None, None
|
||||
seqlen_k = k.shape[-3]
|
||||
num_head_kv = k.shape[-2]
|
||||
head_dim_v = v.shape[-1]
|
||||
if cu_seqlens_k is None:
|
||||
if page_table is None:
|
||||
assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim)
|
||||
assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v)
|
||||
else:
|
||||
assert k.shape == (num_pages, page_size, num_head_kv, head_dim)
|
||||
assert v.shape == (num_pages, page_size, num_head_kv, head_dim_v)
|
||||
else:
|
||||
assert k.shape == (seqlen_k, num_head_kv, head_dim)
|
||||
assert v.shape == (seqlen_k, num_head_kv, head_dim_v)
|
||||
assert cu_seqlens_k.shape == (
|
||||
batch_size + 1,
|
||||
), "cu_seqlens_k must have shape (batch_size + 1,)"
|
||||
if cu_seqlens_q is not None:
|
||||
assert cu_seqlens_q.shape == (
|
||||
batch_size + 1,
|
||||
), "cu_seqlens_q must have shape (batch_size + 1,)"
|
||||
assert seqused_q is None or seqused_q.shape == (
|
||||
batch_size,
|
||||
), "seqused_q must have shape (batch_size,)"
|
||||
assert seqused_k is None or seqused_k.shape == (
|
||||
batch_size,
|
||||
), "seqused_k must have shape (batch_size,)"
|
||||
assert q.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
], "inputs must be float16 or bfloat16"
|
||||
assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype"
|
||||
for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]:
|
||||
if t is not None:
|
||||
assert (
|
||||
t.dtype == torch.int32
|
||||
), "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32"
|
||||
assert (
|
||||
t.stride(0) == 1
|
||||
), "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous"
|
||||
if learnable_sink is not None:
|
||||
assert learnable_sink.shape == (num_head,)
|
||||
assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16"
|
||||
assert all(
|
||||
t is None or t.is_cuda
|
||||
for t in (
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
seqused_q,
|
||||
seqused_k,
|
||||
page_table,
|
||||
learnable_sink,
|
||||
)
|
||||
), "inputs must be on CUDA device"
|
||||
assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
|
||||
assert head_dim <= 256, "head_dim must be less than or equal to 256"
|
||||
alignment = 16 // q.element_size()
|
||||
assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}"
|
||||
assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}"
|
||||
if softmax_scale is None:
|
||||
softmax_scale = 1.0 / math.sqrt(head_dim)
|
||||
if softcap == 0.0:
|
||||
softcap = None
|
||||
qhead_per_kvhead = num_head // num_head_kv
|
||||
if pack_gqa is None:
|
||||
pack_gqa = qhead_per_kvhead > 1
|
||||
|
||||
out_torch_dtype = q.dtype
|
||||
device = q.device
|
||||
q_batch_seqlen_shape = (
|
||||
(batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,)
|
||||
)
|
||||
out = torch.empty(
|
||||
*q_batch_seqlen_shape,
|
||||
num_head,
|
||||
head_dim_v,
|
||||
dtype=out_torch_dtype,
|
||||
device=device,
|
||||
)
|
||||
lse_shape = (
|
||||
(batch_size, num_head, seqlen_q)
|
||||
if cu_seqlens_q is None
|
||||
else (num_head, total_q)
|
||||
)
|
||||
lse = (
|
||||
torch.empty(lse_shape, dtype=torch.float32, device=device)
|
||||
if return_softmax_lse
|
||||
else None
|
||||
)
|
||||
|
||||
dtype = torch2cute_dtype_map[q.dtype]
|
||||
q_tensor, k_tensor, v_tensor, o_tensor = [
|
||||
from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(
|
||||
leading_dim=t.ndim - 1
|
||||
)
|
||||
for t in (q, k, v, out)
|
||||
]
|
||||
lse_tensor = (
|
||||
from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(
|
||||
leading_dim=lse.ndim - 1
|
||||
)
|
||||
if lse is not None
|
||||
else None
|
||||
)
|
||||
(
|
||||
cu_seqlens_q_tensor,
|
||||
cu_seqlens_k_tensor,
|
||||
seqused_q_tensor,
|
||||
seqused_k_tensor,
|
||||
learnable_sink_tensor,
|
||||
) = [
|
||||
(
|
||||
from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
||||
if t is not None
|
||||
else None
|
||||
)
|
||||
for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink)
|
||||
]
|
||||
page_table_tensor = (
|
||||
from_dlpack(page_table.detach(), assumed_align=4).mark_layout_dynamic(
|
||||
leading_dim=1
|
||||
)
|
||||
if page_table is not None
|
||||
else None
|
||||
)
|
||||
if causal:
|
||||
window_size_right = 0
|
||||
local = window_size_left is not None or window_size_right is not None
|
||||
if window_size_left is not None or window_size_right is not None:
|
||||
if window_size_left is None and window_size_right == 0:
|
||||
causal, local = True, False
|
||||
else:
|
||||
causal, local = False, True
|
||||
compute_capability = (
|
||||
torch.cuda.get_device_capability()[0]
|
||||
if _compute_capability is None
|
||||
else _compute_capability
|
||||
)
|
||||
assert compute_capability in [
|
||||
9,
|
||||
10,
|
||||
], "Unsupported compute capability. Supported: 9.x, 10.x"
|
||||
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
if compute_capability == 9: # TODO: tune block size according to hdim
|
||||
if head_dim == head_dim_v == 128 and not causal and not local:
|
||||
n_block_size = 192
|
||||
if compute_capability == 10:
|
||||
# TODO: fix the varlen case
|
||||
if (
|
||||
pack_gqa
|
||||
and (128 % qhead_per_kvhead != 0)
|
||||
or (cu_seqlens_q is not None or seqused_q is not None)
|
||||
):
|
||||
pack_gqa = False
|
||||
|
||||
compile_key = (
|
||||
dtype,
|
||||
head_dim,
|
||||
head_dim_v,
|
||||
qhead_per_kvhead,
|
||||
causal,
|
||||
softcap is not None,
|
||||
lse is None,
|
||||
cu_seqlens_q is None,
|
||||
cu_seqlens_k is None,
|
||||
seqused_q is None,
|
||||
seqused_k is None,
|
||||
page_table is not None,
|
||||
window_size_left is not None,
|
||||
window_size_right is not None,
|
||||
learnable_sink is not None,
|
||||
m_block_size,
|
||||
n_block_size,
|
||||
num_threads,
|
||||
pack_gqa,
|
||||
compute_capability,
|
||||
)
|
||||
if compile_key not in _flash_attn_fwd.compile_cache:
|
||||
if compute_capability == 9:
|
||||
assert page_table is None, "paged KV not supported on SM 9.0"
|
||||
# fa_fwd = FlashAttentionForwardSm80(
|
||||
fa_fwd = FlashAttentionForwardSm90(
|
||||
dtype,
|
||||
head_dim,
|
||||
head_dim_v,
|
||||
qhead_per_kvhead,
|
||||
is_causal=causal,
|
||||
is_local=local,
|
||||
pack_gqa=pack_gqa,
|
||||
m_block_size=m_block_size,
|
||||
n_block_size=n_block_size,
|
||||
# num_stages=1,
|
||||
num_stages=2,
|
||||
num_threads=num_threads,
|
||||
Q_in_regs=False,
|
||||
)
|
||||
elif compute_capability == 10:
|
||||
assert page_size in [
|
||||
None,
|
||||
128,
|
||||
], "Only page_size=128 is supported for paged KV on SM 10.0"
|
||||
fa_fwd = FlashAttentionForwardSm100(
|
||||
head_dim,
|
||||
head_dim_v,
|
||||
qhead_per_kvhead=qhead_per_kvhead,
|
||||
is_causal=causal,
|
||||
is_local=local,
|
||||
pack_gqa=pack_gqa,
|
||||
is_persistent=not causal
|
||||
and not local
|
||||
and cu_seqlens_q is None
|
||||
and seqused_q is None,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x"
|
||||
)
|
||||
# TODO: check @can_implement
|
||||
_flash_attn_fwd.compile_cache[compile_key] = cute.compile(
|
||||
fa_fwd,
|
||||
q_tensor,
|
||||
k_tensor,
|
||||
v_tensor,
|
||||
o_tensor,
|
||||
lse_tensor,
|
||||
softmax_scale,
|
||||
current_stream,
|
||||
cu_seqlens_q_tensor,
|
||||
cu_seqlens_k_tensor,
|
||||
seqused_q_tensor,
|
||||
seqused_k_tensor,
|
||||
page_table_tensor,
|
||||
softcap,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
learnable_sink_tensor,
|
||||
)
|
||||
_flash_attn_fwd.compile_cache[compile_key](
|
||||
q_tensor,
|
||||
k_tensor,
|
||||
v_tensor,
|
||||
o_tensor,
|
||||
lse_tensor,
|
||||
softmax_scale,
|
||||
current_stream,
|
||||
cu_seqlens_q_tensor,
|
||||
cu_seqlens_k_tensor,
|
||||
seqused_q_tensor,
|
||||
seqused_k_tensor,
|
||||
page_table_tensor,
|
||||
softcap,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
learnable_sink_tensor,
|
||||
)
|
||||
return out, lse
|
||||
|
||||
|
||||
_flash_attn_fwd.compile_cache = {}
|
||||
|
||||
|
||||
def flash_attn_varlen_func(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens_q: Optional[torch.Tensor] = None,
|
||||
cu_seqlens_k: Optional[torch.Tensor] = None,
|
||||
seqused_q: Optional[torch.Tensor] = None,
|
||||
seqused_k: Optional[torch.Tensor] = None,
|
||||
page_table: Optional[torch.Tensor] = None,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
window_size: Tuple[Optional[int], Optional[int]] = (None, None),
|
||||
learnable_sink: Optional[torch.Tensor] = None,
|
||||
softcap: float = 0.0,
|
||||
pack_gqa: Optional[bool] = None,
|
||||
return_softmax_lse: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
out, lse = _flash_attn_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
seqused_q,
|
||||
seqused_k,
|
||||
page_table=page_table,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
window_size_left=window_size[0],
|
||||
window_size_right=window_size[1],
|
||||
learnable_sink=learnable_sink,
|
||||
softcap=softcap,
|
||||
pack_gqa=pack_gqa,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
)
|
||||
|
||||
return (out, lse) if return_softmax_lse else out
|
||||
@@ -9,6 +9,11 @@ try:
|
||||
except:
|
||||
raise ImportError("Can not import sgl_kernel. Please check your installation.")
|
||||
|
||||
try:
|
||||
from ._fa4_interface import flash_attn_varlen_func as flash_attn_varlen_func_v4
|
||||
except ImportError:
|
||||
flash_attn_varlen_func_v4 = None
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def is_fa3_supported(device=None) -> bool:
|
||||
@@ -61,6 +66,7 @@ def flash_attn_with_kvcache(
|
||||
sm_margin=0, # Can be tuned if some SMs are used for communication
|
||||
return_softmax_lse=False,
|
||||
sinks=None,
|
||||
ver=3,
|
||||
):
|
||||
"""
|
||||
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
|
||||
@@ -147,6 +153,9 @@ def flash_attn_with_kvcache(
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
"""
|
||||
if ver == 4:
|
||||
raise NotImplementedError("haven't implemented flash_attn_with_kvcache for fa4")
|
||||
|
||||
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
|
||||
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
|
||||
if softmax_scale is None:
|
||||
@@ -237,7 +246,40 @@ def flash_attn_varlen_func(
|
||||
sm_margin=0,
|
||||
return_softmax_lse=False,
|
||||
sinks=None,
|
||||
ver=3,
|
||||
):
|
||||
if ver == 4:
|
||||
assert (
|
||||
flash_attn_varlen_func_v4 is not None
|
||||
), "FA4 is not available, please check your installation."
|
||||
# Using `(-1, -1)` as no sliding window causes correctness issues for FA4.
|
||||
if window_size == (-1, -1):
|
||||
window_size = (None, None)
|
||||
return flash_attn_varlen_func_v4(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
# max_seqlen_q,
|
||||
# max_seqlen_k,
|
||||
seqused_q=seqused_q,
|
||||
seqused_k=seqused_k,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
# qv=qv,
|
||||
# q_descale=q_descale,
|
||||
# k_descale=k_descale,
|
||||
# v_descale=v_descale,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
# num_splits=num_splits,
|
||||
pack_gqa=pack_gqa,
|
||||
# sm_margin=sm_margin,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
learnable_sink=sinks,
|
||||
)
|
||||
|
||||
if not is_fa3_supported():
|
||||
raise NotImplementedError(
|
||||
"flash_attn at sgl-kernel is only supported on sm90 and above"
|
||||
|
||||
Reference in New Issue
Block a user