[NVIDIA] FA3/FA4 Fix (#11606)
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/b31ae1e4cd22cf5f820a2995b74b7cd3bd54355a/tests/cute/test_flash_attn.py
|
||||
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/8ecf128f683266735ba68e3c106ff67a2611886e/tests/cute/test_flash_attn.py
|
||||
|
||||
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
||||
|
||||
@@ -10,12 +10,25 @@ import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
from utils import is_hopper
|
||||
|
||||
try:
|
||||
from flash_attn.layers.rotary import apply_rotary_emb
|
||||
except ImportError:
|
||||
apply_rotary_emb = None
|
||||
|
||||
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
from sgl_kernel.testing.rotary_embedding import _apply_rotary_emb as apply_rotary_emb
|
||||
|
||||
# from utils import is_hopper # Not used in this test
|
||||
|
||||
# Force sgl_kernel.flash_attn wrappers to use FA4 (Cute-DSL) implementations.
|
||||
# The wrappers accept a superset of args; for FA4, extra args are ignored.
|
||||
flash_attn_varlen_func = partial(flash_attn_varlen_func, ver=4)
|
||||
flash_attn_with_kvcache = partial(flash_attn_with_kvcache, ver=4)
|
||||
|
||||
# Skip this test on Hopper machine
|
||||
skip_condition = torch.cuda.get_device_capability() < (10, 0)
|
||||
|
||||
|
||||
def unpad_input(hidden_states, attention_mask, unused_mask=None):
|
||||
"""
|
||||
@@ -88,6 +101,11 @@ def generate_random_padding_mask(
|
||||
lengths = torch.randint(
|
||||
max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device
|
||||
)
|
||||
else:
|
||||
# This should never happen due to the assertion above, but for linter
|
||||
lengths = torch.full(
|
||||
(batch_size, 1), max_seqlen, device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
if zero_lengths:
|
||||
# Generate zero-lengths every 5 batches and the last batch.
|
||||
@@ -482,8 +500,7 @@ def attention_ref(
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
is_hopper(),
|
||||
reason="skip on hopper",
|
||||
skip_condition, reason="FA4 Requires compute capability of 10 or above."
|
||||
)
|
||||
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@@ -497,8 +514,8 @@ def attention_ref(
|
||||
@pytest.mark.parametrize("deterministic", [False])
|
||||
# @pytest.mark.parametrize("softcap", [0.0, 15.0])
|
||||
@pytest.mark.parametrize("softcap", [0.0])
|
||||
@pytest.mark.parametrize("local", [False, True])
|
||||
# @pytest.mark.parametrize("local", [False])
|
||||
# @pytest.mark.parametrize("local", [False, True])
|
||||
@pytest.mark.parametrize("local", [False])
|
||||
@pytest.mark.parametrize("causal", [False, True])
|
||||
# @pytest.mark.parametrize("causal", [False])
|
||||
# @pytest.mark.parametrize("add_unused_qkv", [False, True])
|
||||
@@ -522,11 +539,11 @@ def attention_ref(
|
||||
(64, 128),
|
||||
(128, 128),
|
||||
(256, 256),
|
||||
(113, 203),
|
||||
(128, 217),
|
||||
(113, 211),
|
||||
(108, 256),
|
||||
(256, 512),
|
||||
# (113, 203),
|
||||
# (128, 217),
|
||||
# (113, 211),
|
||||
# (108, 256),
|
||||
# (256, 512),
|
||||
(307, 256),
|
||||
(640, 128),
|
||||
(512, 256),
|
||||
@@ -658,25 +675,7 @@ def test_flash_attn_varlen_output(
|
||||
if causal or local:
|
||||
key_padding_mask = query_padding_mask
|
||||
|
||||
(
|
||||
q_unpad,
|
||||
k_unpad,
|
||||
v_unpad,
|
||||
qv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
seqused_q,
|
||||
seqused_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
qv,
|
||||
output_pad_fn,
|
||||
dq_pad_fn,
|
||||
dk_pad_fn,
|
||||
) = generate_qkv(
|
||||
result = generate_qkv(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@@ -687,6 +686,25 @@ def test_flash_attn_varlen_output(
|
||||
query_unused_mask=query_unused_mask,
|
||||
key_unused_mask=key_unused_mask,
|
||||
)
|
||||
(
|
||||
q_unpad, # 0
|
||||
k_unpad, # 1
|
||||
v_unpad, # 2
|
||||
qv_unpad, # 3
|
||||
cu_seqlens_q, # 4
|
||||
cu_seqlens_k, # 5
|
||||
seqused_q, # 6
|
||||
seqused_k, # 7
|
||||
max_seqlen_q, # 8
|
||||
max_seqlen_k, # 9
|
||||
q, # 10
|
||||
k, # 11
|
||||
v, # 12
|
||||
qv, # 13
|
||||
output_pad_fn, # 14
|
||||
dq_pad_fn, # 15
|
||||
dk_pad_fn, # 16
|
||||
) = result
|
||||
q_unpad, k_unpad, v_unpad = [
|
||||
x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)
|
||||
]
|
||||
@@ -746,20 +764,16 @@ def test_flash_attn_varlen_output(
|
||||
v_unpad,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=None,
|
||||
max_seqlen_k=None,
|
||||
# seqused_q=seqused_q,
|
||||
# seqused_k=seqused_k,
|
||||
# max_seqlen_q and max_seqlen_k not needed for FA4
|
||||
seqused_q=seqused_q,
|
||||
seqused_k=seqused_k,
|
||||
causal=causal,
|
||||
# qv=qv_unpad,
|
||||
# q_descale=q_descale,
|
||||
# k_descale=k_descale, v_descale=v_descale,
|
||||
window_size=window_size,
|
||||
# attention_chunk=attention_chunk,
|
||||
sinks=learnable_sink,
|
||||
softcap=softcap,
|
||||
sinks=learnable_sink, # FA4 uses learnable_sink, not sinks
|
||||
pack_gqa=pack_gqa,
|
||||
return_softmax_lse=True,
|
||||
ver=4, # Use FA4
|
||||
)
|
||||
out = output_pad_fn(out_unpad)
|
||||
if query_unused_mask is not None:
|
||||
@@ -875,8 +889,7 @@ def test_flash_attn_varlen_output(
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
is_hopper(),
|
||||
reason="skip on hopper",
|
||||
skip_condition, reason="FA4 Requires compute capability of 10 or above."
|
||||
)
|
||||
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@@ -887,8 +900,8 @@ def test_flash_attn_varlen_output(
|
||||
# @pytest.mark.parametrize("has_learnable_sink", [False])
|
||||
# @pytest.mark.parametrize("new_kv", [False, True])
|
||||
@pytest.mark.parametrize("new_kv", [False])
|
||||
@pytest.mark.parametrize("local", [False, True])
|
||||
# @pytest.mark.parametrize("local", [False])
|
||||
# @pytest.mark.parametrize("local", [False, True])
|
||||
@pytest.mark.parametrize("local", [False])
|
||||
# @pytest.mark.parametrize("causal", [False, True])
|
||||
@pytest.mark.parametrize("causal", [True])
|
||||
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
|
||||
@@ -900,8 +913,8 @@ def test_flash_attn_varlen_output(
|
||||
# @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
|
||||
@pytest.mark.parametrize("rotary_fraction", [0.0])
|
||||
# @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128]))
|
||||
@pytest.mark.parametrize("page_size", [None, 128])
|
||||
# @pytest.mark.parametrize("page_size", [128])
|
||||
# @pytest.mark.parametrize("page_size", [None, 128])
|
||||
@pytest.mark.parametrize("page_size", [128])
|
||||
# @pytest.mark.parametrize("has_leftpad", [False, True])
|
||||
@pytest.mark.parametrize("has_leftpad", [False])
|
||||
# @pytest.mark.parametrize("has_batch_idx", [False, True])
|
||||
@@ -1085,6 +1098,7 @@ def test_flash_attn_kvcache(
|
||||
.to(dtype_ref)
|
||||
)
|
||||
page_table = None
|
||||
num_blocks = None
|
||||
else:
|
||||
(
|
||||
k_cache,
|
||||
@@ -1301,31 +1315,24 @@ def test_flash_attn_kvcache(
|
||||
else:
|
||||
k_cache_paged.copy_(k_cache_saved)
|
||||
v_cache_paged.copy_(v_cache_saved)
|
||||
# out, lse, *rest = flash_attn_with_kvcache(
|
||||
out, lse, *rest = flash_attn_with_kvcache(
|
||||
# For FA4, use flash_attn_varlen_func directly instead of flash_attn_with_kvcache
|
||||
# This matches the pattern from the original FA4 test
|
||||
out, lse = flash_attn_varlen_func(
|
||||
q if not varlen_q else q_unpad,
|
||||
k_cache if page_size is None else k_cache_paged,
|
||||
v_cache if page_size is None else v_cache_paged,
|
||||
# k if not new_kv or not varlen_q else k_unpad,
|
||||
# v if not new_kv or not varlen_q else v_unpad,
|
||||
# qv=qv if not varlen_q else qv_unpad,
|
||||
# rotary_cos=cos,
|
||||
# rotary_sin=sin,
|
||||
cache_seqlens=cache_seqlens,
|
||||
# cache_batch_idx=cache_batch_idx,
|
||||
# cache_leftpad=cache_leftpad,
|
||||
page_table=page_table,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
# cu_seqlens_k_new=cu_seqlens_k_new,
|
||||
# rotary_seqlens=rotary_seqlens,
|
||||
cu_seqlens_k=None, # FA4 doesn't use cu_seqlens_k for KV cache
|
||||
# max_seqlen_q and max_seqlen_k not needed for FA4
|
||||
seqused_k=cache_seqlens, # Use cache_seqlens as seqused_k
|
||||
page_table=page_table,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
sinks=learnable_sink,
|
||||
# attention_chunk=attention_chunk,
|
||||
# rotary_interleaved=rotary_interleaved,
|
||||
# scheduler_metadata=scheduler_metadata,
|
||||
# num_splits=num_splits,
|
||||
sinks=learnable_sink, # FA4 uses learnable_sink, not sinks
|
||||
softcap=0.0,
|
||||
pack_gqa=None,
|
||||
return_softmax_lse=True,
|
||||
ver=4, # Use FA4
|
||||
)
|
||||
if varlen_q:
|
||||
out = output_pad_fn(out)
|
||||
|
||||
Reference in New Issue
Block a user