[Fix] fix fa3 build at cu118 (#5036)
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
cmake_minimum_required(VERSION 3.26 FATAL_ERROR)
|
||||
project(sgl-kernel LANGUAGES CXX CUDA)
|
||||
|
||||
# we only want to download 3rd, but not build them.
|
||||
# FetchContent_MakeAvailable will build it.
|
||||
cmake_policy(SET CMP0169 OLD)
|
||||
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
||||
|
||||
set(BUILD_FA3, OFF)
|
||||
|
||||
find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
|
||||
|
||||
enable_language(CUDA)
|
||||
@@ -22,6 +24,8 @@ elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "11.8")
|
||||
endif()
|
||||
|
||||
find_package(Torch REQUIRED)
|
||||
# clean Torch Flag
|
||||
clear_cuda_arches(CMAKE_FLAG)
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
@@ -53,8 +57,8 @@ FetchContent_Populate(repo-flashinfer)
|
||||
FetchContent_Declare(
|
||||
repo-flash-attention
|
||||
GIT_REPOSITORY https://github.com/sgl-project/sgl-attn
|
||||
GIT_TAG sgl-kernel
|
||||
GIT_SHALLOW OFF
|
||||
GIT_TAG sgl-kernel
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-flash-attention)
|
||||
|
||||
@@ -92,14 +96,13 @@ set(SGL_KERNEL_CUDA_FLAGS
|
||||
"-gencode=arch=compute_90,code=sm_90"
|
||||
"-std=c++17"
|
||||
"-DFLASHINFER_ENABLE_F16"
|
||||
"-DCUTE_USE_PACKED_TUPLE=1"
|
||||
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
|
||||
"-DCUTLASS_VERSIONS_GENERATED"
|
||||
"-DCUTE_USE_PACKED_TUPLE=1"
|
||||
"-DCUTLASS_TEST_LEVEL=0"
|
||||
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
|
||||
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
|
||||
"--expt-relaxed-constexpr"
|
||||
"--use_fast_math"
|
||||
"-Xcompiler=-Wconversion"
|
||||
"-Xcompiler=-fno-strict-aliasing"
|
||||
)
|
||||
@@ -122,6 +125,7 @@ else()
|
||||
endif()
|
||||
|
||||
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A)
|
||||
set(BUILD_FA3 ON)
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-gencode=arch=compute_90a,code=sm_90a"
|
||||
)
|
||||
@@ -152,30 +156,6 @@ string(REPLACE "-D__CUDA_NO_HALF_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE
|
||||
string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
||||
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
||||
|
||||
# set flash-attention sources file
|
||||
# BF16 source files
|
||||
file(GLOB FA3_BF16_GEN_SRCS
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
|
||||
file(GLOB FA3_BF16_GEN_SRCS_
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu")
|
||||
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})
|
||||
|
||||
# FP16 source files
|
||||
file(GLOB FA3_FP16_GEN_SRCS
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu")
|
||||
file(GLOB FA3_FP16_GEN_SRCS_
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu")
|
||||
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})
|
||||
|
||||
# FP8 source files
|
||||
file(GLOB FA3_FP8_GEN_SRCS
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu")
|
||||
file(GLOB FA3_FP8_GEN_SRCS_
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu")
|
||||
list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_})
|
||||
|
||||
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS})
|
||||
|
||||
set(SOURCES
|
||||
"csrc/allreduce/trt_reduce_internal.cu"
|
||||
"csrc/allreduce/trt_reduce_kernel.cu"
|
||||
@@ -202,39 +182,94 @@ set(SOURCES
|
||||
"csrc/speculative/eagle_utils.cu"
|
||||
"csrc/speculative/speculative_sampling.cu"
|
||||
"csrc/speculative/packbit.cu"
|
||||
"csrc/torch_extension.cc"
|
||||
"csrc/common_extension.cc"
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_prepare_scheduler.cu"
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_api.cpp"
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_fwd_combine.cu"
|
||||
"${FA3_GEN_SRCS}"
|
||||
)
|
||||
|
||||
# Support abi3 for build
|
||||
# set flash-attention sources file
|
||||
# BF16 source files
|
||||
if (BUILD_FA3)
|
||||
set(SGL_FLASH_KERNEL_CUDA_FLAGS
|
||||
"-DNDEBUG"
|
||||
"-DOPERATOR_NAMESPACE=sgl-kernel"
|
||||
"-O3"
|
||||
"-Xcompiler"
|
||||
"-fPIC"
|
||||
"-gencode=arch=compute_90a,code=sm_90a"
|
||||
"-std=c++17"
|
||||
"-DCUTE_USE_PACKED_TUPLE=1"
|
||||
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
|
||||
"-DCUTLASS_VERSIONS_GENERATED"
|
||||
"-DCUTLASS_TEST_LEVEL=0"
|
||||
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
|
||||
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
|
||||
"--expt-relaxed-constexpr"
|
||||
"--expt-extended-lambda"
|
||||
"--use_fast_math"
|
||||
"-Xcompiler=-Wconversion"
|
||||
"-Xcompiler=-fno-strict-aliasing"
|
||||
)
|
||||
|
||||
file(GLOB FA3_BF16_GEN_SRCS
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
|
||||
file(GLOB FA3_BF16_GEN_SRCS_
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu")
|
||||
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})
|
||||
|
||||
# FP16 source files
|
||||
file(GLOB FA3_FP16_GEN_SRCS
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu")
|
||||
file(GLOB FA3_FP16_GEN_SRCS_
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu")
|
||||
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})
|
||||
|
||||
# FP8 source files
|
||||
file(GLOB FA3_FP8_GEN_SRCS
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu")
|
||||
file(GLOB FA3_FP8_GEN_SRCS_
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu")
|
||||
list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_})
|
||||
|
||||
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS})
|
||||
|
||||
set(FLASH_SOURCES
|
||||
"csrc/flash_extension.cc"
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_prepare_scheduler.cu"
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_api.cpp"
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_fwd_combine.cu"
|
||||
"${FA3_GEN_SRCS}"
|
||||
)
|
||||
|
||||
Python_add_library(flash_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${FLASH_SOURCES})
|
||||
|
||||
target_compile_options(flash_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_FLASH_KERNEL_CUDA_FLAGS}>)
|
||||
target_include_directories(flash_ops PRIVATE ${TORCH_INCLUDE_DIRS})
|
||||
target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
|
||||
|
||||
install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel")
|
||||
|
||||
target_compile_definitions(flash_ops PRIVATE
|
||||
FLASHATTENTION_DISABLE_SM8x
|
||||
FLASHATTENTION_DISABLE_BACKWARD
|
||||
FLASHATTENTION_DISABLE_DROPOUT
|
||||
# FLASHATTENTION_DISABLE_ALIBI
|
||||
# FLASHATTENTION_DISABLE_SOFTCAP
|
||||
FLASHATTENTION_DISABLE_UNEVEN_K
|
||||
# FLASHATTENTION_DISABLE_LOCAL
|
||||
FLASHATTENTION_VARLEN_ONLY
|
||||
)
|
||||
endif()
|
||||
|
||||
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
|
||||
|
||||
target_compile_options(common_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
|
||||
|
||||
target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS})
|
||||
|
||||
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt)
|
||||
|
||||
install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel")
|
||||
|
||||
# Add some flash-attention custom flag for inference
|
||||
target_compile_definitions(common_ops PRIVATE
|
||||
FLASHATTENTION_DISABLE_SM8x
|
||||
FLASHATTENTION_DISABLE_BACKWARD
|
||||
FLASHATTENTION_DISABLE_DROPOUT
|
||||
# FLASHATTENTION_DISABLE_ALIBI
|
||||
# FLASHATTENTION_DISABLE_SOFTCAP
|
||||
FLASHATTENTION_DISABLE_UNEVEN_K
|
||||
# FLASHATTENTION_DISABLE_LOCAL
|
||||
FLASHATTENTION_VARLEN_ONLY
|
||||
)
|
||||
|
||||
# JIT Logic
|
||||
# DeepGEMM
|
||||
|
||||
|
||||
21
sgl-kernel/cmake/utils.cmake
Normal file
21
sgl-kernel/cmake/utils.cmake
Normal file
@@ -0,0 +1,21 @@
|
||||
# Adapt from: https://github.com/neuralmagic/vllm-flash-attention/blob/main/cmake/utils.cmake
|
||||
#
|
||||
# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in
|
||||
# `CUDA_ARCH_FLAGS`.
|
||||
#
|
||||
# Example:
|
||||
# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75"
|
||||
# clear_cuda_arches(CUDA_ARCH_FLAGS)
|
||||
# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75"
|
||||
# CMAKE_CUDA_FLAGS="-Wall"
|
||||
#
|
||||
macro(clear_cuda_arches CUDA_ARCH_FLAGS)
|
||||
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
|
||||
string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS
|
||||
${CMAKE_CUDA_FLAGS})
|
||||
|
||||
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
|
||||
# and passed back via the `CUDA_ARCHITECTURES` property.
|
||||
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
|
||||
${CMAKE_CUDA_FLAGS})
|
||||
endmacro()
|
||||
@@ -18,7 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "sgl_kernel_ops.h"
|
||||
|
||||
TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
/*
|
||||
* From csrc/allreduce
|
||||
*/
|
||||
@@ -202,45 +202,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
|
||||
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
|
||||
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
|
||||
|
||||
/*
|
||||
* From flash-attention
|
||||
*/
|
||||
m.def(
|
||||
"fwd(Tensor! q,"
|
||||
" Tensor k,"
|
||||
" Tensor v,"
|
||||
" Tensor? k_new,"
|
||||
" Tensor? v_new,"
|
||||
" Tensor? q_v,"
|
||||
" Tensor!? out,"
|
||||
" Tensor? cu_seqlens_q,"
|
||||
" Tensor? cu_seqlens_k,"
|
||||
" Tensor? cu_seqlens_k_new,"
|
||||
" Tensor? seqused_q,"
|
||||
" Tensor? seqused_k,"
|
||||
" int? max_seqlen_q,"
|
||||
" int? max_seqlen_k,"
|
||||
" Tensor? page_table,"
|
||||
" Tensor? kv_batch_idx,"
|
||||
" Tensor? leftpad_k,"
|
||||
" Tensor? rotary_cos,"
|
||||
" Tensor? rotary_sin,"
|
||||
" Tensor? seqlens_rotary,"
|
||||
" Tensor? q_descale,"
|
||||
" Tensor? k_descale,"
|
||||
" Tensor? v_descale,"
|
||||
" float softmax_scale,"
|
||||
" bool is_causal,"
|
||||
" int window_size_left,"
|
||||
" int window_size_right,"
|
||||
" float softcap,"
|
||||
" bool is_rotary_interleaved,"
|
||||
" Tensor? scheduler_metadata,"
|
||||
" int num_splits,"
|
||||
" bool? pack_gqa,"
|
||||
" int sm_margin) -> Tensor[]");
|
||||
m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(common_ops)
|
||||
62
sgl-kernel/csrc/flash_extension.cc
Normal file
62
sgl-kernel/csrc/flash_extension.cc
Normal file
@@ -0,0 +1,62 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <torch/all.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "sgl_flash_kernel_ops.h"
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
/*
|
||||
* From flash-attention
|
||||
*/
|
||||
m.def(
|
||||
"fwd(Tensor! q,"
|
||||
" Tensor k,"
|
||||
" Tensor v,"
|
||||
" Tensor? k_new,"
|
||||
" Tensor? v_new,"
|
||||
" Tensor? q_v,"
|
||||
" Tensor!? out,"
|
||||
" Tensor? cu_seqlens_q,"
|
||||
" Tensor? cu_seqlens_k,"
|
||||
" Tensor? cu_seqlens_k_new,"
|
||||
" Tensor? seqused_q,"
|
||||
" Tensor? seqused_k,"
|
||||
" int? max_seqlen_q,"
|
||||
" int? max_seqlen_k,"
|
||||
" Tensor? page_table,"
|
||||
" Tensor? kv_batch_idx,"
|
||||
" Tensor? leftpad_k,"
|
||||
" Tensor? rotary_cos,"
|
||||
" Tensor? rotary_sin,"
|
||||
" Tensor? seqlens_rotary,"
|
||||
" Tensor? q_descale,"
|
||||
" Tensor? k_descale,"
|
||||
" Tensor? v_descale,"
|
||||
" float softmax_scale,"
|
||||
" bool is_causal,"
|
||||
" int window_size_left,"
|
||||
" int window_size_right,"
|
||||
" float softcap,"
|
||||
" bool is_rotary_interleaved,"
|
||||
" Tensor? scheduler_metadata,"
|
||||
" int num_splits,"
|
||||
" bool? pack_gqa,"
|
||||
" int sm_margin) -> Tensor[]");
|
||||
m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(flash_ops)
|
||||
85
sgl-kernel/include/sgl_flash_kernel_ops.h
Normal file
85
sgl-kernel/include/sgl_flash_kernel_ops.h
Normal file
@@ -0,0 +1,85 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <Python.h>
|
||||
#include <torch/library.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "sgl_kernel_torch_shim.h"
|
||||
|
||||
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
||||
|
||||
#define _CONCAT(A, B) A##B
|
||||
#define CONCAT(A, B) _CONCAT(A, B)
|
||||
|
||||
#define _STRINGIFY(A) #A
|
||||
#define STRINGIFY(A) _STRINGIFY(A)
|
||||
|
||||
#define REGISTER_EXTENSION(NAME) \
|
||||
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
||||
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
||||
return PyModule_Create(&module); \
|
||||
}
|
||||
|
||||
/*
|
||||
* From flash-attention
|
||||
*/
|
||||
std::vector<at::Tensor> mha_fwd(
|
||||
at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
|
||||
const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,
|
||||
// h_k, d) if there is page_table.
|
||||
const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
|
||||
// page_size, h_k, dv) if there is page_table.
|
||||
std::optional<const at::Tensor>&
|
||||
k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
|
||||
std::optional<const at::Tensor>&
|
||||
v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
|
||||
std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
|
||||
std::optional<at::Tensor>& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
|
||||
std::optional<const at::Tensor>& cu_seqlens_q_, // b+1
|
||||
std::optional<const at::Tensor>& cu_seqlens_k_, // b+1
|
||||
std::optional<const at::Tensor>& cu_seqlens_k_new_, // b+1
|
||||
std::optional<const at::Tensor>&
|
||||
seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
|
||||
std::optional<const at::Tensor>&
|
||||
seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
|
||||
std::optional<int> max_seqlen_q_,
|
||||
// TODO: check if we need max_seqlen_k
|
||||
std::optional<int> max_seqlen_k_,
|
||||
std::optional<const at::Tensor>& page_table_, // (b_k, max_num_pages_per_seq)
|
||||
std::optional<const at::Tensor>& kv_batch_idx_, // b. indices to index into the KV cache
|
||||
std::optional<const at::Tensor>& leftpad_k_, // b
|
||||
std::optional<const at::Tensor>& rotary_cos_, // seqlen_ro x (rotary_dim / 2)
|
||||
std::optional<const at::Tensor>& rotary_sin_, // seqlen_ro x (rotary_dim / 2)
|
||||
std::optional<const at::Tensor>& seqlens_rotary_, // b
|
||||
std::optional<at::Tensor>& q_descale_, // (b, h_k), not (b, h)
|
||||
std::optional<at::Tensor>& k_descale_, // (b, h_k)
|
||||
std::optional<at::Tensor>& v_descale_, // (b, h_k)
|
||||
float const softmax_scale,
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
float const softcap,
|
||||
bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
|
||||
std::optional<at::Tensor>& scheduler_metadata_, // (b + 1)
|
||||
int num_splits,
|
||||
std::optional<bool> pack_gqa_,
|
||||
int const sm_margin);
|
||||
@@ -23,8 +23,6 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "sgl_kernel_torch_shim.h"
|
||||
|
||||
#define _CONCAT(A, B) A##B
|
||||
#define CONCAT(A, B) _CONCAT(A, B)
|
||||
|
||||
@@ -293,48 +291,3 @@ void top_p_sampling_from_probs(
|
||||
double top_p_val,
|
||||
bool deterministic,
|
||||
int64_t cuda_stream);
|
||||
|
||||
/*
|
||||
* From flash-attention
|
||||
*/
|
||||
std::vector<at::Tensor> mha_fwd(
|
||||
at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
|
||||
const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,
|
||||
// h_k, d) if there is page_table.
|
||||
const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
|
||||
// page_size, h_k, dv) if there is page_table.
|
||||
std::optional<const at::Tensor>&
|
||||
k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
|
||||
std::optional<const at::Tensor>&
|
||||
v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
|
||||
std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
|
||||
std::optional<at::Tensor>& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
|
||||
std::optional<const at::Tensor>& cu_seqlens_q_, // b+1
|
||||
std::optional<const at::Tensor>& cu_seqlens_k_, // b+1
|
||||
std::optional<const at::Tensor>& cu_seqlens_k_new_, // b+1
|
||||
std::optional<const at::Tensor>&
|
||||
seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
|
||||
std::optional<const at::Tensor>&
|
||||
seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
|
||||
std::optional<int> max_seqlen_q_,
|
||||
// TODO: check if we need max_seqlen_k
|
||||
std::optional<int> max_seqlen_k_,
|
||||
std::optional<const at::Tensor>& page_table_, // (b_k, max_num_pages_per_seq)
|
||||
std::optional<const at::Tensor>& kv_batch_idx_, // b. indices to index into the KV cache
|
||||
std::optional<const at::Tensor>& leftpad_k_, // b
|
||||
std::optional<const at::Tensor>& rotary_cos_, // seqlen_ro x (rotary_dim / 2)
|
||||
std::optional<const at::Tensor>& rotary_sin_, // seqlen_ro x (rotary_dim / 2)
|
||||
std::optional<const at::Tensor>& seqlens_rotary_, // b
|
||||
std::optional<at::Tensor>& q_descale_, // (b, h_k), not (b, h)
|
||||
std::optional<at::Tensor>& k_descale_, // (b, h_k)
|
||||
std::optional<at::Tensor>& v_descale_, // (b, h_k)
|
||||
float const softmax_scale,
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
float const softcap,
|
||||
bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
|
||||
std::optional<at::Tensor>& scheduler_metadata_, // (b + 1)
|
||||
int num_splits,
|
||||
std::optional<bool> pack_gqa_,
|
||||
int const sm_margin);
|
||||
|
||||
@@ -3,15 +3,22 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
try:
|
||||
from sgl_kernel import flash_ops
|
||||
except:
|
||||
raise ImportError("Can not import sgl_kernel. Please check your installation.")
|
||||
|
||||
|
||||
def is_fa3_supported(device=None) -> bool:
|
||||
# FA3 can fail without a enough shared memory for a some shapes, currently
|
||||
# only 8.0 and 8.7 have enough shared memory for all shapes
|
||||
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
|
||||
return FA3_AVAILABLE and (
|
||||
torch.cuda.get_device_capability(device)[0] >= 9
|
||||
or torch.cuda.get_device_capability(device) == (8, 0)
|
||||
or torch.cuda.get_device_capability(device) == (8, 7)
|
||||
# now sgl-kernel only build fa3 for sm90a && cuda >= 12.4
|
||||
return (
|
||||
(torch.cuda.get_device_capability(device)[0] >= 9)
|
||||
and (torch.version.cuda >= "12.4")
|
||||
# or torch.cuda.get_device_capability(device) == (8, 0)
|
||||
# or torch.cuda.get_device_capability(device) == (8, 7)
|
||||
)
|
||||
|
||||
|
||||
@@ -135,6 +142,10 @@ def flash_attn_with_kvcache(
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
"""
|
||||
if not is_fa3_supported():
|
||||
raise NotImplementedError(
|
||||
"flash_attn at sgl-kernel is only supported on sm90 and above"
|
||||
)
|
||||
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
|
||||
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
|
||||
if softmax_scale is None:
|
||||
|
||||
@@ -10,7 +10,19 @@ from einops import rearrange, repeat
|
||||
|
||||
apply_rotary_emb = None
|
||||
|
||||
from sgl_kernel.flash_attn import flash_attn_with_kvcache
|
||||
|
||||
def is_fa3_supported(device=None) -> bool:
|
||||
# FA3 can fail without a enough shared memory for a some shapes, currently
|
||||
# only 8.0 and 8.7 have enough shared memory for all shapes
|
||||
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
|
||||
# now sgl-kernel only build fa3 for sm90a && cuda >= 12.4
|
||||
return (
|
||||
(torch.cuda.get_device_capability(device)[0] >= 9)
|
||||
and (torch.version.cuda >= "12.4")
|
||||
# or torch.cuda.get_device_capability(device) == (8, 0)
|
||||
# or torch.cuda.get_device_capability(device) == (8, 7)
|
||||
)
|
||||
|
||||
|
||||
DISABLE_BACKWARD = True
|
||||
# For CI test, we close them to True.
|
||||
@@ -284,6 +296,10 @@ def attention_ref(
|
||||
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_fa3_supported(),
|
||||
reason="flash_attn at sgl-kernel is only supported on sm90 and above",
|
||||
)
|
||||
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
|
||||
@pytest.mark.parametrize(
|
||||
"dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])
|
||||
@@ -372,6 +388,8 @@ def test_flash_attn_kvcache(
|
||||
mha_type,
|
||||
dtype,
|
||||
):
|
||||
from sgl_kernel.flash_attn import flash_attn_with_kvcache
|
||||
|
||||
if page_size is not None and seqlen_k % page_size != 0:
|
||||
pytest.skip()
|
||||
if seqlen_q > seqlen_k and new_kv:
|
||||
|
||||
Reference in New Issue
Block a user