[NVIDIA] FA3/FA4 Fix (#11606)

Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
Johnny
2025-10-20 02:10:10 +02:00
committed by GitHub
parent cbb5fc2edc
commit 252dc4e112
10 changed files with 382 additions and 219 deletions

View File

@@ -1,4 +1,4 @@
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/b31ae1e4cd22cf5f820a2995b74b7cd3bd54355a/tests/cute/test_flash_attn.py
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/8ecf128f683266735ba68e3c106ff67a2611886e/tests/cute/test_flash_attn.py
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
@@ -10,12 +10,25 @@ import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from utils import is_hopper
try:
from flash_attn.layers.rotary import apply_rotary_emb
except ImportError:
apply_rotary_emb = None
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from sgl_kernel.testing.rotary_embedding import _apply_rotary_emb as apply_rotary_emb
# from utils import is_hopper # Not used in this test
# Force sgl_kernel.flash_attn wrappers to use FA4 (Cute-DSL) implementations.
# The wrappers accept a superset of args; for FA4, extra args are ignored.
flash_attn_varlen_func = partial(flash_attn_varlen_func, ver=4)
flash_attn_with_kvcache = partial(flash_attn_with_kvcache, ver=4)
# Skip this test on Hopper machine
skip_condition = torch.cuda.get_device_capability() < (10, 0)
def unpad_input(hidden_states, attention_mask, unused_mask=None):
"""
@@ -88,6 +101,11 @@ def generate_random_padding_mask(
lengths = torch.randint(
max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device
)
else:
# This should never happen due to the assertion above, but for linter
lengths = torch.full(
(batch_size, 1), max_seqlen, device=device, dtype=torch.int32
)
if zero_lengths:
# Generate zero-lengths every 5 batches and the last batch.
@@ -482,8 +500,7 @@ def attention_ref(
@pytest.mark.skipif(
is_hopper(),
reason="skip on hopper",
skip_condition, reason="FA4 Requires compute capability of 10 or above."
)
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@@ -497,8 +514,8 @@ def attention_ref(
@pytest.mark.parametrize("deterministic", [False])
# @pytest.mark.parametrize("softcap", [0.0, 15.0])
@pytest.mark.parametrize("softcap", [0.0])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [False])
# @pytest.mark.parametrize("local", [False, True])
@pytest.mark.parametrize("local", [False])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [False])
# @pytest.mark.parametrize("add_unused_qkv", [False, True])
@@ -522,11 +539,11 @@ def attention_ref(
(64, 128),
(128, 128),
(256, 256),
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
# (113, 203),
# (128, 217),
# (113, 211),
# (108, 256),
# (256, 512),
(307, 256),
(640, 128),
(512, 256),
@@ -658,25 +675,7 @@ def test_flash_attn_varlen_output(
if causal or local:
key_padding_mask = query_padding_mask
(
q_unpad,
k_unpad,
v_unpad,
qv_unpad,
cu_seqlens_q,
cu_seqlens_k,
seqused_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
qv,
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
) = generate_qkv(
result = generate_qkv(
q,
k,
v,
@@ -687,6 +686,25 @@ def test_flash_attn_varlen_output(
query_unused_mask=query_unused_mask,
key_unused_mask=key_unused_mask,
)
(
q_unpad, # 0
k_unpad, # 1
v_unpad, # 2
qv_unpad, # 3
cu_seqlens_q, # 4
cu_seqlens_k, # 5
seqused_q, # 6
seqused_k, # 7
max_seqlen_q, # 8
max_seqlen_k, # 9
q, # 10
k, # 11
v, # 12
qv, # 13
output_pad_fn, # 14
dq_pad_fn, # 15
dk_pad_fn, # 16
) = result
q_unpad, k_unpad, v_unpad = [
x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)
]
@@ -746,20 +764,16 @@ def test_flash_attn_varlen_output(
v_unpad,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=None,
max_seqlen_k=None,
# seqused_q=seqused_q,
# seqused_k=seqused_k,
# max_seqlen_q and max_seqlen_k not needed for FA4
seqused_q=seqused_q,
seqused_k=seqused_k,
causal=causal,
# qv=qv_unpad,
# q_descale=q_descale,
# k_descale=k_descale, v_descale=v_descale,
window_size=window_size,
# attention_chunk=attention_chunk,
sinks=learnable_sink,
softcap=softcap,
sinks=learnable_sink, # FA4 uses learnable_sink, not sinks
pack_gqa=pack_gqa,
return_softmax_lse=True,
ver=4, # Use FA4
)
out = output_pad_fn(out_unpad)
if query_unused_mask is not None:
@@ -875,8 +889,7 @@ def test_flash_attn_varlen_output(
@pytest.mark.skipif(
is_hopper(),
reason="skip on hopper",
skip_condition, reason="FA4 Requires compute capability of 10 or above."
)
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@@ -887,8 +900,8 @@ def test_flash_attn_varlen_output(
# @pytest.mark.parametrize("has_learnable_sink", [False])
# @pytest.mark.parametrize("new_kv", [False, True])
@pytest.mark.parametrize("new_kv", [False])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [False])
# @pytest.mark.parametrize("local", [False, True])
@pytest.mark.parametrize("local", [False])
# @pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("causal", [True])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
@@ -900,8 +913,8 @@ def test_flash_attn_varlen_output(
# @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
@pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128]))
@pytest.mark.parametrize("page_size", [None, 128])
# @pytest.mark.parametrize("page_size", [128])
# @pytest.mark.parametrize("page_size", [None, 128])
@pytest.mark.parametrize("page_size", [128])
# @pytest.mark.parametrize("has_leftpad", [False, True])
@pytest.mark.parametrize("has_leftpad", [False])
# @pytest.mark.parametrize("has_batch_idx", [False, True])
@@ -1085,6 +1098,7 @@ def test_flash_attn_kvcache(
.to(dtype_ref)
)
page_table = None
num_blocks = None
else:
(
k_cache,
@@ -1301,31 +1315,24 @@ def test_flash_attn_kvcache(
else:
k_cache_paged.copy_(k_cache_saved)
v_cache_paged.copy_(v_cache_saved)
# out, lse, *rest = flash_attn_with_kvcache(
out, lse, *rest = flash_attn_with_kvcache(
# For FA4, use flash_attn_varlen_func directly instead of flash_attn_with_kvcache
# This matches the pattern from the original FA4 test
out, lse = flash_attn_varlen_func(
q if not varlen_q else q_unpad,
k_cache if page_size is None else k_cache_paged,
v_cache if page_size is None else v_cache_paged,
# k if not new_kv or not varlen_q else k_unpad,
# v if not new_kv or not varlen_q else v_unpad,
# qv=qv if not varlen_q else qv_unpad,
# rotary_cos=cos,
# rotary_sin=sin,
cache_seqlens=cache_seqlens,
# cache_batch_idx=cache_batch_idx,
# cache_leftpad=cache_leftpad,
page_table=page_table,
cu_seqlens_q=cu_seqlens_q,
# cu_seqlens_k_new=cu_seqlens_k_new,
# rotary_seqlens=rotary_seqlens,
cu_seqlens_k=None, # FA4 doesn't use cu_seqlens_k for KV cache
# max_seqlen_q and max_seqlen_k not needed for FA4
seqused_k=cache_seqlens, # Use cache_seqlens as seqused_k
page_table=page_table,
causal=causal,
window_size=window_size,
sinks=learnable_sink,
# attention_chunk=attention_chunk,
# rotary_interleaved=rotary_interleaved,
# scheduler_metadata=scheduler_metadata,
# num_splits=num_splits,
sinks=learnable_sink, # FA4 uses learnable_sink, not sinks
softcap=0.0,
pack_gqa=None,
return_softmax_lse=True,
ver=4, # Use FA4
)
if varlen_q:
out = output_pad_fn(out)