[Feat] QWen-1M context support[1/2]: Update block sparse attention backend utils kernel (#5847)

Co-authored-by: sighingnow <sighingnow@gmail.com>
This commit is contained in:
PGFLMG
2025-04-29 02:03:17 +08:00
committed by GitHub
parent d364b9b0f2
commit ee71ed8a41
6 changed files with 763 additions and 1 deletions

View File

@@ -4,7 +4,12 @@ from typing import List, Optional, Tuple
import pytest
import torch
from einops import rearrange, repeat
from sgl_kernel.sparse_flash_attn import sparse_attn_func, sparse_attn_varlen_func
from sgl_kernel.sparse_flash_attn import (
convert_vertical_slash_indexes,
convert_vertical_slash_indexes_mergehead,
sparse_attn_func,
sparse_attn_varlen_func,
)
def ref_attn(
@@ -249,6 +254,133 @@ def test_sparse_attention(
), f"{torch.max(torch.abs(lse - ref_lse))}"
# sparse attention utils
# origin
@pytest.mark.parametrize("causal", [True, False])
def test_convert_vertical_slash_indexes(causal):
# Prepare small, hand-checkable inputs
q_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda") # [BATCH]
kv_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda")
vertical_indexes = torch.tensor(
[[[1, 3]]], dtype=torch.int32, device="cuda"
) # [BATCH, N_HEADS, NNZ_V]
slash_indexes = torch.tensor(
[[[2]]], dtype=torch.int32, device="cuda"
) # [BATCH, N_HEADS, NNZ_S]
context_size = 4
block_size_M = 2
block_size_N = 2
# Call your CUDA kernel wrapper
block_count, block_offset, column_count, column_index = (
convert_vertical_slash_indexes(
q_seqlens,
kv_seqlens,
vertical_indexes,
slash_indexes,
context_size,
block_size_M,
block_size_N,
causal=causal,
)
)
# Manually create expected outputs for this input
# There are 2 rows (blocks): row0 (tokens 0-1), row1 (tokens 2-3)
# Fill these expected tensors based on your CUDA kernel's logic
# For demonstration, we assume:
# - block_count: how many slash indices fall into each block
# - block_offset: the value of those indices
# - column_count: number of valid vertical indices per block
# - column_index: the actual vertical indices
expected_column_index = torch.tensor(
[[[[0, 0], [0, 0]]]], dtype=torch.int32, device="cuda"
)
# If causal=False, update these tensors according to expected behavior
if not causal:
# Update these values if your kernel produces different output in non-causal mode
expected_column_index = torch.tensor(
[[[[1, 0], [1, 3]]]], dtype=torch.int32, device="cuda"
)
# Assert that outputs match expectations
assert torch.equal(column_index, expected_column_index)
# mergehead
@pytest.mark.parametrize("causal", [True, False])
def test_convert_vertical_slash_indexes_mergehead(causal):
# Prepare small, hand-checkable inputs for mergehead version
q_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda")
kv_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda")
vertical_indexes = torch.tensor(
[
[
[1, 3], # head 0
[2, 0], # head 1
]
],
dtype=torch.int32,
device="cuda",
) # [BATCH, N_HEADS, NNZ_V]
slash_indexes = torch.tensor(
[
[
[2, 0], # head 0
[1, 3], # head 1
]
],
dtype=torch.int32,
device="cuda",
) # [BATCH, N_HEADS, NNZ_S]
vertical_indices_count = torch.tensor([2, 1], dtype=torch.int32, device="cuda")
slash_indices_count = torch.tensor([1, 2], dtype=torch.int32, device="cuda")
context_size = 4
block_size_M = 2
block_size_N = 2
# Call your CUDA kernel wrapper
block_count, block_offset, column_count, column_index = (
convert_vertical_slash_indexes_mergehead(
q_seqlens,
kv_seqlens,
vertical_indexes,
slash_indexes,
vertical_indices_count,
slash_indices_count,
context_size,
block_size_M,
block_size_N,
causal=causal,
)
)
# Manually create expected outputs for this input
# For demonstration, assume:
# - batch=1, head=2, num_rows=2, nnz_v=2, nnz_s=2
# Fill these expected tensors according to your kernel's behavior
expected_column_index = torch.tensor(
[[[[1, 0], [1, 3]], [[-1079459945, -1077788999], [-1080050043, -1104625879]]]],
dtype=torch.int32,
device="cuda",
)
if not causal:
# If non-causal mode output is different, update these values
expected_column_index = torch.tensor(
[[[[1, 0], [1, 3]], [[2, -1077788999], [2, -1104625879]]]],
dtype=torch.int32,
device="cuda",
)
# Assert that outputs match expectations
assert torch.equal(column_index, expected_column_index)
# skip cause use fa2 for test
# @pytest.mark.parametrize("seq_lens", [[(1024, 1328)],
# [(1024, 1328), (1, 2048)],
# [(1025, 1328), (2, 2048)],