[Feat] Add sparse attn to sgl-kernel (#5327)
This commit is contained in:
@@ -5,8 +5,6 @@ cmake_policy(SET CMP0169 OLD)
|
|||||||
|
|
||||||
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
||||||
|
|
||||||
set(BUILD_FA3, OFF)
|
|
||||||
|
|
||||||
find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
|
find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
|
||||||
|
|
||||||
enable_language(CUDA)
|
enable_language(CUDA)
|
||||||
@@ -80,7 +78,6 @@ include_directories(
|
|||||||
${repo-cutlass_SOURCE_DIR}/examples/common
|
${repo-cutlass_SOURCE_DIR}/examples/common
|
||||||
${repo-flashinfer_SOURCE_DIR}/include
|
${repo-flashinfer_SOURCE_DIR}/include
|
||||||
${repo-flashinfer_SOURCE_DIR}/csrc
|
${repo-flashinfer_SOURCE_DIR}/csrc
|
||||||
${repo-flash-attention_SOURCE_DIR}/hopper
|
|
||||||
)
|
)
|
||||||
|
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
@@ -115,6 +112,9 @@ option(SGL_KERNEL_ENABLE_BF16 "Enable BF16" ON)
|
|||||||
option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON)
|
option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON)
|
||||||
option(SGL_KERNEL_ENABLE_FP4 "Enable FP4" OFF)
|
option(SGL_KERNEL_ENABLE_FP4 "Enable FP4" OFF)
|
||||||
|
|
||||||
|
option(SGL_KERNEL_ENABLE_FA3 "Enable FA3" OFF)
|
||||||
|
|
||||||
|
|
||||||
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
|
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
|
||||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||||
"-gencode=arch=compute_100,code=sm_100"
|
"-gencode=arch=compute_100,code=sm_100"
|
||||||
@@ -127,7 +127,7 @@ else()
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A)
|
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A)
|
||||||
set(BUILD_FA3 ON)
|
set(SGL_KERNEL_ENABLE_FA3 ON)
|
||||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||||
"-gencode=arch=compute_90a,code=sm_90a"
|
"-gencode=arch=compute_90a,code=sm_90a"
|
||||||
)
|
)
|
||||||
@@ -187,11 +187,33 @@ set(SOURCES
|
|||||||
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
|
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
|
||||||
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
|
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
|
||||||
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
|
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
|
||||||
|
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu"
|
||||||
|
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu"
|
||||||
|
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu"
|
||||||
|
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_sm80.cu"
|
||||||
|
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/flash_sparse_api.cpp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
|
||||||
|
|
||||||
|
target_compile_options(common_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
|
||||||
|
target_include_directories(common_ops PRIVATE
|
||||||
|
${TORCH_INCLUDE_DIRS}
|
||||||
|
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src)
|
||||||
|
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt)
|
||||||
|
|
||||||
|
target_compile_definitions(common_ops PRIVATE
|
||||||
|
FLASHATTENTION_DISABLE_BACKWARD
|
||||||
|
FLASHATTENTION_DISABLE_DROPOUT
|
||||||
|
FLASHATTENTION_DISABLE_UNEVEN_K
|
||||||
|
)
|
||||||
|
|
||||||
|
install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel")
|
||||||
|
|
||||||
|
# ============================ Optional Install ============================= #
|
||||||
# set flash-attention sources file
|
# set flash-attention sources file
|
||||||
# BF16 source files
|
# BF16 source files
|
||||||
if (BUILD_FA3)
|
if (SGL_KERNEL_ENABLE_FA3)
|
||||||
set(SGL_FLASH_KERNEL_CUDA_FLAGS
|
set(SGL_FLASH_KERNEL_CUDA_FLAGS
|
||||||
"-DNDEBUG"
|
"-DNDEBUG"
|
||||||
"-DOPERATOR_NAMESPACE=sgl-kernel"
|
"-DOPERATOR_NAMESPACE=sgl-kernel"
|
||||||
@@ -246,7 +268,9 @@ if (BUILD_FA3)
|
|||||||
Python_add_library(flash_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${FLASH_SOURCES})
|
Python_add_library(flash_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${FLASH_SOURCES})
|
||||||
|
|
||||||
target_compile_options(flash_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_FLASH_KERNEL_CUDA_FLAGS}>)
|
target_compile_options(flash_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_FLASH_KERNEL_CUDA_FLAGS}>)
|
||||||
target_include_directories(flash_ops PRIVATE ${TORCH_INCLUDE_DIRS})
|
target_include_directories(flash_ops PRIVATE
|
||||||
|
${TORCH_INCLUDE_DIRS}
|
||||||
|
${repo-flash-attention_SOURCE_DIR}/hopper)
|
||||||
target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
|
target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
|
||||||
|
|
||||||
install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel")
|
install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel")
|
||||||
@@ -260,14 +284,6 @@ if (BUILD_FA3)
|
|||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
|
|
||||||
|
|
||||||
target_compile_options(common_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
|
|
||||||
target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS})
|
|
||||||
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt)
|
|
||||||
|
|
||||||
install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel")
|
|
||||||
|
|
||||||
# JIT Logic
|
# JIT Logic
|
||||||
# DeepGEMM
|
# DeepGEMM
|
||||||
|
|
||||||
|
|||||||
@@ -206,6 +206,28 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
|
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
|
||||||
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
|
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
|
||||||
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
|
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* From Sparse Flash Attention
|
||||||
|
*/
|
||||||
|
m.def(
|
||||||
|
"fwd_sparse(Tensor! q, Tensor k, Tensor v, "
|
||||||
|
"Tensor block_count, Tensor block_offset, Tensor column_count, Tensor column_index, "
|
||||||
|
"Tensor!? out, Tensor? alibi_slopes, "
|
||||||
|
"float p_dropout, float softmax_scale, bool is_causal, "
|
||||||
|
"float softcap, bool return_softmax, Generator? gen)"
|
||||||
|
"-> Tensor[]");
|
||||||
|
m.impl("fwd_sparse", torch::kCUDA, &flash::mha_fwd_sparse);
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"varlen_fwd_sparse(Tensor! q, Tensor k, Tensor v, "
|
||||||
|
"Tensor block_count, Tensor block_offset, Tensor column_count, Tensor column_index, "
|
||||||
|
"Tensor!? out, Tensor cu_seqlens_q, "
|
||||||
|
"Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? alibi_slopes, "
|
||||||
|
"int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, "
|
||||||
|
"bool is_causal, float softcap, bool return_softmax, "
|
||||||
|
"Generator? gen) -> Tensor[]");
|
||||||
|
m.impl("varlen_fwd_sparse", torch::kCUDA, &flash::mha_varlen_fwd_sparse);
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_EXTENSION(common_ops)
|
REGISTER_EXTENSION(common_ops)
|
||||||
|
|||||||
@@ -256,18 +256,21 @@ void min_p_sampling_from_probs(
|
|||||||
double min_p_val,
|
double min_p_val,
|
||||||
bool deterministic,
|
bool deterministic,
|
||||||
int64_t cuda_stream);
|
int64_t cuda_stream);
|
||||||
|
|
||||||
void top_k_renorm_probs(
|
void top_k_renorm_probs(
|
||||||
at::Tensor probs,
|
at::Tensor probs,
|
||||||
at::Tensor renorm_probs,
|
at::Tensor renorm_probs,
|
||||||
std::optional<at::Tensor> maybe_top_k_arr,
|
std::optional<at::Tensor> maybe_top_k_arr,
|
||||||
int64_t top_k_val,
|
int64_t top_k_val,
|
||||||
int64_t cuda_stream);
|
int64_t cuda_stream);
|
||||||
|
|
||||||
void top_p_renorm_probs(
|
void top_p_renorm_probs(
|
||||||
at::Tensor probs,
|
at::Tensor probs,
|
||||||
at::Tensor renorm_probs,
|
at::Tensor renorm_probs,
|
||||||
std::optional<at::Tensor> maybe_top_p_arr,
|
std::optional<at::Tensor> maybe_top_p_arr,
|
||||||
double top_p_val,
|
double top_p_val,
|
||||||
int64_t cuda_stream);
|
int64_t cuda_stream);
|
||||||
|
|
||||||
void top_k_top_p_sampling_from_probs(
|
void top_k_top_p_sampling_from_probs(
|
||||||
at::Tensor probs,
|
at::Tensor probs,
|
||||||
at::Tensor uniform_samples,
|
at::Tensor uniform_samples,
|
||||||
@@ -279,6 +282,7 @@ void top_k_top_p_sampling_from_probs(
|
|||||||
double top_p_val,
|
double top_p_val,
|
||||||
bool deterministic,
|
bool deterministic,
|
||||||
int64_t cuda_stream);
|
int64_t cuda_stream);
|
||||||
|
|
||||||
void top_p_sampling_from_probs(
|
void top_p_sampling_from_probs(
|
||||||
at::Tensor probs,
|
at::Tensor probs,
|
||||||
at::Tensor uniform_samples,
|
at::Tensor uniform_samples,
|
||||||
@@ -288,3 +292,49 @@ void top_p_sampling_from_probs(
|
|||||||
double top_p_val,
|
double top_p_val,
|
||||||
bool deterministic,
|
bool deterministic,
|
||||||
int64_t cuda_stream);
|
int64_t cuda_stream);
|
||||||
|
|
||||||
|
namespace flash {
|
||||||
|
/*
|
||||||
|
* From fa2 sparse
|
||||||
|
*/
|
||||||
|
std::vector<at::Tensor> mha_fwd_sparse(
|
||||||
|
at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
|
||||||
|
const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size
|
||||||
|
const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size
|
||||||
|
const at::Tensor& block_count,
|
||||||
|
const at::Tensor& block_offset,
|
||||||
|
const at::Tensor& column_count,
|
||||||
|
const at::Tensor& column_index,
|
||||||
|
const std::optional<at::Tensor>& out_, // batch_size x seqlen_q x num_heads x head_size
|
||||||
|
const std::optional<at::Tensor>& alibi_slopes_, // num_heads or batch_size x num_heads
|
||||||
|
const double p_dropout,
|
||||||
|
const double softmax_scale,
|
||||||
|
bool is_causal,
|
||||||
|
const double softcap,
|
||||||
|
const bool return_softmax,
|
||||||
|
std::optional<at::Generator> gen_);
|
||||||
|
|
||||||
|
std::vector<at::Tensor> mha_varlen_fwd_sparse(
|
||||||
|
at::Tensor& q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||||
|
const at::Tensor& k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i.
|
||||||
|
const at::Tensor& v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i.
|
||||||
|
const at::Tensor& block_count,
|
||||||
|
const at::Tensor& block_offset,
|
||||||
|
const at::Tensor& column_count,
|
||||||
|
const at::Tensor& column_index,
|
||||||
|
const c10::optional<at::Tensor>& out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||||
|
const at::Tensor& cu_seqlens_q, // b+1
|
||||||
|
const at::Tensor& cu_seqlens_k, // b+1
|
||||||
|
const c10::optional<at::Tensor>&
|
||||||
|
seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
||||||
|
const c10::optional<at::Tensor>& alibi_slopes_, // num_heads or b x num_heads
|
||||||
|
int64_t max_seqlen_q,
|
||||||
|
const int64_t max_seqlen_k,
|
||||||
|
const double p_dropout,
|
||||||
|
const double softmax_scale,
|
||||||
|
const bool zero_tensors,
|
||||||
|
bool is_causal,
|
||||||
|
const double softcap,
|
||||||
|
const bool return_softmax,
|
||||||
|
c10::optional<at::Generator> gen_);
|
||||||
|
} // namespace flash
|
||||||
|
|||||||
175
sgl-kernel/python/sgl_kernel/sparse_flash_attn.py
Normal file
175
sgl-kernel/python/sgl_kernel/sparse_flash_attn.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_contiguous(x):
|
||||||
|
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_attn_func(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
block_count,
|
||||||
|
block_offset,
|
||||||
|
column_count,
|
||||||
|
column_index,
|
||||||
|
dropout_p=0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=False,
|
||||||
|
softcap=0.0, # 0.0 means deactivated
|
||||||
|
alibi_slopes=None,
|
||||||
|
deterministic=False,
|
||||||
|
return_attn_probs=False,
|
||||||
|
*,
|
||||||
|
return_softmax_lse=False,
|
||||||
|
out=None,
|
||||||
|
):
|
||||||
|
"""Compute attention with vertical and slash sparsity patterns.
|
||||||
|
Most Arguments are the same with the flash_attn_func interface, except for 4 extra args:
|
||||||
|
block_count and block_offset for slash sparsity patterns, and
|
||||||
|
column_count and column_index for vertical sparsity patterns.
|
||||||
|
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
q: (batch_size, seqlen, nheads, headdim)
|
||||||
|
k: (batch_size, seqlen, nheads_k, headdim)
|
||||||
|
v: (batch_size, seqlen, nheads_k, headdim)
|
||||||
|
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||||||
|
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
|
||||||
|
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||||||
|
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
|
||||||
|
dropout_p: float. Dropout probability.
|
||||||
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||||
|
Default to 1 / sqrt(headdim).
|
||||||
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||||
|
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
||||||
|
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
||||||
|
is added to the attention score of query i and key j.
|
||||||
|
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
||||||
|
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
||||||
|
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||||
|
testing only. The returned probabilities are not guaranteed to be correct
|
||||||
|
(they might not have the right scaling).
|
||||||
|
Return:
|
||||||
|
out: (batch_size, seqlen, nheads, headdim).
|
||||||
|
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
|
||||||
|
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||||
|
normalization factor).
|
||||||
|
"""
|
||||||
|
if softmax_scale is None:
|
||||||
|
softmax_scale = q.shape[-1] ** (-0.5)
|
||||||
|
|
||||||
|
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||||
|
out, softmax_lse = torch.ops.sgl_kernel.fwd_sparse.default(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
block_count,
|
||||||
|
block_offset,
|
||||||
|
column_count,
|
||||||
|
column_index,
|
||||||
|
out,
|
||||||
|
alibi_slopes,
|
||||||
|
dropout_p,
|
||||||
|
softmax_scale,
|
||||||
|
causal,
|
||||||
|
softcap,
|
||||||
|
return_attn_probs and dropout_p > 0,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
return (out, softmax_lse) if return_softmax_lse else out
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_attn_varlen_func(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
block_count,
|
||||||
|
block_offset,
|
||||||
|
column_count,
|
||||||
|
column_index,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
dropout_p=0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=False,
|
||||||
|
softcap=0.0, # 0.0 means deactivated
|
||||||
|
alibi_slopes=None,
|
||||||
|
deterministic=False,
|
||||||
|
return_attn_probs=False,
|
||||||
|
*,
|
||||||
|
return_softmax_lse=False,
|
||||||
|
out=None,
|
||||||
|
):
|
||||||
|
"""Compute attention with vertical and slash sparsity patterns.
|
||||||
|
Most Arguments are the same with the flash_attn_varlen_func interface, except for 4 extra args:
|
||||||
|
block_count and block_offset for slash sparsity patterns, and
|
||||||
|
column_count and column_index for vertical sparsity patterns.
|
||||||
|
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
||||||
|
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||||||
|
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||||||
|
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||||||
|
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
|
||||||
|
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||||||
|
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
|
||||||
|
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||||
|
of the sequences in the batch, used to index into q.
|
||||||
|
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||||
|
of the sequences in the batch, used to index into kv.
|
||||||
|
max_seqlen_q: int. Maximum query sequence length in the batch.
|
||||||
|
max_seqlen_k: int. Maximum key sequence length in the batch.
|
||||||
|
dropout_p: float. Dropout probability.
|
||||||
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||||
|
Default to 1 / sqrt(headdim).
|
||||||
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||||
|
softcap: float. Anything > 0 activates softcapping attention.
|
||||||
|
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
||||||
|
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
||||||
|
is added to the attention score of query i and key j.
|
||||||
|
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
||||||
|
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
||||||
|
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||||
|
testing only. The returned probabilities are not guaranteed to be correct
|
||||||
|
(they might not have the right scaling).
|
||||||
|
Return:
|
||||||
|
out: (total, nheads, headdim).
|
||||||
|
softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
|
||||||
|
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||||
|
normalization factor).
|
||||||
|
"""
|
||||||
|
if softmax_scale is None:
|
||||||
|
softmax_scale = q.shape[-1] ** (-0.5)
|
||||||
|
|
||||||
|
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||||
|
out, softmax_lse = torch.ops.sgl_kernel.varlen_fwd_sparse.default(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
block_count,
|
||||||
|
block_offset,
|
||||||
|
column_count,
|
||||||
|
column_index,
|
||||||
|
out,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
None,
|
||||||
|
alibi_slopes,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
dropout_p,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
causal,
|
||||||
|
softcap,
|
||||||
|
return_attn_probs and dropout_p > 0,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
return (out, softmax_lse) if return_softmax_lse else out
|
||||||
348
sgl-kernel/tests/test_sparse_flash_attn.py
Normal file
348
sgl-kernel/tests/test_sparse_flash_attn.py
Normal file
@@ -0,0 +1,348 @@
|
|||||||
|
import math
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from sgl_kernel.sparse_flash_attn import sparse_attn_func, sparse_attn_varlen_func
|
||||||
|
|
||||||
|
|
||||||
|
def ref_attn(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
query_padding_mask=None,
|
||||||
|
key_padding_mask=None,
|
||||||
|
attn_bias=None,
|
||||||
|
dropout_p=0.0,
|
||||||
|
dropout_mask=None,
|
||||||
|
causal=False,
|
||||||
|
window_size=(-1, -1), # -1 means infinite window size
|
||||||
|
softcap=0.0,
|
||||||
|
upcast=True,
|
||||||
|
reorder_ops=False,
|
||||||
|
key_leftpad=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
q: (batch_size, seqlen_q, nheads, head_dim)
|
||||||
|
k: (batch_size, seqlen_k, nheads_k, head_dim)
|
||||||
|
v: (batch_size, seqlen_k, nheads_k, head_dim)
|
||||||
|
query_padding_mask: (batch_size, seqlen_q)
|
||||||
|
key_padding_mask: (batch_size, seqlen_k)
|
||||||
|
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
|
||||||
|
dropout_p: float
|
||||||
|
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
|
||||||
|
causal: whether to apply causal masking
|
||||||
|
window_size: (int, int), left and right window size
|
||||||
|
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
|
||||||
|
output back to fp16/bf16.
|
||||||
|
reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
|
||||||
|
without changing the math. This is to estimate the numerical error from operation
|
||||||
|
reordering.
|
||||||
|
Output:
|
||||||
|
output: (batch_size, seqlen_q, nheads, head_dim)
|
||||||
|
lse: (batch_size, nheads, seqlen_q)
|
||||||
|
"""
|
||||||
|
if causal:
|
||||||
|
window_size = (window_size[0], 0)
|
||||||
|
dtype_og = q.dtype
|
||||||
|
if upcast:
|
||||||
|
q, k, v = q.float(), k.float(), v.float()
|
||||||
|
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
|
||||||
|
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
|
||||||
|
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
|
||||||
|
d = q.shape[-1]
|
||||||
|
if not reorder_ops:
|
||||||
|
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
|
||||||
|
else:
|
||||||
|
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
|
||||||
|
|
||||||
|
lse_ref = scores.logsumexp(dim=-1)
|
||||||
|
|
||||||
|
if softcap > 0:
|
||||||
|
scores = scores / softcap
|
||||||
|
scores = scores.tanh()
|
||||||
|
scores = scores * softcap
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
scores.masked_fill_(
|
||||||
|
rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")
|
||||||
|
)
|
||||||
|
if window_size[0] >= 0 or window_size[1] >= 0:
|
||||||
|
local_mask = construct_local_mask(
|
||||||
|
seqlen_q,
|
||||||
|
seqlen_k,
|
||||||
|
window_size,
|
||||||
|
query_padding_mask,
|
||||||
|
key_padding_mask,
|
||||||
|
q.device,
|
||||||
|
key_leftpad=key_leftpad,
|
||||||
|
)
|
||||||
|
scores.masked_fill_(local_mask, float("-inf"))
|
||||||
|
if attn_bias is not None:
|
||||||
|
scores = scores + attn_bias
|
||||||
|
attention = torch.softmax(scores, dim=-1).to(v.dtype)
|
||||||
|
# Some rows might be completely masked out so we fill them with zero instead of NaN
|
||||||
|
if window_size[0] >= 0 or window_size[1] >= 0:
|
||||||
|
attention = attention.masked_fill(
|
||||||
|
torch.all(local_mask, dim=-1, keepdim=True), 0.0
|
||||||
|
)
|
||||||
|
# We want to mask here so that the attention matrix doesn't have any NaNs
|
||||||
|
# Otherwise we'll get NaN in dV
|
||||||
|
if query_padding_mask is not None:
|
||||||
|
attention = attention.masked_fill(
|
||||||
|
rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0
|
||||||
|
)
|
||||||
|
dropout_scaling = 1.0 / (1 - dropout_p)
|
||||||
|
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
|
||||||
|
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
|
||||||
|
if dropout_mask is not None:
|
||||||
|
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
|
||||||
|
else:
|
||||||
|
attention_drop = attention
|
||||||
|
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
|
||||||
|
if query_padding_mask is not None:
|
||||||
|
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
|
||||||
|
|
||||||
|
return output.to(dtype=dtype_og), lse_ref
|
||||||
|
|
||||||
|
|
||||||
|
def ref_paged_attn(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
query_lens: List[int],
|
||||||
|
kv_lens: List[int],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
soft_cap: Optional[float] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
num_seqs = len(query_lens)
|
||||||
|
block_tables = block_tables.cpu().numpy()
|
||||||
|
_, block_size, num_kv_heads, head_size = key_cache.shape
|
||||||
|
|
||||||
|
outputs: List[torch.Tensor] = []
|
||||||
|
start_idx = 0
|
||||||
|
for i in range(num_seqs):
|
||||||
|
query_len = query_lens[i]
|
||||||
|
kv_len = kv_lens[i]
|
||||||
|
# clone to avoid clobbering the query tensor
|
||||||
|
q = query[start_idx : start_idx + query_len].clone()
|
||||||
|
q *= scale
|
||||||
|
|
||||||
|
num_kv_blocks = (kv_len + block_size - 1) // block_size
|
||||||
|
block_indices = block_tables[i, :num_kv_blocks]
|
||||||
|
|
||||||
|
k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
|
||||||
|
k = k[:kv_len]
|
||||||
|
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
|
||||||
|
v = v[:kv_len]
|
||||||
|
|
||||||
|
if q.shape[1] != k.shape[1]:
|
||||||
|
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
|
||||||
|
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
|
||||||
|
attn = torch.einsum("qhd,khd->hqk", q, k).float()
|
||||||
|
empty_mask = torch.ones(query_len, kv_len)
|
||||||
|
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
|
||||||
|
if sliding_window is not None:
|
||||||
|
sliding_window_mask = (
|
||||||
|
torch.triu(
|
||||||
|
empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1
|
||||||
|
)
|
||||||
|
.bool()
|
||||||
|
.logical_not()
|
||||||
|
)
|
||||||
|
mask |= sliding_window_mask
|
||||||
|
if soft_cap is not None:
|
||||||
|
attn = soft_cap * torch.tanh(attn / soft_cap)
|
||||||
|
attn.masked_fill_(mask, float("-inf"))
|
||||||
|
attn = torch.softmax(attn, dim=-1).to(v.dtype)
|
||||||
|
out = torch.einsum("hqk,khd->qhd", attn, v)
|
||||||
|
|
||||||
|
outputs.append(out)
|
||||||
|
start_idx += query_len
|
||||||
|
|
||||||
|
return torch.cat(outputs, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 2])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"seq_lens",
|
||||||
|
[
|
||||||
|
(1, 1),
|
||||||
|
(1, 1024),
|
||||||
|
(1, 2048),
|
||||||
|
(1023, 2049),
|
||||||
|
(1023, 1023),
|
||||||
|
(32, 32),
|
||||||
|
(65, 65),
|
||||||
|
(129, 129),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("num_heads", [1, 2, 4])
|
||||||
|
@pytest.mark.parametrize("head_size", [128])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
|
@pytest.mark.parametrize("NNZ_S", [0, 1, 2, 3, 7, 15, 32])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_sparse_attention(
|
||||||
|
batch_size,
|
||||||
|
seq_lens,
|
||||||
|
num_heads,
|
||||||
|
head_size,
|
||||||
|
dtype,
|
||||||
|
NNZ_S,
|
||||||
|
) -> None:
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
torch.cuda.manual_seed_all(0)
|
||||||
|
block_size_M = 64
|
||||||
|
block_size_N = 64
|
||||||
|
seqlen_q, seqlen_k = seq_lens
|
||||||
|
q = torch.randn(
|
||||||
|
batch_size, seqlen_q, num_heads, head_size, dtype=dtype, requires_grad=False
|
||||||
|
)
|
||||||
|
k = torch.randn(
|
||||||
|
batch_size, seqlen_k, num_heads, head_size, dtype=dtype, requires_grad=False
|
||||||
|
)
|
||||||
|
v = torch.randn(
|
||||||
|
batch_size, seqlen_k, num_heads, head_size, dtype=dtype, requires_grad=False
|
||||||
|
)
|
||||||
|
NUM_ROWS = (seqlen_q + block_size_M - 1) // block_size_M
|
||||||
|
if NNZ_S * block_size_N > seqlen_k:
|
||||||
|
return
|
||||||
|
NNZ_V = seqlen_k - NNZ_S * block_size_N
|
||||||
|
block_count = torch.tensor(
|
||||||
|
[NNZ_S] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32
|
||||||
|
).reshape(batch_size, num_heads, NUM_ROWS)
|
||||||
|
column_count = torch.tensor(
|
||||||
|
[NNZ_V] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32
|
||||||
|
).reshape(batch_size, num_heads, NUM_ROWS)
|
||||||
|
block_offset = torch.tensor(
|
||||||
|
[[i * block_size_N for i in range(NNZ_S)]] * batch_size * NUM_ROWS * num_heads,
|
||||||
|
dtype=torch.int32,
|
||||||
|
).reshape(batch_size, num_heads, NUM_ROWS, NNZ_S)
|
||||||
|
column_index = torch.tensor(
|
||||||
|
[[NNZ_S * block_size_N + i for i in range(NNZ_V)]]
|
||||||
|
* batch_size
|
||||||
|
* NUM_ROWS
|
||||||
|
* num_heads,
|
||||||
|
dtype=torch.int32,
|
||||||
|
).reshape(batch_size, num_heads, NUM_ROWS, NNZ_V)
|
||||||
|
out, lse = sparse_attn_func(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
block_count,
|
||||||
|
block_offset,
|
||||||
|
column_count,
|
||||||
|
column_index,
|
||||||
|
return_softmax_lse=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
ref_out, ref_lse = ref_attn(q, k, v)
|
||||||
|
|
||||||
|
torch.testing.assert_close(
|
||||||
|
out, ref_out, atol=2e-2, rtol=1e-2
|
||||||
|
), f"{torch.max(torch.abs(out - ref_out))}"
|
||||||
|
torch.testing.assert_close(
|
||||||
|
lse, ref_lse, atol=2e-2, rtol=1e-2
|
||||||
|
), f"{torch.max(torch.abs(lse - ref_lse))}"
|
||||||
|
|
||||||
|
|
||||||
|
# @pytest.mark.parametrize("seq_lens", [[(1024, 1328)],
|
||||||
|
# [(1024, 1328), (1, 2048)],
|
||||||
|
# [(1025, 1328), (2, 2048)],
|
||||||
|
# [(1025, 2049), (2, 1281)],
|
||||||
|
# ])
|
||||||
|
# @pytest.mark.parametrize("head_size", [128])
|
||||||
|
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
|
# @torch.inference_mode()
|
||||||
|
# def test_sparse_attention_varlen(
|
||||||
|
# seq_lens,
|
||||||
|
# head_size,
|
||||||
|
# dtype,
|
||||||
|
# ) -> None:
|
||||||
|
# torch.set_default_device("cuda")
|
||||||
|
# torch.cuda.manual_seed_all(0)
|
||||||
|
# block_size_M = 64
|
||||||
|
# block_size_N = 64
|
||||||
|
# num_seqs = len(seq_lens)
|
||||||
|
# query_lens = [x[0] for x in seq_lens]
|
||||||
|
# kv_lens = [x[1] for x in seq_lens]
|
||||||
|
# num_heads = 1
|
||||||
|
# query = torch.randn(sum(query_lens),
|
||||||
|
# num_heads,
|
||||||
|
# head_size,
|
||||||
|
# dtype=dtype)
|
||||||
|
# key = torch.randn(sum(kv_lens),
|
||||||
|
# num_heads,
|
||||||
|
# head_size,
|
||||||
|
# dtype=dtype)
|
||||||
|
# value = torch.randn_like(key)
|
||||||
|
# cu_query_lens = torch.tensor([0] + query_lens,
|
||||||
|
# dtype=torch.int32).cumsum(dim=0,
|
||||||
|
# dtype=torch.int32)
|
||||||
|
# cu_kv_lens = torch.tensor([0] + kv_lens,
|
||||||
|
# dtype=torch.int32).cumsum(dim=0,
|
||||||
|
# dtype=torch.int32)
|
||||||
|
# max_query_len = max(query_lens)
|
||||||
|
# max_kv_len = max(kv_lens)
|
||||||
|
|
||||||
|
# NUM_ROWS = (max_query_len + block_size_M - 1) // block_size_M
|
||||||
|
# NNZ_S = 20
|
||||||
|
# NNZ_V = 2048
|
||||||
|
# batch_size = len(query_lens)
|
||||||
|
|
||||||
|
# block_counts = []
|
||||||
|
# column_counts = []
|
||||||
|
# block_offsets = []
|
||||||
|
# column_indices = []
|
||||||
|
# for b in range(batch_size):
|
||||||
|
# block_counts.append(torch.tensor([NNZ_S] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS))
|
||||||
|
# columns = kv_lens[b] - NNZ_S * block_size_N
|
||||||
|
# column_counts.append(torch.tensor([columns] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS))
|
||||||
|
# block_offsets.append(torch.tensor([[i * block_size_N for i in range(NNZ_S)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS, NNZ_S))
|
||||||
|
# column_indices.append(torch.tensor([[NNZ_S * block_size_N + i for i in range(NNZ_V)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS, NNZ_V))
|
||||||
|
# block_count = torch.concat(block_counts).reshape(batch_size, num_heads, NUM_ROWS)
|
||||||
|
# column_count = torch.concat(column_counts).reshape(batch_size, num_heads, NUM_ROWS)
|
||||||
|
# block_offset = torch.concat(block_offsets).reshape(batch_size, num_heads, NUM_ROWS, NNZ_S)
|
||||||
|
# column_index = torch.concat(column_indices).reshape(batch_size, num_heads, NUM_ROWS, NNZ_V)
|
||||||
|
# out, lse = sparse_attn_varlen_func(
|
||||||
|
# query,
|
||||||
|
# key,
|
||||||
|
# value,
|
||||||
|
# block_count,
|
||||||
|
# block_offset,
|
||||||
|
# column_count,
|
||||||
|
# column_index,
|
||||||
|
# cu_seqlens_q=cu_query_lens,
|
||||||
|
# cu_seqlens_k=cu_kv_lens,
|
||||||
|
# max_seqlen_q=max_query_len,
|
||||||
|
# max_seqlen_k=max_kv_len,
|
||||||
|
# return_softmax_lse=True,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# max_num_blocks_per_seq = (max_kv_len + 2048 - 1) // 2048
|
||||||
|
# block_tables = torch.randint(0,
|
||||||
|
# 2048,
|
||||||
|
# (len(query_lens), max_num_blocks_per_seq),
|
||||||
|
# dtype=torch.int32)
|
||||||
|
# scale = head_size**-0.5
|
||||||
|
|
||||||
|
# ref_out, ref_lse, _ = ref_paged_attn(
|
||||||
|
# query,
|
||||||
|
# key,
|
||||||
|
# value,
|
||||||
|
# query_lens=query_lens,
|
||||||
|
# kv_lens=kv_lens,
|
||||||
|
# block_tables=block_tables,
|
||||||
|
# scale=scale
|
||||||
|
# )
|
||||||
|
|
||||||
|
# torch.testing.assert_close(out, ref_out, atol=2e-2, rtol=1e-2), \
|
||||||
|
# f"{torch.max(torch.abs(out - ref_out))}"
|
||||||
|
# torch.testing.assert_close(lse, ref_lse, atol=2e-2, rtol=1e-2), \
|
||||||
|
# f"{torch.max(torch.abs(lse - ref_lse))}"
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
||||||
Reference in New Issue
Block a user