diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 01bafc880..0bd1e7234 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -1,10 +1,12 @@ cmake_minimum_required(VERSION 3.26 FATAL_ERROR) project(sgl-kernel LANGUAGES CXX CUDA) -# we only want to download 3rd, but not build them. -# FetchContent_MakeAvailable will build it. cmake_policy(SET CMP0169 OLD) +include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) + +set(BUILD_FA3, OFF) + find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED) enable_language(CUDA) @@ -22,6 +24,8 @@ elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "11.8") endif() find_package(Torch REQUIRED) +# clean Torch Flag +clear_cuda_arches(CMAKE_FLAG) include(FetchContent) @@ -53,8 +57,8 @@ FetchContent_Populate(repo-flashinfer) FetchContent_Declare( repo-flash-attention GIT_REPOSITORY https://github.com/sgl-project/sgl-attn - GIT_TAG sgl-kernel - GIT_SHALLOW OFF + GIT_TAG sgl-kernel + GIT_SHALLOW OFF ) FetchContent_Populate(repo-flash-attention) @@ -92,14 +96,13 @@ set(SGL_KERNEL_CUDA_FLAGS "-gencode=arch=compute_90,code=sm_90" "-std=c++17" "-DFLASHINFER_ENABLE_F16" + "-DCUTE_USE_PACKED_TUPLE=1" "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1" "-DCUTLASS_VERSIONS_GENERATED" - "-DCUTE_USE_PACKED_TUPLE=1" "-DCUTLASS_TEST_LEVEL=0" "-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1" "-DCUTLASS_DEBUG_TRACE_LEVEL=0" "--expt-relaxed-constexpr" - "--use_fast_math" "-Xcompiler=-Wconversion" "-Xcompiler=-fno-strict-aliasing" ) @@ -122,6 +125,7 @@ else() endif() if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A) + set(BUILD_FA3 ON) list(APPEND SGL_KERNEL_CUDA_FLAGS "-gencode=arch=compute_90a,code=sm_90a" ) @@ -152,30 +156,6 @@ string(REPLACE "-D__CUDA_NO_HALF_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}") string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}") -# set flash-attention sources file -# BF16 source files -file(GLOB FA3_BF16_GEN_SRCS - "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu") -file(GLOB FA3_BF16_GEN_SRCS_ - "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu") -list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_}) - -# FP16 source files -file(GLOB FA3_FP16_GEN_SRCS - "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu") -file(GLOB FA3_FP16_GEN_SRCS_ - "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu") -list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_}) - -# FP8 source files -file(GLOB FA3_FP8_GEN_SRCS - "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu") -file(GLOB FA3_FP8_GEN_SRCS_ - "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu") -list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_}) - -set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS}) - set(SOURCES "csrc/allreduce/trt_reduce_internal.cu" "csrc/allreduce/trt_reduce_kernel.cu" @@ -202,39 +182,94 @@ set(SOURCES "csrc/speculative/eagle_utils.cu" "csrc/speculative/speculative_sampling.cu" "csrc/speculative/packbit.cu" - "csrc/torch_extension.cc" + "csrc/common_extension.cc" "${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu" - "${repo-flash-attention_SOURCE_DIR}/hopper/flash_prepare_scheduler.cu" - "${repo-flash-attention_SOURCE_DIR}/hopper/flash_api.cpp" - "${repo-flash-attention_SOURCE_DIR}/hopper/flash_fwd_combine.cu" - "${FA3_GEN_SRCS}" ) -# Support abi3 for build +# set flash-attention sources file +# BF16 source files +if (BUILD_FA3) + set(SGL_FLASH_KERNEL_CUDA_FLAGS + "-DNDEBUG" + "-DOPERATOR_NAMESPACE=sgl-kernel" + "-O3" + "-Xcompiler" + "-fPIC" + "-gencode=arch=compute_90a,code=sm_90a" + "-std=c++17" + "-DCUTE_USE_PACKED_TUPLE=1" + "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1" + "-DCUTLASS_VERSIONS_GENERATED" + "-DCUTLASS_TEST_LEVEL=0" + "-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1" + "-DCUTLASS_DEBUG_TRACE_LEVEL=0" + "--expt-relaxed-constexpr" + "--expt-extended-lambda" + "--use_fast_math" + "-Xcompiler=-Wconversion" + "-Xcompiler=-fno-strict-aliasing" + ) + + file(GLOB FA3_BF16_GEN_SRCS + "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu") + file(GLOB FA3_BF16_GEN_SRCS_ + "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu") + list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_}) + + # FP16 source files + file(GLOB FA3_FP16_GEN_SRCS + "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu") + file(GLOB FA3_FP16_GEN_SRCS_ + "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu") + list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_}) + + # FP8 source files + file(GLOB FA3_FP8_GEN_SRCS + "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu") + file(GLOB FA3_FP8_GEN_SRCS_ + "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu") + list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_}) + + set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS}) + + set(FLASH_SOURCES + "csrc/flash_extension.cc" + "${repo-flash-attention_SOURCE_DIR}/hopper/flash_prepare_scheduler.cu" + "${repo-flash-attention_SOURCE_DIR}/hopper/flash_api.cpp" + "${repo-flash-attention_SOURCE_DIR}/hopper/flash_fwd_combine.cu" + "${FA3_GEN_SRCS}" + ) + + Python_add_library(flash_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${FLASH_SOURCES}) + + target_compile_options(flash_ops PRIVATE $<$:${SGL_FLASH_KERNEL_CUDA_FLAGS}>) + target_include_directories(flash_ops PRIVATE ${TORCH_INCLUDE_DIRS}) + target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda) + + install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel") + + target_compile_definitions(flash_ops PRIVATE + FLASHATTENTION_DISABLE_SM8x + FLASHATTENTION_DISABLE_BACKWARD + FLASHATTENTION_DISABLE_DROPOUT + # FLASHATTENTION_DISABLE_ALIBI + # FLASHATTENTION_DISABLE_SOFTCAP + FLASHATTENTION_DISABLE_UNEVEN_K + # FLASHATTENTION_DISABLE_LOCAL + FLASHATTENTION_VARLEN_ONLY + ) +endif() + Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES}) target_compile_options(common_ops PRIVATE $<$:${SGL_KERNEL_CUDA_FLAGS}>) - target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS}) - target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt) install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel") -# Add some flash-attention custom flag for inference -target_compile_definitions(common_ops PRIVATE - FLASHATTENTION_DISABLE_SM8x - FLASHATTENTION_DISABLE_BACKWARD - FLASHATTENTION_DISABLE_DROPOUT - # FLASHATTENTION_DISABLE_ALIBI - # FLASHATTENTION_DISABLE_SOFTCAP - FLASHATTENTION_DISABLE_UNEVEN_K - # FLASHATTENTION_DISABLE_LOCAL - FLASHATTENTION_VARLEN_ONLY -) - # JIT Logic # DeepGEMM diff --git a/sgl-kernel/cmake/utils.cmake b/sgl-kernel/cmake/utils.cmake new file mode 100644 index 000000000..0eaa7a61a --- /dev/null +++ b/sgl-kernel/cmake/utils.cmake @@ -0,0 +1,21 @@ +# Adapt from: https://github.com/neuralmagic/vllm-flash-attention/blob/main/cmake/utils.cmake +# +# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in +# `CUDA_ARCH_FLAGS`. +# +# Example: +# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75" +# clear_cuda_arches(CUDA_ARCH_FLAGS) +# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75" +# CMAKE_CUDA_FLAGS="-Wall" +# +macro(clear_cuda_arches CUDA_ARCH_FLAGS) + # Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS` + string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS + ${CMAKE_CUDA_FLAGS}) + + # Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified + # and passed back via the `CUDA_ARCHITECTURES` property. + string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS + ${CMAKE_CUDA_FLAGS}) +endmacro() diff --git a/sgl-kernel/csrc/torch_extension.cc b/sgl-kernel/csrc/common_extension.cc similarity index 88% rename from sgl-kernel/csrc/torch_extension.cc rename to sgl-kernel/csrc/common_extension.cc index 3b91e63cd..a620be120 100644 --- a/sgl-kernel/csrc/torch_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -18,7 +18,7 @@ limitations under the License. #include "sgl_kernel_ops.h" -TORCH_LIBRARY_EXPAND(sgl_kernel, m) { +TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { /* * From csrc/allreduce */ @@ -202,45 +202,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { "top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " "maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()"); m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs); - - /* - * From flash-attention - */ - m.def( - "fwd(Tensor! q," - " Tensor k," - " Tensor v," - " Tensor? k_new," - " Tensor? v_new," - " Tensor? q_v," - " Tensor!? out," - " Tensor? cu_seqlens_q," - " Tensor? cu_seqlens_k," - " Tensor? cu_seqlens_k_new," - " Tensor? seqused_q," - " Tensor? seqused_k," - " int? max_seqlen_q," - " int? max_seqlen_k," - " Tensor? page_table," - " Tensor? kv_batch_idx," - " Tensor? leftpad_k," - " Tensor? rotary_cos," - " Tensor? rotary_sin," - " Tensor? seqlens_rotary," - " Tensor? q_descale," - " Tensor? k_descale," - " Tensor? v_descale," - " float softmax_scale," - " bool is_causal," - " int window_size_left," - " int window_size_right," - " float softcap," - " bool is_rotary_interleaved," - " Tensor? scheduler_metadata," - " int num_splits," - " bool? pack_gqa," - " int sm_margin) -> Tensor[]"); - m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd)); } REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/csrc/flash_extension.cc b/sgl-kernel/csrc/flash_extension.cc new file mode 100644 index 000000000..c4fbe0092 --- /dev/null +++ b/sgl-kernel/csrc/flash_extension.cc @@ -0,0 +1,62 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "sgl_flash_kernel_ops.h" + +TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { + /* + * From flash-attention + */ + m.def( + "fwd(Tensor! q," + " Tensor k," + " Tensor v," + " Tensor? k_new," + " Tensor? v_new," + " Tensor? q_v," + " Tensor!? out," + " Tensor? cu_seqlens_q," + " Tensor? cu_seqlens_k," + " Tensor? cu_seqlens_k_new," + " Tensor? seqused_q," + " Tensor? seqused_k," + " int? max_seqlen_q," + " int? max_seqlen_k," + " Tensor? page_table," + " Tensor? kv_batch_idx," + " Tensor? leftpad_k," + " Tensor? rotary_cos," + " Tensor? rotary_sin," + " Tensor? seqlens_rotary," + " Tensor? q_descale," + " Tensor? k_descale," + " Tensor? v_descale," + " float softmax_scale," + " bool is_causal," + " int window_size_left," + " int window_size_right," + " float softcap," + " bool is_rotary_interleaved," + " Tensor? scheduler_metadata," + " int num_splits," + " bool? pack_gqa," + " int sm_margin) -> Tensor[]"); + m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd)); +} + +REGISTER_EXTENSION(flash_ops) diff --git a/sgl-kernel/include/sgl_flash_kernel_ops.h b/sgl-kernel/include/sgl_flash_kernel_ops.h new file mode 100644 index 000000000..c406fa9f3 --- /dev/null +++ b/sgl-kernel/include/sgl_flash_kernel_ops.h @@ -0,0 +1,85 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include +#include +#include + +#include + +#include "sgl_kernel_torch_shim.h" + +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) + +#define _CONCAT(A, B) A##B +#define CONCAT(A, B) _CONCAT(A, B) + +#define _STRINGIFY(A) #A +#define STRINGIFY(A) _STRINGIFY(A) + +#define REGISTER_EXTENSION(NAME) \ + PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ + static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \ + return PyModule_Create(&module); \ + } + +/* + * From flash-attention + */ +std::vector mha_fwd( + at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, + // h_k, d) if there is page_table. + const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, + // page_size, h_k, dv) if there is page_table. + std::optional& + k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new + std::optional& + v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new + std::optional& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q + std::optional& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + std::optional& cu_seqlens_q_, // b+1 + std::optional& cu_seqlens_k_, // b+1 + std::optional& cu_seqlens_k_new_, // b+1 + std::optional& + seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional& + seqused_k_, // b. If given, only this many elements of each batch element's keys are used. + std::optional max_seqlen_q_, + // TODO: check if we need max_seqlen_k + std::optional max_seqlen_k_, + std::optional& page_table_, // (b_k, max_num_pages_per_seq) + std::optional& kv_batch_idx_, // b. indices to index into the KV cache + std::optional& leftpad_k_, // b + std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional& rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional& seqlens_rotary_, // b + std::optional& q_descale_, // (b, h_k), not (b, h) + std::optional& k_descale_, // (b, h_k) + std::optional& v_descale_, // (b, h_k) + float const softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + float const softcap, + bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + std::optional& scheduler_metadata_, // (b + 1) + int num_splits, + std::optional pack_gqa_, + int const sm_margin); diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index d89fccbb2..847b24ebe 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -23,8 +23,6 @@ limitations under the License. #include -#include "sgl_kernel_torch_shim.h" - #define _CONCAT(A, B) A##B #define CONCAT(A, B) _CONCAT(A, B) @@ -293,48 +291,3 @@ void top_p_sampling_from_probs( double top_p_val, bool deterministic, int64_t cuda_stream); - -/* - * From flash-attention - */ -std::vector mha_fwd( - at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, - // h_k, d) if there is page_table. - const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, - // page_size, h_k, dv) if there is page_table. - std::optional& - k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new - std::optional& - v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new - std::optional& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q - std::optional& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q - std::optional& cu_seqlens_q_, // b+1 - std::optional& cu_seqlens_k_, // b+1 - std::optional& cu_seqlens_k_new_, // b+1 - std::optional& - seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - std::optional& - seqused_k_, // b. If given, only this many elements of each batch element's keys are used. - std::optional max_seqlen_q_, - // TODO: check if we need max_seqlen_k - std::optional max_seqlen_k_, - std::optional& page_table_, // (b_k, max_num_pages_per_seq) - std::optional& kv_batch_idx_, // b. indices to index into the KV cache - std::optional& leftpad_k_, // b - std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) - std::optional& rotary_sin_, // seqlen_ro x (rotary_dim / 2) - std::optional& seqlens_rotary_, // b - std::optional& q_descale_, // (b, h_k), not (b, h) - std::optional& k_descale_, // (b, h_k) - std::optional& v_descale_, // (b, h_k) - float const softmax_scale, - bool is_causal, - int window_size_left, - int window_size_right, - float const softcap, - bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 - std::optional& scheduler_metadata_, // (b + 1) - int num_splits, - std::optional pack_gqa_, - int const sm_margin); diff --git a/sgl-kernel/python/sgl_kernel/flash_attn.py b/sgl-kernel/python/sgl_kernel/flash_attn.py index b23a64623..acf0807b0 100644 --- a/sgl-kernel/python/sgl_kernel/flash_attn.py +++ b/sgl-kernel/python/sgl_kernel/flash_attn.py @@ -3,15 +3,22 @@ from typing import List, Optional, Tuple, Union import torch import torch.nn as nn +try: + from sgl_kernel import flash_ops +except: + raise ImportError("Can not import sgl_kernel. Please check your installation.") + def is_fa3_supported(device=None) -> bool: # FA3 can fail without a enough shared memory for a some shapes, currently # only 8.0 and 8.7 have enough shared memory for all shapes # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x - return FA3_AVAILABLE and ( - torch.cuda.get_device_capability(device)[0] >= 9 - or torch.cuda.get_device_capability(device) == (8, 0) - or torch.cuda.get_device_capability(device) == (8, 7) + # now sgl-kernel only build fa3 for sm90a && cuda >= 12.4 + return ( + (torch.cuda.get_device_capability(device)[0] >= 9) + and (torch.version.cuda >= "12.4") + # or torch.cuda.get_device_capability(device) == (8, 0) + # or torch.cuda.get_device_capability(device) == (8, 7) ) @@ -135,6 +142,10 @@ def flash_attn_with_kvcache( logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). """ + if not is_fa3_supported(): + raise NotImplementedError( + "flash_attn at sgl-kernel is only supported on sm90 and above" + ) assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" if softmax_scale is None: diff --git a/sgl-kernel/tests/test_flash_attention.py b/sgl-kernel/tests/test_flash_attention.py index 37e50cbd7..ff60b7710 100644 --- a/sgl-kernel/tests/test_flash_attention.py +++ b/sgl-kernel/tests/test_flash_attention.py @@ -10,7 +10,19 @@ from einops import rearrange, repeat apply_rotary_emb = None -from sgl_kernel.flash_attn import flash_attn_with_kvcache + +def is_fa3_supported(device=None) -> bool: + # FA3 can fail without a enough shared memory for a some shapes, currently + # only 8.0 and 8.7 have enough shared memory for all shapes + # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x + # now sgl-kernel only build fa3 for sm90a && cuda >= 12.4 + return ( + (torch.cuda.get_device_capability(device)[0] >= 9) + and (torch.version.cuda >= "12.4") + # or torch.cuda.get_device_capability(device) == (8, 0) + # or torch.cuda.get_device_capability(device) == (8, 7) + ) + DISABLE_BACKWARD = True # For CI test, we close them to True. @@ -284,6 +296,10 @@ def attention_ref( return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) +@pytest.mark.skipif( + not is_fa3_supported(), + reason="flash_attn at sgl-kernel is only supported on sm90 and above", +) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize( "dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []) @@ -372,6 +388,8 @@ def test_flash_attn_kvcache( mha_type, dtype, ): + from sgl_kernel.flash_attn import flash_attn_with_kvcache + if page_size is not None and seqlen_k % page_size != 0: pytest.skip() if seqlen_q > seqlen_k and new_kv: