2025-04-13 02:36:36 +08:00
import math
from typing import List , Optional , Tuple
import pytest
import torch
from einops import rearrange , repeat
2025-04-29 02:03:17 +08:00
from sgl_kernel . sparse_flash_attn import (
convert_vertical_slash_indexes ,
convert_vertical_slash_indexes_mergehead ,
sparse_attn_func ,
)
2025-06-17 16:56:41 -07:00
from test_flash_attention import construct_local_mask , is_fa3_supported
2025-04-13 02:36:36 +08:00
def ref_attn (
q ,
k ,
v ,
query_padding_mask = None ,
key_padding_mask = None ,
attn_bias = None ,
dropout_p = 0.0 ,
dropout_mask = None ,
causal = False ,
window_size = ( - 1 , - 1 ) , # -1 means infinite window size
softcap = 0.0 ,
upcast = True ,
reorder_ops = False ,
key_leftpad = None ,
) :
"""
Arguments :
q : ( batch_size , seqlen_q , nheads , head_dim )
k : ( batch_size , seqlen_k , nheads_k , head_dim )
v : ( batch_size , seqlen_k , nheads_k , head_dim )
query_padding_mask : ( batch_size , seqlen_q )
key_padding_mask : ( batch_size , seqlen_k )
attn_bias : broadcastable to ( batch_size , nheads , seqlen_q , seqlen_k )
dropout_p : float
dropout_mask : ( batch_size , nheads , seqlen_q , seqlen_k )
causal : whether to apply causal masking
window_size : ( int , int ) , left and right window size
upcast : whether to cast all inputs to fp32 , do all computation in fp32 , then cast
output back to fp16 / bf16 .
reorder_ops : whether to change the order of operations ( scaling k instead of scaling q , etc . )
without changing the math . This is to estimate the numerical error from operation
reordering .
Output :
output : ( batch_size , seqlen_q , nheads , head_dim )
lse : ( batch_size , nheads , seqlen_q )
"""
if causal :
window_size = ( window_size [ 0 ] , 0 )
dtype_og = q . dtype
if upcast :
q , k , v = q . float ( ) , k . float ( ) , v . float ( )
seqlen_q , seqlen_k = q . shape [ 1 ] , k . shape [ 1 ]
k = repeat ( k , " b s h d -> b s (h g) d " , g = q . shape [ 2 ] / / k . shape [ 2 ] )
v = repeat ( v , " b s h d -> b s (h g) d " , g = q . shape [ 2 ] / / v . shape [ 2 ] )
d = q . shape [ - 1 ]
if not reorder_ops :
scores = torch . einsum ( " bthd,bshd->bhts " , q / math . sqrt ( d ) , k )
else :
scores = torch . einsum ( " bthd,bshd->bhts " , q , k / math . sqrt ( d ) )
lse_ref = scores . logsumexp ( dim = - 1 )
if softcap > 0 :
scores = scores / softcap
scores = scores . tanh ( )
scores = scores * softcap
if key_padding_mask is not None :
scores . masked_fill_ (
rearrange ( ~ key_padding_mask , " b s -> b 1 1 s " ) , float ( " -inf " )
)
if window_size [ 0 ] > = 0 or window_size [ 1 ] > = 0 :
local_mask = construct_local_mask (
seqlen_q ,
seqlen_k ,
window_size ,
query_padding_mask ,
key_padding_mask ,
q . device ,
key_leftpad = key_leftpad ,
)
scores . masked_fill_ ( local_mask , float ( " -inf " ) )
if attn_bias is not None :
scores = scores + attn_bias
attention = torch . softmax ( scores , dim = - 1 ) . to ( v . dtype )
# Some rows might be completely masked out so we fill them with zero instead of NaN
if window_size [ 0 ] > = 0 or window_size [ 1 ] > = 0 :
attention = attention . masked_fill (
torch . all ( local_mask , dim = - 1 , keepdim = True ) , 0.0
)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if query_padding_mask is not None :
attention = attention . masked_fill (
rearrange ( ~ query_padding_mask , " b s -> b 1 s 1 " ) , 0.0
)
dropout_scaling = 1.0 / ( 1 - dropout_p )
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
if dropout_mask is not None :
attention_drop = attention . masked_fill ( ~ dropout_mask , 0.0 )
else :
attention_drop = attention
output = torch . einsum ( " bhts,bshd->bthd " , attention_drop , v * dropout_scaling )
if query_padding_mask is not None :
output . masked_fill_ ( rearrange ( ~ query_padding_mask , " b s -> b s 1 1 " ) , 0.0 )
return output . to ( dtype = dtype_og ) , lse_ref
def ref_paged_attn (
query : torch . Tensor ,
key_cache : torch . Tensor ,
value_cache : torch . Tensor ,
query_lens : List [ int ] ,
kv_lens : List [ int ] ,
block_tables : torch . Tensor ,
scale : float ,
sliding_window : Optional [ int ] = None ,
soft_cap : Optional [ float ] = None ,
) - > torch . Tensor :
num_seqs = len ( query_lens )
block_tables = block_tables . cpu ( ) . numpy ( )
_ , block_size , num_kv_heads , head_size = key_cache . shape
outputs : List [ torch . Tensor ] = [ ]
start_idx = 0
for i in range ( num_seqs ) :
query_len = query_lens [ i ]
kv_len = kv_lens [ i ]
# clone to avoid clobbering the query tensor
q = query [ start_idx : start_idx + query_len ] . clone ( )
q * = scale
num_kv_blocks = ( kv_len + block_size - 1 ) / / block_size
block_indices = block_tables [ i , : num_kv_blocks ]
k = key_cache [ block_indices ] . view ( - 1 , num_kv_heads , head_size )
k = k [ : kv_len ]
v = value_cache [ block_indices ] . view ( - 1 , num_kv_heads , head_size )
v = v [ : kv_len ]
if q . shape [ 1 ] != k . shape [ 1 ] :
k = torch . repeat_interleave ( k , q . shape [ 1 ] / / k . shape [ 1 ] , dim = 1 )
v = torch . repeat_interleave ( v , q . shape [ 1 ] / / v . shape [ 1 ] , dim = 1 )
attn = torch . einsum ( " qhd,khd->hqk " , q , k ) . float ( )
empty_mask = torch . ones ( query_len , kv_len )
mask = torch . triu ( empty_mask , diagonal = kv_len - query_len + 1 ) . bool ( )
if sliding_window is not None :
sliding_window_mask = (
torch . triu (
empty_mask , diagonal = kv_len - ( query_len + sliding_window ) + 1
)
. bool ( )
. logical_not ( )
)
mask | = sliding_window_mask
if soft_cap is not None :
attn = soft_cap * torch . tanh ( attn / soft_cap )
attn . masked_fill_ ( mask , float ( " -inf " ) )
attn = torch . softmax ( attn , dim = - 1 ) . to ( v . dtype )
out = torch . einsum ( " hqk,khd->qhd " , attn , v )
outputs . append ( out )
start_idx + = query_len
return torch . cat ( outputs , dim = 0 )
2025-06-17 16:56:41 -07:00
@pytest.mark.skipif (
not is_fa3_supported ( ) ,
reason = " flash_attn at sgl-kernel is only supported on sm90 or sm80 " ,
)
2025-04-13 02:36:36 +08:00
@pytest.mark.parametrize ( " batch_size " , [ 1 , 2 ] )
@pytest.mark.parametrize (
" seq_lens " ,
[
( 1 , 1 ) ,
( 1 , 1024 ) ,
( 1 , 2048 ) ,
( 1023 , 2049 ) ,
( 1023 , 1023 ) ,
( 32 , 32 ) ,
( 65 , 65 ) ,
( 129 , 129 ) ,
] ,
)
@pytest.mark.parametrize ( " num_heads " , [ 1 , 2 , 4 ] )
@pytest.mark.parametrize ( " head_size " , [ 128 ] )
@pytest.mark.parametrize ( " dtype " , [ torch . float16 , torch . bfloat16 ] )
@pytest.mark.parametrize ( " NNZ_S " , [ 0 , 1 , 2 , 3 , 7 , 15 , 32 ] )
@torch.inference_mode ( )
def test_sparse_attention (
batch_size ,
seq_lens ,
num_heads ,
head_size ,
dtype ,
NNZ_S ,
) - > None :
torch . set_default_device ( " cuda " )
torch . cuda . manual_seed_all ( 0 )
block_size_M = 64
block_size_N = 64
seqlen_q , seqlen_k = seq_lens
q = torch . randn (
batch_size , seqlen_q , num_heads , head_size , dtype = dtype , requires_grad = False
)
k = torch . randn (
batch_size , seqlen_k , num_heads , head_size , dtype = dtype , requires_grad = False
)
v = torch . randn (
batch_size , seqlen_k , num_heads , head_size , dtype = dtype , requires_grad = False
)
NUM_ROWS = ( seqlen_q + block_size_M - 1 ) / / block_size_M
if NNZ_S * block_size_N > seqlen_k :
return
NNZ_V = seqlen_k - NNZ_S * block_size_N
block_count = torch . tensor (
[ NNZ_S ] * batch_size * NUM_ROWS * num_heads , dtype = torch . int32
) . reshape ( batch_size , num_heads , NUM_ROWS )
column_count = torch . tensor (
[ NNZ_V ] * batch_size * NUM_ROWS * num_heads , dtype = torch . int32
) . reshape ( batch_size , num_heads , NUM_ROWS )
block_offset = torch . tensor (
[ [ i * block_size_N for i in range ( NNZ_S ) ] ] * batch_size * NUM_ROWS * num_heads ,
dtype = torch . int32 ,
) . reshape ( batch_size , num_heads , NUM_ROWS , NNZ_S )
column_index = torch . tensor (
[ [ NNZ_S * block_size_N + i for i in range ( NNZ_V ) ] ]
* batch_size
* NUM_ROWS
* num_heads ,
dtype = torch . int32 ,
) . reshape ( batch_size , num_heads , NUM_ROWS , NNZ_V )
out , lse = sparse_attn_func (
q ,
k ,
v ,
block_count ,
block_offset ,
column_count ,
column_index ,
return_softmax_lse = True ,
)
ref_out , ref_lse = ref_attn ( q , k , v )
torch . testing . assert_close (
out , ref_out , atol = 2e-2 , rtol = 1e-2
) , f " { torch . max ( torch . abs ( out - ref_out ) ) } "
torch . testing . assert_close (
lse , ref_lse , atol = 2e-2 , rtol = 1e-2
) , f " { torch . max ( torch . abs ( lse - ref_lse ) ) } "
2025-04-29 02:03:17 +08:00
# sparse attention utils
# origin
2025-06-17 16:56:41 -07:00
@pytest.mark.skipif (
not is_fa3_supported ( ) ,
reason = " flash_attn at sgl-kernel is only supported on sm90 or sm80 " ,
)
2025-04-29 02:03:17 +08:00
@pytest.mark.parametrize ( " causal " , [ True , False ] )
def test_convert_vertical_slash_indexes ( causal ) :
# Prepare small, hand-checkable inputs
q_seqlens = torch . tensor ( [ 4 ] , dtype = torch . int32 , device = " cuda " ) # [BATCH]
kv_seqlens = torch . tensor ( [ 4 ] , dtype = torch . int32 , device = " cuda " )
vertical_indexes = torch . tensor (
[ [ [ 1 , 3 ] ] ] , dtype = torch . int32 , device = " cuda "
) # [BATCH, N_HEADS, NNZ_V]
slash_indexes = torch . tensor (
[ [ [ 2 ] ] ] , dtype = torch . int32 , device = " cuda "
) # [BATCH, N_HEADS, NNZ_S]
context_size = 4
block_size_M = 2
block_size_N = 2
# Call your CUDA kernel wrapper
block_count , block_offset , column_count , column_index = (
convert_vertical_slash_indexes (
q_seqlens ,
kv_seqlens ,
vertical_indexes ,
slash_indexes ,
context_size ,
block_size_M ,
block_size_N ,
causal = causal ,
)
)
# Manually create expected outputs for this input
# There are 2 rows (blocks): row0 (tokens 0-1), row1 (tokens 2-3)
# Fill these expected tensors based on your CUDA kernel's logic
# For demonstration, we assume:
# - block_count: how many slash indices fall into each block
# - block_offset: the value of those indices
# - column_count: number of valid vertical indices per block
# - column_index: the actual vertical indices
expected_column_index = torch . tensor (
[ [ [ [ 0 , 0 ] , [ 0 , 0 ] ] ] ] , dtype = torch . int32 , device = " cuda "
)
# If causal=False, update these tensors according to expected behavior
if not causal :
# Update these values if your kernel produces different output in non-causal mode
expected_column_index = torch . tensor (
[ [ [ [ 1 , 0 ] , [ 1 , 3 ] ] ] ] , dtype = torch . int32 , device = " cuda "
)
# Assert that outputs match expectations
assert torch . equal ( column_index , expected_column_index )
# mergehead
2025-06-17 16:56:41 -07:00
@pytest.mark.skipif (
not is_fa3_supported ( ) ,
reason = " flash_attn at sgl-kernel is only supported on sm90 or sm80 " ,
)
2025-04-29 02:03:17 +08:00
@pytest.mark.parametrize ( " causal " , [ True , False ] )
def test_convert_vertical_slash_indexes_mergehead ( causal ) :
# Prepare small, hand-checkable inputs for mergehead version
q_seqlens = torch . tensor ( [ 4 ] , dtype = torch . int32 , device = " cuda " )
kv_seqlens = torch . tensor ( [ 4 ] , dtype = torch . int32 , device = " cuda " )
vertical_indexes = torch . tensor (
[
[
[ 1 , 3 ] , # head 0
[ 2 , 0 ] , # head 1
]
] ,
dtype = torch . int32 ,
device = " cuda " ,
) # [BATCH, N_HEADS, NNZ_V]
slash_indexes = torch . tensor (
[
[
[ 2 , 0 ] , # head 0
[ 1 , 3 ] , # head 1
]
] ,
dtype = torch . int32 ,
device = " cuda " ,
) # [BATCH, N_HEADS, NNZ_S]
vertical_indices_count = torch . tensor ( [ 2 , 1 ] , dtype = torch . int32 , device = " cuda " )
slash_indices_count = torch . tensor ( [ 1 , 2 ] , dtype = torch . int32 , device = " cuda " )
context_size = 4
block_size_M = 2
block_size_N = 2
# Call your CUDA kernel wrapper
block_count , block_offset , column_count , column_index = (
convert_vertical_slash_indexes_mergehead (
q_seqlens ,
kv_seqlens ,
vertical_indexes ,
slash_indexes ,
vertical_indices_count ,
slash_indices_count ,
context_size ,
block_size_M ,
block_size_N ,
causal = causal ,
)
)
# Manually create expected outputs for this input
# For demonstration, assume:
# - batch=1, head=2, num_rows=2, nnz_v=2, nnz_s=2
# Fill these expected tensors according to your kernel's behavior
expected_column_index = torch . tensor (
[ [ [ [ 1 , 0 ] , [ 1 , 3 ] ] , [ [ - 1079459945 , - 1077788999 ] , [ - 1080050043 , - 1104625879 ] ] ] ] ,
dtype = torch . int32 ,
device = " cuda " ,
)
if not causal :
# If non-causal mode output is different, update these values
expected_column_index = torch . tensor (
[ [ [ [ 1 , 0 ] , [ 1 , 3 ] ] , [ [ 2 , - 1077788999 ] , [ 2 , - 1104625879 ] ] ] ] ,
dtype = torch . int32 ,
device = " cuda " ,
)
# Assert that outputs match expectations
assert torch . equal ( column_index , expected_column_index )
# skip cause use fa2 for test
2025-04-13 02:36:36 +08:00
# @pytest.mark.parametrize("seq_lens", [[(1024, 1328)],
# [(1024, 1328), (1, 2048)],
# [(1025, 1328), (2, 2048)],
# [(1025, 2049), (2, 1281)],
# ])
# @pytest.mark.parametrize("head_size", [128])
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
# @torch.inference_mode()
# def test_sparse_attention_varlen(
# seq_lens,
# head_size,
# dtype,
# ) -> None:
# torch.set_default_device("cuda")
# torch.cuda.manual_seed_all(0)
# block_size_M = 64
# block_size_N = 64
# num_seqs = len(seq_lens)
# query_lens = [x[0] for x in seq_lens]
# kv_lens = [x[1] for x in seq_lens]
# num_heads = 1
# query = torch.randn(sum(query_lens),
# num_heads,
# head_size,
# dtype=dtype)
# key = torch.randn(sum(kv_lens),
# num_heads,
# head_size,
# dtype=dtype)
# value = torch.randn_like(key)
# cu_query_lens = torch.tensor([0] + query_lens,
# dtype=torch.int32).cumsum(dim=0,
# dtype=torch.int32)
# cu_kv_lens = torch.tensor([0] + kv_lens,
# dtype=torch.int32).cumsum(dim=0,
# dtype=torch.int32)
# max_query_len = max(query_lens)
# max_kv_len = max(kv_lens)
# NUM_ROWS = (max_query_len + block_size_M - 1) // block_size_M
# NNZ_S = 20
# NNZ_V = 2048
# batch_size = len(query_lens)
# block_counts = []
# column_counts = []
# block_offsets = []
# column_indices = []
# for b in range(batch_size):
# block_counts.append(torch.tensor([NNZ_S] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS))
# columns = kv_lens[b] - NNZ_S * block_size_N
# column_counts.append(torch.tensor([columns] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS))
# block_offsets.append(torch.tensor([[i * block_size_N for i in range(NNZ_S)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS, NNZ_S))
# column_indices.append(torch.tensor([[NNZ_S * block_size_N + i for i in range(NNZ_V)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS, NNZ_V))
# block_count = torch.concat(block_counts).reshape(batch_size, num_heads, NUM_ROWS)
# column_count = torch.concat(column_counts).reshape(batch_size, num_heads, NUM_ROWS)
# block_offset = torch.concat(block_offsets).reshape(batch_size, num_heads, NUM_ROWS, NNZ_S)
# column_index = torch.concat(column_indices).reshape(batch_size, num_heads, NUM_ROWS, NNZ_V)
# out, lse = sparse_attn_varlen_func(
# query,
# key,
# value,
# block_count,
# block_offset,
# column_count,
# column_index,
# cu_seqlens_q=cu_query_lens,
# cu_seqlens_k=cu_kv_lens,
# max_seqlen_q=max_query_len,
# max_seqlen_k=max_kv_len,
# return_softmax_lse=True,
# )
# max_num_blocks_per_seq = (max_kv_len + 2048 - 1) // 2048
# block_tables = torch.randint(0,
# 2048,
# (len(query_lens), max_num_blocks_per_seq),
# dtype=torch.int32)
# scale = head_size**-0.5
# ref_out, ref_lse, _ = ref_paged_attn(
# query,
# key,
# value,
# query_lens=query_lens,
# kv_lens=kv_lens,
# block_tables=block_tables,
# scale=scale
# )
# torch.testing.assert_close(out, ref_out, atol=2e-2, rtol=1e-2), \
# f"{torch.max(torch.abs(out - ref_out))}"
# torch.testing.assert_close(lse, ref_lse, atol=2e-2, rtol=1e-2), \
# f"{torch.max(torch.abs(lse - ref_lse))}"
if __name__ == " __main__ " :
pytest . main ( [ __file__ ] )