[NVIDIA] FA3/FA4 Fix (#11606)

Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
Johnny
2025-10-20 02:10:10 +02:00
committed by GitHub
parent cbb5fc2edc
commit 252dc4e112
10 changed files with 382 additions and 219 deletions

View File

@@ -1,14 +1,14 @@
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/203b9b3dba39d5d08dffb49c09aa622984dff07d/flash_attn/cute/interface.py
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/54d8aa6751fc9d5f0357854079261913d5df1f9d/flash_attn/cute/interface.py
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.1.0.
# [2025-10-14] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.2.1.
import copy
import gc
import logging
import math
from typing import Optional, Tuple
from typing import Callable, Optional, Tuple
logger = logging.getLogger(__name__)
@@ -18,6 +18,7 @@ import cutlass
import cutlass.cute as cute
import torch
from cutlass.cute.runtime import from_dlpack
from flash_attn.cute import utils
from flash_attn.cute.flash_fwd import FlashAttentionForwardSm90
from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100
@@ -26,22 +27,6 @@ def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
def _reason_recompile(compile_key, jit_func):
compile_cache = jit_func.compile_cache
compile_key_map = jit_func.compile_key_map
if not compile_cache:
return "not compiled yet"
for k, v in compile_cache.items():
if k == compile_key:
continue
if len(k) != len(compile_key):
continue
for i in range(len(k)):
if k[i] != compile_key[i]:
return f"diff at '{compile_key_map[i]}': {k[i]} vs {compile_key[i]} "
return "unknown reason"
torch2cute_dtype_map = {
torch.float16: cutlass.Float16,
torch.bfloat16: cutlass.BFloat16,
@@ -72,7 +57,11 @@ def _flash_attn_fwd(
num_threads: int = 384,
pack_gqa: Optional[bool] = None,
_compute_capability: Optional[int] = None,
return_softmax_lse: Optional[bool] = False,
score_mod: Callable | None = None,
return_lse: bool = False,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
buffers: Optional[list[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
q, k, v = [maybe_contiguous(t) for t in (q, k, v)]
num_head, head_dim = q.shape[-2:]
@@ -169,23 +158,51 @@ def _flash_attn_fwd(
q_batch_seqlen_shape = (
(batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,)
)
out = torch.empty(
*q_batch_seqlen_shape,
num_head,
head_dim_v,
dtype=out_torch_dtype,
device=device,
)
lse_shape = (
(batch_size, num_head, seqlen_q)
if cu_seqlens_q is None
else (num_head, total_q)
)
lse = (
torch.empty(lse_shape, dtype=torch.float32, device=device)
if return_softmax_lse
else None
)
requires_grad = q.requires_grad or k.requires_grad or v.requires_grad
if out is None:
out = torch.empty(
*q_batch_seqlen_shape,
num_head,
head_dim_v,
dtype=out_torch_dtype,
device=device,
)
else:
expected_out_shape = (*q_batch_seqlen_shape, num_head, head_dim_v)
assert (
out.shape == expected_out_shape
), f"out tensor shape {out.shape} does not match expected shape {expected_out_shape}"
assert (
out.dtype == out_torch_dtype
), f"out tensor dtype {out.dtype} does not match expected dtype {out_torch_dtype}"
assert (
out.device == device
), f"out tensor device {out.device} does not match input device {device}"
assert out.is_cuda, "out tensor must be on CUDA device"
if lse is None:
lse = (
torch.empty(lse_shape, dtype=torch.float32, device=device)
if requires_grad or return_lse
else None
)
elif lse is not None:
assert (
lse.shape == lse_shape
), f"lse tensor shape {lse.shape} does not match expected shape {lse_shape}"
assert (
lse.dtype == torch.float32
), f"lse tensor dtype {lse.dtype} does not match expected dtype torch.float32"
assert (
lse.device == device
), f"lse tensor device {lse.device} does not match input device {device}"
assert lse.is_cuda, "lse tensor must be on CUDA device"
dtype = torch2cute_dtype_map[q.dtype]
q_tensor, k_tensor, v_tensor, o_tensor = [
@@ -242,6 +259,7 @@ def _flash_attn_fwd(
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
if compute_capability == 9: # TODO: tune block size according to hdim
# Perf heuristic from upstream: hdim=128, noncausal, non-local benefits from larger n_block
if head_dim == head_dim_v == 128 and not causal and not local:
n_block_size = 192
if compute_capability == 10:
@@ -253,13 +271,34 @@ def _flash_attn_fwd(
):
pack_gqa = False
if softcap is not None:
assert score_mod is None, "softcap and score_mod cannot be used together"
score_mod = utils.create_softcap_scoremod(softcap)
if score_mod is not None:
is_varlen = (
cu_seqlens_q is not None
or cu_seqlens_k is not None
or seqused_q is not None
or seqused_k is not None
)
if is_varlen:
raise NotImplementedError(
"score_mod with buffers is not yet supported for varlen sequences. This will be fixed in a future PR."
)
cute_buffers = None
if buffers is not None:
cute_buffers = [from_dlpack(buf) for buf in buffers]
compile_key = (
dtype,
head_dim,
head_dim_v,
qhead_per_kvhead,
causal,
softcap is not None,
utils.hash_callable(score_mod) if score_mod is not None else None,
buffers is not None,
lse is None,
cu_seqlens_q is None,
cu_seqlens_k is None,
@@ -276,9 +315,6 @@ def _flash_attn_fwd(
compute_capability,
)
if compile_key not in _flash_attn_fwd.compile_cache:
logger.info(
f"Compiling FA4 kernel with reason: {_reason_recompile(compile_key, _flash_attn_fwd)}"
)
if compute_capability == 9:
assert page_table is None, "paged KV not supported on SM 9.0"
# fa_fwd = FlashAttentionForwardSm80(
@@ -290,12 +326,14 @@ def _flash_attn_fwd(
is_causal=causal,
is_local=local,
pack_gqa=pack_gqa,
m_block_size=m_block_size,
n_block_size=n_block_size,
tile_m=m_block_size,
tile_n=n_block_size,
# num_stages=1,
num_stages=2,
num_threads=num_threads,
Q_in_regs=False,
score_mod=score_mod,
has_buffers=buffers is not None,
)
elif compute_capability == 10:
assert page_size in [
@@ -313,12 +351,15 @@ def _flash_attn_fwd(
and not local
and cu_seqlens_q is None
and seqused_q is None,
score_mod=score_mod,
has_buffers=buffers is not None,
)
else:
raise ValueError(
f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x"
)
# TODO: check @can_implement
# TODO caching for buffers; cute_buffers
_flash_attn_fwd.compile_cache[compile_key] = cute.compile(
fa_fwd,
q_tensor,
@@ -333,10 +374,10 @@ def _flash_attn_fwd(
seqused_q_tensor,
seqused_k_tensor,
page_table_tensor,
softcap,
window_size_left,
window_size_right,
learnable_sink_tensor,
cute_buffers,
)
_flash_attn_fwd.compile_cache[compile_key](
q_tensor,
@@ -351,46 +392,29 @@ def _flash_attn_fwd(
seqused_q_tensor,
seqused_k_tensor,
page_table_tensor,
softcap,
window_size_left,
window_size_right,
learnable_sink_tensor,
cute_buffers,
)
return out, lse
_flash_attn_fwd.compile_cache = {}
_flash_attn_fwd.compile_key_map = [
"dtype",
"head_dim",
"head_dim_v",
"qhead_per_kvhead",
"causal",
"softcap is not None",
"lse is None",
"cu_seqlens_q is None",
"cu_seqlens_k is None",
"seqused_q is None",
"seqused_k is None",
"page_table is not None",
"window_size_left is not None",
"window_size_right is not None",
"learnable_sink is not None",
"m_block_size",
"n_block_size",
"num_threads",
"pack_gqa",
"compute_capability",
]
def warmup_flash_attn(f):
"""
Decorator for flash_attn_varlen_func:
- On the first call, run several warmup passes with different flag combinations
- Warmups are executed sequentially to minimize peak GPU memory usage
- Does not modify user-provided tensors (clones data)
- Easy to extend with more compile-key dimensions
- On first call, run several warmup passes with different flag combinations:
* return_softmax_lse in {False, True}
* global noncausal (window_size=(None,None))
* causal (window_size=(None,0))
* local sliding window (window_size=(64,64))
* optionally pack_gqa=True if qheads > kvheads and allowed
- No score_mod / softcap (not supported for varlen yet)
- Executes sequentially to minimize peak GPU mem
- Does not modify user tensors (clones)
"""
done = False
@@ -399,30 +423,78 @@ def warmup_flash_attn(f):
def maybe_clone(x):
if isinstance(x, torch.Tensor):
return x.clone()
return x.detach().clone() # detach to avoid autograd edges
return copy.deepcopy(x)
return tuple(maybe_clone(a) for a in args), {
k: maybe_clone(v) for k, v in kwargs.items()
}
def _infer_heads(args, kwargs):
"""Infer q and kv head counts from arguments."""
# Expect signature: (q, k, v, cu_seqlens_q, cu_seqlens_k, ...)
q = args[0] if len(args) > 0 else kwargs.get("q")
k = args[1] if len(args) > 1 else kwargs.get("k")
try:
qh = int(q.shape[-2])
kvh = int(k.shape[-2])
return qh, kvh
except Exception:
return None, None
def _run_warmups(args, kwargs):
"""Run warmup calls sequentially and release memory after each."""
base_args, base_kwargs = _clone_args(args, kwargs)
# Warmup combinations for return_softmax_lse and causal
combos = [
dict(return_softmax_lse=False, causal=False),
dict(return_softmax_lse=False, causal=True),
dict(return_softmax_lse=True, causal=False),
dict(return_softmax_lse=True, causal=True),
qh, kvh = _infer_heads(base_args, base_kwargs)
can_pack_gqa = (
qh is not None and kvh is not None and qh % kvh == 0 and qh // kvh > 1
)
has_page_table = (
"page_table" in base_kwargs and base_kwargs["page_table"] is not None
)
# Window presets covering global, causal, and local
window_presets = [
(None, None), # global noncausal
(None, 0), # causal
(64, 64), # local sliding window
]
lse_flags = [False, True]
# Base combo list
combos = []
for ws in window_presets:
for return_lse_flag in lse_flags:
combos.append(dict(window_size=ws, return_softmax_lse=return_lse_flag))
# Optionally add a pack_gqa=True variant (FA4 may disable it internally for some varlen shapes/SMs)
if can_pack_gqa:
for ws in window_presets:
combos.append(
dict(window_size=ws, return_softmax_lse=False, pack_gqa=True)
)
# If page_table is present, warm one combo with it (page_table in compile key for SM100)
if has_page_table:
combos.append(dict(window_size=(None, None), return_softmax_lse=False))
# Run sequentially
for combo in combos:
wa, wk = _clone_args(base_args, base_kwargs)
# Keep user-provided softcap/score_mod OUT (varlen+score_mod unsupported)
wk.pop("score_mod", None)
if "softcap" in wk and wk["softcap"]:
wk["softcap"] = 0.0
# Apply combo
wk.update(combo)
with torch.cuda.stream(torch.cuda.current_stream()):
f(*wa, **wk)
try:
f(*wa, **wk)
except Exception as e:
# Some combos can be invalid for specific head dims / arch. Ignore and continue.
logger.debug("Warmup combo skipped: %s", e)
del wa, wk
torch.cuda.empty_cache()
gc.collect()
@@ -430,7 +502,9 @@ def warmup_flash_attn(f):
def wrapper(*args, **kwargs):
nonlocal done
if not done:
logger.info("Running flash_attn_varlen_func warmup passes...")
logger.info(
"Running FA4 warmup (global/causal/local, LSE on/off, optional GQA pack)..."
)
_run_warmups(args, kwargs)
done = True
return f(*args, **kwargs)
@@ -472,7 +546,7 @@ def flash_attn_varlen_func(
learnable_sink=learnable_sink,
softcap=softcap,
pack_gqa=pack_gqa,
return_softmax_lse=return_softmax_lse,
return_lse=return_softmax_lse,
)
return (out, lse) if return_softmax_lse else out