[Feat] QWen-1M context support[1/2]: Update block sparse attention backend utils kernel (#5847)
Co-authored-by: sighingnow <sighingnow@gmail.com>
This commit is contained in:
@@ -4,7 +4,12 @@ from typing import List, Optional, Tuple
|
||||
import pytest
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from sgl_kernel.sparse_flash_attn import sparse_attn_func, sparse_attn_varlen_func
|
||||
from sgl_kernel.sparse_flash_attn import (
|
||||
convert_vertical_slash_indexes,
|
||||
convert_vertical_slash_indexes_mergehead,
|
||||
sparse_attn_func,
|
||||
sparse_attn_varlen_func,
|
||||
)
|
||||
|
||||
|
||||
def ref_attn(
|
||||
@@ -249,6 +254,133 @@ def test_sparse_attention(
|
||||
), f"{torch.max(torch.abs(lse - ref_lse))}"
|
||||
|
||||
|
||||
# sparse attention utils
|
||||
# origin
|
||||
@pytest.mark.parametrize("causal", [True, False])
|
||||
def test_convert_vertical_slash_indexes(causal):
|
||||
# Prepare small, hand-checkable inputs
|
||||
q_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda") # [BATCH]
|
||||
kv_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda")
|
||||
vertical_indexes = torch.tensor(
|
||||
[[[1, 3]]], dtype=torch.int32, device="cuda"
|
||||
) # [BATCH, N_HEADS, NNZ_V]
|
||||
slash_indexes = torch.tensor(
|
||||
[[[2]]], dtype=torch.int32, device="cuda"
|
||||
) # [BATCH, N_HEADS, NNZ_S]
|
||||
context_size = 4
|
||||
block_size_M = 2
|
||||
block_size_N = 2
|
||||
|
||||
# Call your CUDA kernel wrapper
|
||||
block_count, block_offset, column_count, column_index = (
|
||||
convert_vertical_slash_indexes(
|
||||
q_seqlens,
|
||||
kv_seqlens,
|
||||
vertical_indexes,
|
||||
slash_indexes,
|
||||
context_size,
|
||||
block_size_M,
|
||||
block_size_N,
|
||||
causal=causal,
|
||||
)
|
||||
)
|
||||
|
||||
# Manually create expected outputs for this input
|
||||
# There are 2 rows (blocks): row0 (tokens 0-1), row1 (tokens 2-3)
|
||||
# Fill these expected tensors based on your CUDA kernel's logic
|
||||
# For demonstration, we assume:
|
||||
# - block_count: how many slash indices fall into each block
|
||||
# - block_offset: the value of those indices
|
||||
# - column_count: number of valid vertical indices per block
|
||||
# - column_index: the actual vertical indices
|
||||
|
||||
expected_column_index = torch.tensor(
|
||||
[[[[0, 0], [0, 0]]]], dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
# If causal=False, update these tensors according to expected behavior
|
||||
if not causal:
|
||||
# Update these values if your kernel produces different output in non-causal mode
|
||||
expected_column_index = torch.tensor(
|
||||
[[[[1, 0], [1, 3]]]], dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
# Assert that outputs match expectations
|
||||
assert torch.equal(column_index, expected_column_index)
|
||||
|
||||
|
||||
# mergehead
|
||||
@pytest.mark.parametrize("causal", [True, False])
|
||||
def test_convert_vertical_slash_indexes_mergehead(causal):
|
||||
# Prepare small, hand-checkable inputs for mergehead version
|
||||
q_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda")
|
||||
kv_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda")
|
||||
vertical_indexes = torch.tensor(
|
||||
[
|
||||
[
|
||||
[1, 3], # head 0
|
||||
[2, 0], # head 1
|
||||
]
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
) # [BATCH, N_HEADS, NNZ_V]
|
||||
slash_indexes = torch.tensor(
|
||||
[
|
||||
[
|
||||
[2, 0], # head 0
|
||||
[1, 3], # head 1
|
||||
]
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
) # [BATCH, N_HEADS, NNZ_S]
|
||||
vertical_indices_count = torch.tensor([2, 1], dtype=torch.int32, device="cuda")
|
||||
slash_indices_count = torch.tensor([1, 2], dtype=torch.int32, device="cuda")
|
||||
context_size = 4
|
||||
block_size_M = 2
|
||||
block_size_N = 2
|
||||
|
||||
# Call your CUDA kernel wrapper
|
||||
block_count, block_offset, column_count, column_index = (
|
||||
convert_vertical_slash_indexes_mergehead(
|
||||
q_seqlens,
|
||||
kv_seqlens,
|
||||
vertical_indexes,
|
||||
slash_indexes,
|
||||
vertical_indices_count,
|
||||
slash_indices_count,
|
||||
context_size,
|
||||
block_size_M,
|
||||
block_size_N,
|
||||
causal=causal,
|
||||
)
|
||||
)
|
||||
|
||||
# Manually create expected outputs for this input
|
||||
# For demonstration, assume:
|
||||
# - batch=1, head=2, num_rows=2, nnz_v=2, nnz_s=2
|
||||
# Fill these expected tensors according to your kernel's behavior
|
||||
|
||||
expected_column_index = torch.tensor(
|
||||
[[[[1, 0], [1, 3]], [[-1079459945, -1077788999], [-1080050043, -1104625879]]]],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
if not causal:
|
||||
# If non-causal mode output is different, update these values
|
||||
expected_column_index = torch.tensor(
|
||||
[[[[1, 0], [1, 3]], [[2, -1077788999], [2, -1104625879]]]],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
# Assert that outputs match expectations
|
||||
assert torch.equal(column_index, expected_column_index)
|
||||
|
||||
|
||||
# skip cause use fa2 for test
|
||||
# @pytest.mark.parametrize("seq_lens", [[(1024, 1328)],
|
||||
# [(1024, 1328), (1, 2048)],
|
||||
# [(1025, 1328), (2, 2048)],
|
||||
|
||||
Reference in New Issue
Block a user