From ee71ed8a412709416dd688992dd084b1c0640495 Mon Sep 17 00:00:00 2001 From: PGFLMG <1106310035@qq.com> Date: Tue, 29 Apr 2025 02:03:17 +0800 Subject: [PATCH] [Feat] QWen-1M context support[1/2]: Update block sparse attention backend utils kernel (#5847) Co-authored-by: sighingnow --- sgl-kernel/CMakeLists.txt | 1 + .../csrc/attention/vertical_slash_index.cu | 459 ++++++++++++++++++ sgl-kernel/csrc/common_extension.cc | 22 + sgl-kernel/include/sgl_kernel_ops.h | 30 ++ .../python/sgl_kernel/sparse_flash_attn.py | 118 +++++ sgl-kernel/tests/test_sparse_flash_attn.py | 134 ++++- 6 files changed, 763 insertions(+), 1 deletion(-) create mode 100644 sgl-kernel/csrc/attention/vertical_slash_index.cu diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 76661f211..bbe6246eb 100755 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -176,6 +176,7 @@ set(SOURCES "csrc/attention/cascade.cu" "csrc/attention/merge_attn_states.cu" "csrc/attention/cutlass_mla_kernel.cu" + "csrc/attention/vertical_slash_index.cu" "csrc/attention/lightning_attention_decode_kernel.cu" "csrc/elementwise/activation.cu" "csrc/elementwise/fused_add_rms_norm_kernel.cu" diff --git a/sgl-kernel/csrc/attention/vertical_slash_index.cu b/sgl-kernel/csrc/attention/vertical_slash_index.cu new file mode 100644 index 000000000..93c936fdd --- /dev/null +++ b/sgl-kernel/csrc/attention/vertical_slash_index.cu @@ -0,0 +1,459 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// This file is for blocksparse attention utils cuda kernel. + +#include +#include +#include + +// Save the start index of each block in the given range into block_offset. +// Returns the updated block count. +__device__ int64_t save_blocks( + int* block_offset, + int64_t range_start, + int64_t range_end, + int64_t block_size, + int64_t input_block_count, + int64_t kv_seqlen) { + if (range_start >= kv_seqlen) { + return input_block_count; + } + if (range_end > kv_seqlen) { + range_end = kv_seqlen; + } + int64_t current_block_count = input_block_count; + for (int idx = range_start; idx < range_end; idx += block_size) { + block_offset[current_block_count++] = idx; + } + return current_block_count; +} + +// CUDA kernel: convert sparse vertical/slash indices to block/column offsets. +__global__ void convert_vertical_slash_indexes_kernel( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t N_HEADS, + int64_t N_ROWS, + int64_t BLOCK_SIZE_M, + int64_t BLOCK_SIZE_N, + int64_t NNZ_V, + int64_t NNZ_S, + bool causal // True for intra, False for succ +) { + const int batch_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int group_idx = blockIdx.z; + + int64_t q_seqlen = q_seqlens[batch_idx]; + int64_t kv_seqlen = kv_seqlens[batch_idx]; + int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x; + int64_t start_m = block_idx_m * BLOCK_SIZE_M; + if (start_m >= q_seqlen) { + return; + } + int64_t end_m = start_m + BLOCK_SIZE_M; + vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; + slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; + int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; + block_count += row_offset; + block_offset += row_offset * NNZ_S; + column_count += row_offset; + column_index += row_offset * NNZ_V; + + bool has_slash = true; + int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0; + int64_t s = 0, v = 0; + int64_t v_idx = vertical_indexes[v++]; + int64_t s_idx = slash_indexes[s++]; + if (causal) { + while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false; + s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M); + } else { + while (s_idx >= end_m + kv_seqlen && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + kv_seqlen) has_slash = false; + s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M); + } + + int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; + if (!has_slash) { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + } + + bool slash_finished = false; + while (1) { + if (v_idx < range_end) { + if (v_idx < range_start) { + column_index[tmp_col_cnt++] = v_idx; + } + if (v < NNZ_V) { + v_idx = vertical_indexes[v++]; + } else { + if (causal) + v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen); + else + v_idx = end_m + BLOCK_SIZE_N + kv_seqlen; + } + } else { + if ((s < NNZ_S && causal) || (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) { + if (causal) + s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++], BLOCK_SIZE_M); + else + s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M); + } else { + if (v == NNZ_V || (v_idx > range_start && causal)) { + // add the last vertical if no more slash + if (v == NNZ_V && !causal && v_idx < kv_seqlen) { + column_index[tmp_col_cnt++] = v_idx; + } + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + break; + } else { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + // if slash_finished but there are vertical left, save current + // blocks + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + slash_finished = true; + } + } + if (!slash_finished) { + if (s_idx > range_end + BLOCK_SIZE_M) { + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + range_start = s_idx - BLOCK_SIZE_M; + range_end = s_idx; + } else if (s_idx > range_end) { + range_end += BLOCK_SIZE_M; + } + } + } + } + + block_count[0] = tmp_blk_cnt; + column_count[0] = tmp_col_cnt; +} + +// Host function: launches the kernel with 64 threads per block. +void convert_vertical_slash_indexes_64x64( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t BATCH_SIZE, + int64_t N_HEADS, + int64_t N_ROWS, + int64_t BLOCK_SIZE_M, + int64_t BLOCK_SIZE_N, + int64_t NNZ_V, + int64_t NNZ_S, + bool causal) { + const int N_THREADS = 64; + const dim3 dimBlock((int32_t)N_THREADS); + const dim3 dimGrid( + (int32_t)N_HEADS, (int32_t)BATCH_SIZE, ((int32_t)N_ROWS + (int32_t)N_THREADS - 1) / (int32_t)N_THREADS); + convert_vertical_slash_indexes_kernel<<>>( + q_seqlens, + kv_seqlens, + vertical_indexes, + slash_indexes, + block_count, + block_offset, + column_count, + column_index, + N_HEADS, + N_ROWS, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + NNZ_V, + NNZ_S, + causal); +} + +// Host function: prepares tensor pointers and launches the CUDA kernel. +void convert_vertical_slash_indexes( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int64_t context_size, + int64_t block_size_M, + int64_t block_size_N, + bool causal) { + cudaSetDevice(q_seqlens.get_device()); + + int64_t batch_size = slash_indexes.size(0); + int64_t num_heads = slash_indexes.size(1); + int64_t nnz_slash = slash_indexes.size(2); + int64_t nnz_vertical = vertical_indexes.size(2); + int64_t num_rows = (context_size + block_size_M - 1) / block_size_M; + + convert_vertical_slash_indexes_64x64( + q_seqlens.data_ptr(), + kv_seqlens.data_ptr(), + vertical_indexes.data_ptr(), + slash_indexes.data_ptr(), + block_count.data_ptr(), + block_offset.data_ptr(), + column_count.data_ptr(), + column_index.data_ptr(), + batch_size, + num_heads, + num_rows, + block_size_M, + block_size_N, + nnz_vertical, + nnz_slash, + causal); +} + +// --- mergehead kernels --- // + +// Kernel: like above, but supports per-head variable NNZ_V/NNZ_S. +__global__ void convert_vertical_slash_indexes_kernel_mergehead( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + const int* per_head_vertical_topkv, + const int* per_head_slash_topkv, + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t N_HEADS, + int64_t N_ROWS, + int64_t BLOCK_SIZE_M, + int64_t BLOCK_SIZE_N, + int64_t NNZ_V, + int64_t NNZ_S, + bool causal // True for intra, False for succ +) { + const int batch_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int group_idx = blockIdx.z; + + int64_t q_seqlen = q_seqlens[batch_idx]; + int64_t kv_seqlen = kv_seqlens[batch_idx]; + int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x; + int64_t start_m = block_idx_m * BLOCK_SIZE_M; + if (start_m >= q_seqlen) { + return; + } + int64_t end_m = start_m + BLOCK_SIZE_M; + vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; + slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; + int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; + block_count += row_offset; + block_offset += row_offset * NNZ_S; + column_count += row_offset; + column_index += row_offset * NNZ_V; + + // MergeHead: each head has it's unique max topk NNZ_V,NNZ_S. (NNZ_V,NNZ_S + // above is buffer size, use to compute offset) + NNZ_S = per_head_slash_topkv[head_idx]; + NNZ_V = per_head_vertical_topkv[head_idx]; + + bool has_slash = true; + int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0; + int64_t s = 0, v = 0; + int64_t v_idx = vertical_indexes[v++]; + int64_t s_idx = slash_indexes[s++]; + if (causal) { + while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false; + s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M); + } else { + while (s_idx >= end_m + kv_seqlen && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + kv_seqlen) has_slash = false; + s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M); + } + + int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; + if (!has_slash) { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + } + + bool slash_finished = false; + while (1) { + if (v_idx < range_end) { + if (v_idx < range_start) { + column_index[tmp_col_cnt++] = v_idx; + } + if (v < NNZ_V) { + v_idx = vertical_indexes[v++]; + } else { + if (causal) + v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen); + else + v_idx = end_m + BLOCK_SIZE_N + kv_seqlen; + } + } else { + if ((s < NNZ_S && causal) || (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) { + if (causal) + s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++], BLOCK_SIZE_M); + else + s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M); + } else { + if (v == NNZ_V || (v_idx > range_start && causal)) { + // add the last vertical if no more slash + if (v == NNZ_V && !causal && v_idx < kv_seqlen) { + column_index[tmp_col_cnt++] = v_idx; + } + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + break; + } else { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + // if slash_finished but there are vertical left, save current + // blocks + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + slash_finished = true; + } + } + if (!slash_finished) { + if (s_idx > range_end + BLOCK_SIZE_M) { + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + range_start = s_idx - BLOCK_SIZE_M; + range_end = s_idx; + } else if (s_idx > range_end) { + range_end += BLOCK_SIZE_M; + } + } + } + } + + block_count[0] = tmp_blk_cnt; + column_count[0] = tmp_col_cnt; +} + +// Launch the mergehead kernel with 64 threads per block. +void convert_vertical_slash_indexes_64x64_mergehead( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* per_head_vertical_topkv, + int* per_head_slash_topkv, + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t BATCH_SIZE, + int64_t N_HEADS, + int64_t N_ROWS, + int64_t BLOCK_SIZE_M, + int64_t BLOCK_SIZE_N, + int64_t NNZ_V, + int64_t NNZ_S, + bool causal) { + const int N_THREADS = 64; + const dim3 dimBlock(N_THREADS); + const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); + convert_vertical_slash_indexes_kernel_mergehead<<>>( + q_seqlens, + kv_seqlens, + vertical_indexes, + slash_indexes, + per_head_vertical_topkv, + per_head_slash_topkv, + block_count, + block_offset, + column_count, + column_index, + N_HEADS, + N_ROWS, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + NNZ_V, + NNZ_S, + causal); +} + +// Host wrapper for mergehead kernel. +void convert_vertical_slash_indexes_mergehead( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + torch::Tensor vertical_indices_count, // [N_HEADS, ] + torch::Tensor slash_indices_count, + int64_t context_size, + int64_t block_size_M, + int64_t block_size_N, + bool causal) { + cudaSetDevice(q_seqlens.get_device()); + + int batch_size = slash_indexes.size(0); + int num_heads = slash_indexes.size(1); + int nnz_slash = slash_indexes.size(2); + int nnz_vertical = vertical_indexes.size(2); + int num_rows = (context_size + block_size_M - 1) / block_size_M; + + convert_vertical_slash_indexes_64x64_mergehead( + q_seqlens.data_ptr(), + kv_seqlens.data_ptr(), + vertical_indexes.data_ptr(), + slash_indexes.data_ptr(), + vertical_indices_count.data_ptr(), + slash_indices_count.data_ptr(), + block_count.data_ptr(), + block_offset.data_ptr(), + column_count.data_ptr(), + column_index.data_ptr(), + batch_size, + num_heads, + num_rows, + block_size_M, + block_size_N, + nnz_vertical, + nnz_slash, + causal); +} diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index eee2fdee9..35bd7a1cb 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -234,6 +234,28 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "Generator? gen) -> Tensor[]"); m.impl("varlen_fwd_sparse", torch::kCUDA, &flash::mha_varlen_fwd_sparse); + // Sparse Attention utils + m.def( + "convert_vertical_slash_indexes(" + " Tensor! block_count, Tensor! block_offset, " + " Tensor! column_count, Tensor! column_index, " + " Tensor q_seqlens, Tensor q_seqlens, " + " Tensor vertical_indexes, Tensor slash_indexes, " + " int context_size, int block_size_M, int block_size_N, " + " bool causal) -> ()"); + m.impl("convert_vertical_slash_indexes", torch::kCUDA, &convert_vertical_slash_indexes); + + m.def( + "convert_vertical_slash_indexes_mergehead(" + " Tensor! block_count, Tensor! block_offset, " + " Tensor! column_count, Tensor! column_index, " + " Tensor q_seqlens, Tensor q_seqlens, " + " Tensor vertical_indexes, Tensor slash_indexes, " + " Tensor vertical_indices_count, Tensor slash_indices_count, " + " int context_size, int block_size_M, int block_size_N, " + " bool causal) -> ()"); + m.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA, &convert_vertical_slash_indexes_mergehead); + /* * From XGrammar */ diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index f8a3294e6..bf608456d 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -353,6 +353,36 @@ std::vector mha_varlen_fwd_sparse( c10::optional gen_); } // namespace flash +void convert_vertical_slash_indexes( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int64_t context_size, + int64_t block_size_M, + int64_t block_size_N, + bool causal); + +void convert_vertical_slash_indexes_mergehead( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + torch::Tensor vertical_indices_count, // [N_HEADS, ] + torch::Tensor slash_indices_count, + int64_t context_size, + int64_t block_size_M, + int64_t block_size_N, + bool causal); + /* * From XGrammar */ diff --git a/sgl-kernel/python/sgl_kernel/sparse_flash_attn.py b/sgl-kernel/python/sgl_kernel/sparse_flash_attn.py index c4ffad7da..29b2f0405 100644 --- a/sgl-kernel/python/sgl_kernel/sparse_flash_attn.py +++ b/sgl-kernel/python/sgl_kernel/sparse_flash_attn.py @@ -8,6 +8,124 @@ def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x +# Sparse attention utils +def convert_vertical_slash_indexes( + q_seqlens: torch.Tensor, # [BATCH, ] + kv_seqlens: torch.Tensor, # [BATCH, ] + vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] + slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] + context_size: int, + block_size_M: int, + block_size_N: int, + causal: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + batch_size = slash_indexes.size(0) + num_heads = slash_indexes.size(1) + nnz_slash = slash_indexes.size(2) + nnz_vertical = vertical_indexes.size(2) + num_rows = (context_size + block_size_M - 1) // block_size_M + + block_count = torch.zeros( + batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device + ) + block_offset = torch.zeros( + batch_size, + num_heads, + num_rows, + nnz_slash, + dtype=q_seqlens.dtype, + device=q_seqlens.device, + ) + column_count = torch.zeros( + batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device + ) + column_index = torch.zeros( + batch_size, + num_heads, + num_rows, + nnz_vertical, + dtype=q_seqlens.dtype, + device=q_seqlens.device, + ) + + torch.ops.sgl_kernel.convert_vertical_slash_indexes.default( + block_count, + block_offset, + column_count, + column_index, + q_seqlens, + kv_seqlens, + vertical_indexes, + slash_indexes, + context_size, + block_size_M, + block_size_N, + causal, + ) + return block_count, block_offset, column_count, column_index + + +def convert_vertical_slash_indexes_mergehead( + q_seqlens: torch.Tensor, # [BATCH, ] + kv_seqlens: torch.Tensor, # [BATCH, ] + vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] + slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] + # [N_HEADS] : different head use different number of indices + vertical_indices_count: torch.Tensor, + slash_indices_count: torch.Tensor, + context_size: int, + block_size_M: int, + block_size_N: int, + causal: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + batch_size = slash_indexes.size(0) + num_heads = slash_indexes.size(1) + nnz_slash = slash_indexes.size(2) + nnz_vertical = vertical_indexes.size(2) + num_rows = (context_size + block_size_M - 1) // block_size_M + + block_count = torch.empty( + batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device + ) + block_offset = torch.empty( + batch_size, + num_heads, + num_rows, + nnz_slash, + dtype=q_seqlens.dtype, + device=q_seqlens.device, + ) + column_count = torch.empty( + batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device + ) + column_index = torch.empty( + batch_size, + num_heads, + num_rows, + nnz_vertical, + dtype=q_seqlens.dtype, + device=q_seqlens.device, + ) + + torch.ops.sgl_kernel.convert_vertical_slash_indexes_mergehead.default( + block_count, + block_offset, + column_count, + column_index, + q_seqlens, + kv_seqlens, + vertical_indexes, + slash_indexes, + vertical_indices_count, + slash_indices_count, + context_size, + block_size_M, + block_size_N, + causal, + ) + return block_count, block_offset, column_count, column_index + + def sparse_attn_func( q, k, diff --git a/sgl-kernel/tests/test_sparse_flash_attn.py b/sgl-kernel/tests/test_sparse_flash_attn.py index bb964f335..4ddb6d7f5 100644 --- a/sgl-kernel/tests/test_sparse_flash_attn.py +++ b/sgl-kernel/tests/test_sparse_flash_attn.py @@ -4,7 +4,12 @@ from typing import List, Optional, Tuple import pytest import torch from einops import rearrange, repeat -from sgl_kernel.sparse_flash_attn import sparse_attn_func, sparse_attn_varlen_func +from sgl_kernel.sparse_flash_attn import ( + convert_vertical_slash_indexes, + convert_vertical_slash_indexes_mergehead, + sparse_attn_func, + sparse_attn_varlen_func, +) def ref_attn( @@ -249,6 +254,133 @@ def test_sparse_attention( ), f"{torch.max(torch.abs(lse - ref_lse))}" +# sparse attention utils +# origin +@pytest.mark.parametrize("causal", [True, False]) +def test_convert_vertical_slash_indexes(causal): + # Prepare small, hand-checkable inputs + q_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda") # [BATCH] + kv_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda") + vertical_indexes = torch.tensor( + [[[1, 3]]], dtype=torch.int32, device="cuda" + ) # [BATCH, N_HEADS, NNZ_V] + slash_indexes = torch.tensor( + [[[2]]], dtype=torch.int32, device="cuda" + ) # [BATCH, N_HEADS, NNZ_S] + context_size = 4 + block_size_M = 2 + block_size_N = 2 + + # Call your CUDA kernel wrapper + block_count, block_offset, column_count, column_index = ( + convert_vertical_slash_indexes( + q_seqlens, + kv_seqlens, + vertical_indexes, + slash_indexes, + context_size, + block_size_M, + block_size_N, + causal=causal, + ) + ) + + # Manually create expected outputs for this input + # There are 2 rows (blocks): row0 (tokens 0-1), row1 (tokens 2-3) + # Fill these expected tensors based on your CUDA kernel's logic + # For demonstration, we assume: + # - block_count: how many slash indices fall into each block + # - block_offset: the value of those indices + # - column_count: number of valid vertical indices per block + # - column_index: the actual vertical indices + + expected_column_index = torch.tensor( + [[[[0, 0], [0, 0]]]], dtype=torch.int32, device="cuda" + ) + + # If causal=False, update these tensors according to expected behavior + if not causal: + # Update these values if your kernel produces different output in non-causal mode + expected_column_index = torch.tensor( + [[[[1, 0], [1, 3]]]], dtype=torch.int32, device="cuda" + ) + + # Assert that outputs match expectations + assert torch.equal(column_index, expected_column_index) + + +# mergehead +@pytest.mark.parametrize("causal", [True, False]) +def test_convert_vertical_slash_indexes_mergehead(causal): + # Prepare small, hand-checkable inputs for mergehead version + q_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda") + kv_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda") + vertical_indexes = torch.tensor( + [ + [ + [1, 3], # head 0 + [2, 0], # head 1 + ] + ], + dtype=torch.int32, + device="cuda", + ) # [BATCH, N_HEADS, NNZ_V] + slash_indexes = torch.tensor( + [ + [ + [2, 0], # head 0 + [1, 3], # head 1 + ] + ], + dtype=torch.int32, + device="cuda", + ) # [BATCH, N_HEADS, NNZ_S] + vertical_indices_count = torch.tensor([2, 1], dtype=torch.int32, device="cuda") + slash_indices_count = torch.tensor([1, 2], dtype=torch.int32, device="cuda") + context_size = 4 + block_size_M = 2 + block_size_N = 2 + + # Call your CUDA kernel wrapper + block_count, block_offset, column_count, column_index = ( + convert_vertical_slash_indexes_mergehead( + q_seqlens, + kv_seqlens, + vertical_indexes, + slash_indexes, + vertical_indices_count, + slash_indices_count, + context_size, + block_size_M, + block_size_N, + causal=causal, + ) + ) + + # Manually create expected outputs for this input + # For demonstration, assume: + # - batch=1, head=2, num_rows=2, nnz_v=2, nnz_s=2 + # Fill these expected tensors according to your kernel's behavior + + expected_column_index = torch.tensor( + [[[[1, 0], [1, 3]], [[-1079459945, -1077788999], [-1080050043, -1104625879]]]], + dtype=torch.int32, + device="cuda", + ) + + if not causal: + # If non-causal mode output is different, update these values + expected_column_index = torch.tensor( + [[[[1, 0], [1, 3]], [[2, -1077788999], [2, -1104625879]]]], + dtype=torch.int32, + device="cuda", + ) + + # Assert that outputs match expectations + assert torch.equal(column_index, expected_column_index) + + +# skip cause use fa2 for test # @pytest.mark.parametrize("seq_lens", [[(1024, 1328)], # [(1024, 1328), (1, 2048)], # [(1025, 1328), (2, 2048)],