adapt to sglang v0.5.2rc1 on dcu
This commit is contained in:
15
sgl-kernel/.clang-format
Normal file
15
sgl-kernel/.clang-format
Normal file
@@ -0,0 +1,15 @@
|
||||
BasedOnStyle: Google
|
||||
IndentWidth: 2
|
||||
ColumnLimit: 120
|
||||
AllowShortFunctionsOnASingleLine: Empty
|
||||
DerivePointerAlignment: false
|
||||
PointerAlignment: Left
|
||||
NamespaceIndentation: None
|
||||
SortIncludes: true
|
||||
AllowShortLoopsOnASingleLine: false
|
||||
BinPackParameters: false # Prevents packing parameters in declarations
|
||||
BinPackArguments: false # Prevents packing arguments in function calls
|
||||
AlignAfterOpenBracket: AlwaysBreak # Forces a break after the opening parenthesis
|
||||
AlignOperands: Align # Aligns arguments vertically
|
||||
PenaltyBreakBeforeFirstCallParameter: 1 # Encourages breaking before the first argument
|
||||
PenaltyReturnTypeOnItsOwnLine: 100 # Keeps return type with function name
|
||||
503
sgl-kernel/CMakeLists.txt
Normal file
503
sgl-kernel/CMakeLists.txt
Normal file
@@ -0,0 +1,503 @@
|
||||
cmake_minimum_required(VERSION 3.26 FATAL_ERROR)
|
||||
project(sgl-kernel LANGUAGES CXX CUDA)
|
||||
|
||||
# CMake
|
||||
cmake_policy(SET CMP0169 OLD)
|
||||
cmake_policy(SET CMP0177 NEW)
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
||||
set(CMAKE_COLOR_DIAGNOSTICS ON)
|
||||
set(CMAKE_VERBOSE_MAKEFILE ON CACHE BOOL "ON")
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CMAKE_SHARED_LIBRARY_PREFIX "")
|
||||
|
||||
# Python
|
||||
find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
|
||||
|
||||
# CXX
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
|
||||
|
||||
# CUDA
|
||||
enable_language(CUDA)
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
set_property(GLOBAL PROPERTY CUDA_SEPARABLE_COMPILATION ON)
|
||||
|
||||
message(STATUS "Detected CUDA_VERSION=${CUDA_VERSION}")
|
||||
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "13.0")
|
||||
message("CUDA_VERSION ${CUDA_VERSION} >= 13.0")
|
||||
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8")
|
||||
message("CUDA_VERSION ${CUDA_VERSION} >= 12.8")
|
||||
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4")
|
||||
message("CUDA_VERSION ${CUDA_VERSION} >= 12.4")
|
||||
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.1")
|
||||
message("CUDA_VERSION ${CUDA_VERSION} >= 12.1")
|
||||
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "11.8")
|
||||
message("CUDA_VERSION ${CUDA_VERSION} >= 11.8")
|
||||
endif()
|
||||
|
||||
# Torch
|
||||
find_package(Torch REQUIRED)
|
||||
# clean Torch Flag
|
||||
clear_cuda_arches(CMAKE_FLAG)
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
# cutlass
|
||||
FetchContent_Declare(
|
||||
repo-cutlass
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
|
||||
GIT_TAG a49a78ffefc86a87160dfe0ccc3a3a2d1622c918
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-cutlass)
|
||||
|
||||
# DeepGEMM
|
||||
FetchContent_Declare(
|
||||
repo-deepgemm
|
||||
GIT_REPOSITORY https://github.com/sgl-project/DeepGEMM
|
||||
GIT_TAG sgl
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-deepgemm)
|
||||
|
||||
FetchContent_Declare(
|
||||
repo-fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt
|
||||
GIT_TAG 553ec11ec06fbe0beebfbb45f9dc3c9eabd83d28
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-fmt)
|
||||
|
||||
# Triton
|
||||
FetchContent_Declare(
|
||||
repo-triton
|
||||
GIT_REPOSITORY "https://github.com/triton-lang/triton"
|
||||
GIT_TAG 8f9f695ea8fde23a0c7c88e4ab256634ca27789f
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-triton)
|
||||
|
||||
# flashinfer
|
||||
FetchContent_Declare(
|
||||
repo-flashinfer
|
||||
GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git
|
||||
GIT_TAG 018b551825c8e5579206e6eb9d3229fa679202b3
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-flashinfer)
|
||||
|
||||
# flash-attention
|
||||
FetchContent_Declare(
|
||||
repo-flash-attention
|
||||
GIT_REPOSITORY https://github.com/sgl-project/sgl-attn
|
||||
GIT_TAG sgl-kernel
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-flash-attention)
|
||||
|
||||
# mscclpp
|
||||
FetchContent_Declare(
|
||||
repo-mscclpp
|
||||
GIT_REPOSITORY https://github.com/microsoft/mscclpp.git
|
||||
GIT_TAG 51eca89d20f0cfb3764ccd764338d7b22cd486a6
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-mscclpp)
|
||||
|
||||
# ccache option
|
||||
option(ENABLE_CCACHE "Whether to use ccache" ON)
|
||||
find_program(CCACHE_FOUND ccache)
|
||||
if(CCACHE_FOUND AND ENABLE_CCACHE AND DEFINED ENV{CCACHE_DIR})
|
||||
message(STATUS "Building with CCACHE enabled")
|
||||
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "ccache")
|
||||
set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK "ccache")
|
||||
endif()
|
||||
|
||||
# Enable gencode below SM90
|
||||
option(ENABLE_BELOW_SM90 "Enable below SM90" ON)
|
||||
|
||||
if (CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
|
||||
set(ENABLE_BELOW_SM90 OFF)
|
||||
message(STATUS "For aarch64, disable gencode below SM90 by default")
|
||||
endif()
|
||||
|
||||
include_directories(
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/csrc
|
||||
${repo-cutlass_SOURCE_DIR}/include
|
||||
${repo-cutlass_SOURCE_DIR}/tools/util/include
|
||||
${repo-flashinfer_SOURCE_DIR}/include
|
||||
${repo-flashinfer_SOURCE_DIR}/csrc
|
||||
${repo-mscclpp_SOURCE_DIR}/include
|
||||
)
|
||||
|
||||
set(SGL_KERNEL_CUDA_FLAGS
|
||||
"-DNDEBUG"
|
||||
"-DOPERATOR_NAMESPACE=sgl-kernel"
|
||||
"-O3"
|
||||
"-Xcompiler"
|
||||
"-fPIC"
|
||||
"-gencode=arch=compute_90,code=sm_90"
|
||||
"-std=c++17"
|
||||
"-DFLASHINFER_ENABLE_F16"
|
||||
"-DCUTE_USE_PACKED_TUPLE=1"
|
||||
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
|
||||
"-DCUTLASS_VERSIONS_GENERATED"
|
||||
"-DCUTLASS_TEST_LEVEL=0"
|
||||
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
|
||||
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
|
||||
"--expt-relaxed-constexpr"
|
||||
"--expt-extended-lambda"
|
||||
"--threads=32"
|
||||
|
||||
# Supress warnings
|
||||
"-Xcompiler=-Wno-clang-format-violations"
|
||||
"-Xcompiler=-Wno-conversion"
|
||||
"-Xcompiler=-Wno-deprecated-declarations"
|
||||
"-Xcompiler=-Wno-terminate"
|
||||
"-Xcompiler=-Wfatal-errors"
|
||||
"-Xcompiler=-ftemplate-backtrace-limit=1"
|
||||
"-Xcudafe=--diag_suppress=177" # variable was declared but never referenced
|
||||
|
||||
# uncomment to debug
|
||||
# "--ptxas-options=-v"
|
||||
# "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage"
|
||||
)
|
||||
|
||||
option(SGL_KERNEL_ENABLE_BF16 "Enable BF16" ON)
|
||||
option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON)
|
||||
option(SGL_KERNEL_ENABLE_FP4 "Enable FP4" OFF)
|
||||
option(SGL_KERNEL_ENABLE_FA3 "Enable FA3" OFF)
|
||||
option(SGL_KERNEL_ENABLE_SM90A "Enable SM90A" OFF)
|
||||
option(SGL_KERNEL_ENABLE_SM100A "Enable SM100A" OFF)
|
||||
|
||||
if (SGL_KERNEL_ENABLE_BF16)
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-DFLASHINFER_ENABLE_BF16"
|
||||
)
|
||||
endif()
|
||||
|
||||
if (SGL_KERNEL_ENABLE_FP8)
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-DFLASHINFER_ENABLE_FP8"
|
||||
"-DFLASHINFER_ENABLE_FP8_E4M3"
|
||||
"-DFLASHINFER_ENABLE_FP8_E5M2"
|
||||
)
|
||||
endif()
|
||||
|
||||
if (ENABLE_BELOW_SM90)
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-gencode=arch=compute_80,code=sm_80"
|
||||
"-gencode=arch=compute_89,code=sm_89"
|
||||
)
|
||||
endif()
|
||||
|
||||
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-gencode=arch=compute_100,code=sm_100"
|
||||
"-gencode=arch=compute_100a,code=sm_100a"
|
||||
"-gencode=arch=compute_120,code=sm_120"
|
||||
"-gencode=arch=compute_120a,code=sm_120a"
|
||||
)
|
||||
|
||||
# refer sm_121, sm_110 and sm_101 description https://github.com/pytorch/pytorch/pull/156176
|
||||
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "13.0")
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-gencode=arch=compute_103,code=sm_103"
|
||||
"-gencode=arch=compute_103a,code=sm_103a"
|
||||
"-gencode=arch=compute_110,code=sm_110"
|
||||
"-gencode=arch=compute_110a,code=sm_110a"
|
||||
"-gencode=arch=compute_121,code=sm_121"
|
||||
"-gencode=arch=compute_121a,code=sm_121a"
|
||||
"--compress-mode=size"
|
||||
)
|
||||
else()
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-gencode=arch=compute_101,code=sm_101"
|
||||
"-gencode=arch=compute_101a,code=sm_101a"
|
||||
)
|
||||
endif()
|
||||
|
||||
else()
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-use_fast_math"
|
||||
)
|
||||
endif()
|
||||
|
||||
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A)
|
||||
set(SGL_KERNEL_ENABLE_FA3 ON)
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-gencode=arch=compute_90a,code=sm_90a"
|
||||
)
|
||||
endif()
|
||||
|
||||
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_FP4)
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-DENABLE_NVFP4=1"
|
||||
)
|
||||
endif()
|
||||
|
||||
set(SOURCES
|
||||
"csrc/allreduce/custom_all_reduce.cu"
|
||||
"csrc/allreduce/mscclpp_allreduce.cu"
|
||||
"csrc/attention/cascade.cu"
|
||||
"csrc/attention/cutlass_mla_kernel.cu"
|
||||
"csrc/attention/lightning_attention_decode_kernel.cu"
|
||||
"csrc/attention/merge_attn_states.cu"
|
||||
"csrc/attention/vertical_slash_index.cu"
|
||||
"csrc/elementwise/activation.cu"
|
||||
"csrc/elementwise/cast.cu"
|
||||
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
|
||||
"csrc/elementwise/rope.cu"
|
||||
"csrc/common_extension.cc"
|
||||
|
||||
"csrc/gemm/awq_kernel.cu"
|
||||
"csrc/gemm/bmm_fp8.cu"
|
||||
"csrc/gemm/dsv3_fused_a_gemm.cu"
|
||||
"csrc/gemm/dsv3_router_gemm_bf16_out.cu"
|
||||
"csrc/gemm/dsv3_router_gemm_entry.cu"
|
||||
"csrc/gemm/dsv3_router_gemm_float_out.cu"
|
||||
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
|
||||
"csrc/gemm/fp8_gemm_kernel.cu"
|
||||
"csrc/gemm/int8_gemm_kernel.cu"
|
||||
"csrc/gemm/nvfp4_expert_quant.cu"
|
||||
"csrc/gemm/nvfp4_quant_entry.cu"
|
||||
"csrc/gemm/nvfp4_quant_kernels.cu"
|
||||
"csrc/gemm/nvfp4_scaled_mm_entry.cu"
|
||||
"csrc/gemm/nvfp4_scaled_mm_kernels.cu"
|
||||
"csrc/gemm/per_tensor_quant_fp8.cu"
|
||||
"csrc/gemm/per_token_group_quant_8bit.cu"
|
||||
"csrc/gemm/per_token_quant_fp8.cu"
|
||||
"csrc/gemm/qserve_w4a8_per_chn_gemm.cu"
|
||||
"csrc/gemm/qserve_w4a8_per_group_gemm.cu"
|
||||
"csrc/gemm/marlin/gptq_marlin.cu"
|
||||
"csrc/gemm/marlin/gptq_marlin_repack.cu"
|
||||
"csrc/gemm/marlin/awq_marlin_repack.cu"
|
||||
"csrc/gemm/gptq/gptq_kernel.cu"
|
||||
|
||||
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
|
||||
|
||||
"csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu"
|
||||
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
|
||||
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
|
||||
"csrc/moe/marlin_moe_wna16/ops.cu"
|
||||
"csrc/moe/moe_align_kernel.cu"
|
||||
"csrc/moe/moe_fused_gate.cu"
|
||||
"csrc/moe/moe_topk_softmax_kernels.cu"
|
||||
"csrc/moe/nvfp4_blockwise_moe.cu"
|
||||
"csrc/moe/fp8_blockwise_moe_kernel.cu"
|
||||
"csrc/moe/prepare_moe_input.cu"
|
||||
"csrc/moe/ep_moe_reorder_kernel.cu"
|
||||
"csrc/moe/ep_moe_silu_and_mul_kernel.cu"
|
||||
|
||||
"csrc/memory/store.cu"
|
||||
"csrc/kvcacheio/transfer.cu"
|
||||
|
||||
"csrc/speculative/eagle_utils.cu"
|
||||
"csrc/speculative/packbit.cu"
|
||||
"csrc/speculative/speculative_sampling.cu"
|
||||
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
|
||||
|
||||
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu"
|
||||
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu"
|
||||
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu"
|
||||
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_sm80.cu"
|
||||
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/flash_sparse_api.cpp"
|
||||
)
|
||||
|
||||
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
|
||||
|
||||
target_compile_options(common_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
|
||||
target_include_directories(common_ops PRIVATE
|
||||
${repo-cutlass_SOURCE_DIR}/examples/77_blackwell_fmha
|
||||
${repo-cutlass_SOURCE_DIR}/examples/common
|
||||
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src
|
||||
)
|
||||
|
||||
find_package(Python3 COMPONENTS Interpreter REQUIRED)
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -c "import torch; print(int(torch._C._GLIBCXX_USE_CXX11_ABI))"
|
||||
OUTPUT_VARIABLE TORCH_CXX11_ABI
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
if(TORCH_CXX11_ABI STREQUAL "0")
|
||||
message(STATUS "Using old C++ ABI (-D_GLIBCXX_USE_CXX11_ABI=0)")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
|
||||
else()
|
||||
message(STATUS "Using new C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI=1)")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1")
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1")
|
||||
endif()
|
||||
|
||||
# mscclpp
|
||||
set(MSCCLPP_USE_CUDA ON)
|
||||
set(MSCCLPP_BYPASS_GPU_CHECK ON)
|
||||
set(MSCCLPP_BUILD_TESTS OFF)
|
||||
add_subdirectory(
|
||||
${repo-mscclpp_SOURCE_DIR}
|
||||
${CMAKE_CURRENT_BINARY_DIR}/mscclpp-build
|
||||
)
|
||||
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static)
|
||||
|
||||
# flash attention
|
||||
target_compile_definitions(common_ops PRIVATE
|
||||
FLASHATTENTION_DISABLE_BACKWARD
|
||||
FLASHATTENTION_DISABLE_DROPOUT
|
||||
FLASHATTENTION_DISABLE_UNEVEN_K
|
||||
)
|
||||
|
||||
install(TARGETS common_ops LIBRARY DESTINATION sgl_kernel)
|
||||
|
||||
# ============================ Optional Install ============================= #
|
||||
# set flash-attention sources file
|
||||
# Now FA3 support sm80/sm86/sm90
|
||||
if (SGL_KERNEL_ENABLE_FA3)
|
||||
set(SGL_FLASH_KERNEL_CUDA_FLAGS
|
||||
"-DNDEBUG"
|
||||
"-DOPERATOR_NAMESPACE=sgl-kernel"
|
||||
"-O3"
|
||||
"-Xcompiler"
|
||||
"-fPIC"
|
||||
"-gencode=arch=compute_90a,code=sm_90a"
|
||||
"-std=c++17"
|
||||
"-DCUTE_USE_PACKED_TUPLE=1"
|
||||
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
|
||||
"-DCUTLASS_VERSIONS_GENERATED"
|
||||
"-DCUTLASS_TEST_LEVEL=0"
|
||||
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
|
||||
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
|
||||
"--expt-relaxed-constexpr"
|
||||
"--expt-extended-lambda"
|
||||
"--use_fast_math"
|
||||
"-Xcompiler=-Wconversion"
|
||||
"-Xcompiler=-fno-strict-aliasing"
|
||||
)
|
||||
|
||||
if (ENABLE_BELOW_SM90)
|
||||
list(APPEND SGL_FLASH_KERNEL_CUDA_FLAGS
|
||||
"-gencode=arch=compute_80,code=sm_80"
|
||||
"-gencode=arch=compute_86,code=sm_86"
|
||||
)
|
||||
# SM8X Logic
|
||||
file(GLOB FA3_SM8X_GEN_SRCS
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdim*_sm80.cu")
|
||||
endif()
|
||||
|
||||
file(GLOB FA3_BF16_GEN_SRCS
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
|
||||
file(GLOB FA3_BF16_GEN_SRCS_
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu")
|
||||
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})
|
||||
|
||||
# FP16 source files
|
||||
file(GLOB FA3_FP16_GEN_SRCS
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu")
|
||||
file(GLOB FA3_FP16_GEN_SRCS_
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu")
|
||||
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})
|
||||
|
||||
# FP8 source files
|
||||
file(GLOB FA3_FP8_GEN_SRCS
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu")
|
||||
file(GLOB FA3_FP8_GEN_SRCS_
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu")
|
||||
list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_})
|
||||
|
||||
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS} ${FA3_SM8X_GEN_SRCS})
|
||||
|
||||
set(FLASH_SOURCES
|
||||
"csrc/flash_extension.cc"
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_prepare_scheduler.cu"
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_api.cpp"
|
||||
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_fwd_combine.cu"
|
||||
"${FA3_GEN_SRCS}"
|
||||
)
|
||||
|
||||
Python_add_library(flash_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${FLASH_SOURCES})
|
||||
|
||||
target_compile_options(flash_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_FLASH_KERNEL_CUDA_FLAGS}>)
|
||||
target_include_directories(flash_ops PRIVATE
|
||||
${repo-flash-attention_SOURCE_DIR}/hopper
|
||||
)
|
||||
target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
|
||||
|
||||
install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel")
|
||||
set(FLASH_OPS_COMPILE_DEFS
|
||||
FLASHATTENTION_DISABLE_BACKWARD
|
||||
FLASHATTENTION_DISABLE_DROPOUT
|
||||
FLASHATTENTION_DISABLE_UNEVEN_K
|
||||
FLASHATTENTION_VARLEN_ONLY
|
||||
)
|
||||
|
||||
if(NOT ENABLE_BELOW_SM90)
|
||||
list(APPEND FLASH_OPS_COMPILE_DEFS FLASHATTENTION_DISABLE_SM8x)
|
||||
endif()
|
||||
target_compile_definitions(flash_ops PRIVATE ${FLASH_OPS_COMPILE_DEFS})
|
||||
endif()
|
||||
|
||||
# Build spatial_ops as a separate, optional extension for green contexts
|
||||
set(SPATIAL_SOURCES
|
||||
"csrc/spatial/greenctx_stream.cu"
|
||||
"csrc/spatial_extension.cc"
|
||||
)
|
||||
|
||||
Python_add_library(spatial_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SPATIAL_SOURCES})
|
||||
target_compile_options(spatial_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
|
||||
target_link_libraries(spatial_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
|
||||
install(TARGETS spatial_ops LIBRARY DESTINATION sgl_kernel)
|
||||
|
||||
|
||||
# ============================ DeepGEMM (JIT) ============================= #
|
||||
# Create a separate library for DeepGEMM's Python API.
|
||||
# This keeps its compilation isolated from the main common_ops.
|
||||
set(DEEPGEMM_SOURCES
|
||||
"${repo-deepgemm_SOURCE_DIR}/csrc/python_api.cpp"
|
||||
)
|
||||
|
||||
Python_add_library(deep_gemm_cpp MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${DEEPGEMM_SOURCES})
|
||||
|
||||
# Link against necessary libraries, including nvrtc for JIT compilation.
|
||||
target_link_libraries(deep_gemm_cpp PRIVATE ${TORCH_LIBRARIES} c10 cuda nvrtc mscclpp_static)
|
||||
|
||||
# Add include directories needed by DeepGEMM.
|
||||
target_include_directories(deep_gemm_cpp PRIVATE
|
||||
${repo-deepgemm_SOURCE_DIR}/deep_gemm/include
|
||||
${repo-cutlass_SOURCE_DIR}/include
|
||||
${repo-fmt_SOURCE_DIR}/include
|
||||
)
|
||||
|
||||
# Apply the same compile options as common_ops.
|
||||
target_compile_options(deep_gemm_cpp PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
|
||||
|
||||
# Create an empty __init__.py to make `deepgemm` a Python package.
|
||||
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/deepgemm_pkg_init.py "")
|
||||
install(
|
||||
FILES ${CMAKE_CURRENT_BINARY_DIR}/deepgemm_pkg_init.py
|
||||
DESTINATION deep_gemm
|
||||
RENAME __init__.py
|
||||
)
|
||||
|
||||
# Install the compiled DeepGEMM API library.
|
||||
install(TARGETS deep_gemm_cpp LIBRARY DESTINATION deep_gemm)
|
||||
|
||||
# Install the source files required by DeepGEMM for runtime JIT compilation.
|
||||
install(
|
||||
DIRECTORY ${repo-deepgemm_SOURCE_DIR}/deep_gemm/
|
||||
DESTINATION deep_gemm
|
||||
)
|
||||
|
||||
install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cute/"
|
||||
DESTINATION "deep_gemm/include/cute")
|
||||
|
||||
install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cutlass/"
|
||||
DESTINATION "deep_gemm/include/cutlass")
|
||||
|
||||
# triton_kernels
|
||||
install(DIRECTORY "${repo-triton_SOURCE_DIR}/python/triton_kernels/triton_kernels/"
|
||||
DESTINATION "triton_kernels"
|
||||
PATTERN ".git*" EXCLUDE
|
||||
PATTERN "__pycache__" EXCLUDE)
|
||||
201
sgl-kernel/LICENSE
Normal file
201
sgl-kernel/LICENSE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2023-2024 SGLang Team
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
71
sgl-kernel/Makefile
Normal file
71
sgl-kernel/Makefile
Normal file
@@ -0,0 +1,71 @@
|
||||
.PHONY: help check-deps install-deps tree ln submodule install build clean rebuild test format update
|
||||
|
||||
# Show help for each target
|
||||
help: ## Show this help message
|
||||
@echo "Available targets:"
|
||||
@grep -E '^[a-zA-Z0-9_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
check-deps: ## Check and install required Python formatting dependencies
|
||||
@command -v isort >/dev/null 2>&1 || (echo "Installing isort..." && pip install isort)
|
||||
@command -v black >/dev/null 2>&1 || (echo "Installing black..." && pip install black)
|
||||
|
||||
install-deps: ## Install Python formatting tools (isort and black)
|
||||
pip install scikit-build-core isort black
|
||||
|
||||
tree: ## Show project directory structure
|
||||
@tree --prune -I "__pycache__|*.egg-info|*.so|build|3rdparty|dist"
|
||||
|
||||
submodule: ## Initialize and update git submodules
|
||||
@git submodule update --init --recursive
|
||||
|
||||
ln: submodule ## Create compilation database
|
||||
@rm -rf build && mkdir build && cd build && cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=YES -DCMAKE_POLICY_VERSION_MINIMUM=3.5
|
||||
|
||||
install: submodule ## Install package in development mode
|
||||
@pip install -e . --no-build-isolation
|
||||
|
||||
build: install-deps submodule ## Build and install wheel package
|
||||
@rm -rf dist/* || true && CMAKE_POLICY_VERSION_MINIMUM=3.5 MAX_JOBS=$(nproc) CMAKE_BUILD_PARALLEL_LEVEL=$(nproc) uv build --wheel -Cbuild-dir=build . --verbose --color=always --no-build-isolation && pip3 install dist/*whl --force-reinstall --no-deps
|
||||
|
||||
clean: ## Remove build artifacts
|
||||
@rm -rf build dist *.egg-info
|
||||
|
||||
rebuild: clean submodule build ## Clean and rebuild the project
|
||||
@echo "Succeed to rebuild"
|
||||
|
||||
test: ## Run all tests
|
||||
@find tests -name "test_*.py" | xargs -n 1 python3
|
||||
|
||||
format: check-deps ## Format all source files
|
||||
@echo "Formatting source files..."
|
||||
@find csrc tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i
|
||||
@find python tests -name '*.py' | xargs isort
|
||||
@find python tests -name '*.py' | xargs black
|
||||
@pre-commit run --all-files
|
||||
|
||||
FILES_TO_UPDATE = python/sgl_kernel/version.py \
|
||||
pyproject.toml \
|
||||
pyproject_rocm.toml \
|
||||
pyproject_cpu.toml \
|
||||
../docker/Dockerfile \
|
||||
../.github/workflows/pr-test-pd-router.yml
|
||||
|
||||
update: ## Update version numbers across project files. Usage: make update <new_version>
|
||||
@if [ -z "$(filter-out $@,$(MAKECMDGOALS))" ]; then \
|
||||
echo "Version required. Usage: make update <new_version>"; \
|
||||
exit 1; \
|
||||
fi
|
||||
@OLD_VERSION=$$(grep "version" python/sgl_kernel/version.py | cut -d '"' -f2); \
|
||||
NEW_VERSION=$(filter-out $@,$(MAKECMDGOALS)); \
|
||||
echo "Updating version from $$OLD_VERSION to $$NEW_VERSION"; \
|
||||
for file in $(FILES_TO_UPDATE); do \
|
||||
if [ "$(shell uname)" = "Darwin" ]; then \
|
||||
sed -i '' -e "s/$$OLD_VERSION/$$NEW_VERSION/g" $$file; \
|
||||
else \
|
||||
sed -i -e "s/$$OLD_VERSION/$$NEW_VERSION/g" $$file; \
|
||||
fi \
|
||||
done; \
|
||||
echo "Version update complete"
|
||||
|
||||
%:
|
||||
@:
|
||||
261
sgl-kernel/README.md
Normal file
261
sgl-kernel/README.md
Normal file
@@ -0,0 +1,261 @@
|
||||
# SGL Kernel
|
||||
|
||||
[Kernel Library](https://github.com/sgl-project/sglang/tree/main/sgl-kernel) for SGLang
|
||||
|
||||
[](https://pypi.org/project/sgl-kernel)
|
||||
|
||||
## Installation
|
||||
For CUDA 12.1 and above:
|
||||
|
||||
```bash
|
||||
pip3 install sgl-kernel
|
||||
```
|
||||
|
||||
For CUDA 11.8:
|
||||
|
||||
```bash
|
||||
pip3 install sgl-kernel -i https://docs.sglang.ai/whl/cu118
|
||||
```
|
||||
|
||||
## Build from source
|
||||
|
||||
Development build:
|
||||
|
||||
```bash
|
||||
make build
|
||||
```
|
||||
|
||||
Note:
|
||||
|
||||
The `sgl-kernel` is rapidly evolving. If you experience a compilation failure, try using `make rebuild`.
|
||||
|
||||
### Build with [ccache](https://github.com/ccache/ccache)
|
||||
```bash
|
||||
# or `yum install -y ccache`.
|
||||
apt-get install -y ccache
|
||||
# Building with ccache is enabled when ccache is installed and CCACHE_DIR is set.
|
||||
export CCACHE_DIR=/path/to/your/ccache/dir
|
||||
export CCACHE_BACKEND=""
|
||||
export CCACHE_KEEP_LOCAL_STORAGE="TRUE"
|
||||
unset CCACHE_READONLY
|
||||
python -m uv build --wheel -Cbuild-dir=build --color=always .
|
||||
```
|
||||
|
||||
### Configuring CMake Build Options
|
||||
Cmake options can be configuring by adding `-Ccmake.define.<option>=<value>` to the `uv build` flags.
|
||||
For example, to enable building FP4 kernels, use:
|
||||
```bash
|
||||
python -m uv build --wheel -Cbuild-dir=build -Ccmake.define.SGL_KERNEL_ENABLE_FP4=1 --color=always .
|
||||
```
|
||||
See CMakeLists.txt for more options.
|
||||
|
||||
### Parallel Build
|
||||
|
||||
We highly recommend you build sgl-kernel with Ninja. Ninja can automatically build sgl-kernel in parallel.
|
||||
And if you build the sgl-kernel with cmake, you need to add `CMAKE_BUILD_PARALLEL_LEVEL` for parallel build like:
|
||||
|
||||
```bash
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=$(nproc) python -m uv build --wheel -Cbuild-dir=build --color=always .
|
||||
```
|
||||
|
||||
### ⚠️ Compilation Issue with `sgl-kernel` and CUDA 12.6
|
||||
|
||||
When compiling `sgl-kernel` with FlashAttention on a Hopper GPU using CUDA 12.6, you may encounter a segmentation fault:
|
||||
|
||||
```bash
|
||||
kernel/build/_deps/repo-flash-attention-src/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu -o CMakeFiles/flash_ops.dir/_deps/repo-flash-attention-src/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu.o
|
||||
Segmentation fault (core dumped)
|
||||
```
|
||||
|
||||
⚠️ **Note**: To ensure that FlashAttention compiles correctly on Hopper GPU Architecture(sm90), it is strongly [recommended](https://github.com/Dao-AILab/flash-attention/issues/1453) to use:
|
||||
- nvcc version: 12.6
|
||||
- ptxas version: 12.8
|
||||
|
||||
**1. Check Current Versions**
|
||||
|
||||
Before proceeding, verify your current CUDA tool versions:
|
||||
```bash
|
||||
nvcc --version
|
||||
ptxas --version
|
||||
```
|
||||
**2. Update ptxas to 12.8 (if needed)**
|
||||
|
||||
1. Save the following script to a file (e.g., `update_ptxas.sh`).
|
||||
```bash
|
||||
#!/usr/bin/env bash
|
||||
# Source: https://github.com/Dao-AILab/flash-attention/blob/7ff1b621112ba8b538e2fc6a316f2a6b6f22e518/hopper/setup.py#L404
|
||||
set -ex
|
||||
|
||||
if [ -z "$1" ]; then
|
||||
echo "Usage: $0 <CUDA_VERSION>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
CUDA_VERSION=$1
|
||||
|
||||
if awk "BEGIN {exit !("$CUDA_VERSION" >= 12.6 && "$CUDA_VERSION" < 12.8)}"; then
|
||||
NVCC_ARCHIVE_VERSION="12.8.93"
|
||||
NVCC_ARCHIVE_NAME="cuda_nvcc-linux-x86_64-${NVCC_ARCHIVE_VERSION}-archive"
|
||||
NVCC_ARCHIVE_TAR="${NVCC_ARCHIVE_NAME}.tar.xz"
|
||||
NVCC_ARCHIVE_URL="https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/linux-x86_64/${NVCC_ARCHIVE_TAR}"
|
||||
|
||||
wget "$NVCC_ARCHIVE_URL"
|
||||
tar -xf "$NVCC_ARCHIVE_TAR"
|
||||
|
||||
mkdir -p /usr/local/cuda/bin
|
||||
cp "${NVCC_ARCHIVE_NAME}/bin/ptxas" /usr/local/cuda/bin/
|
||||
|
||||
# Clean up temporary files
|
||||
rm -f "${NVCC_ARCHIVE_TAR}"
|
||||
rm -rf "${NVCC_ARCHIVE_NAME}"
|
||||
fi
|
||||
```
|
||||
2. Run the script with your CUDA version as the argument, using `sudo`:
|
||||
```bash
|
||||
sudo bash update_ptxas.sh 12.6
|
||||
# Check the version
|
||||
ptxas --version
|
||||
```
|
||||
|
||||
# Developer Guide
|
||||
|
||||
## Development Environment Setup
|
||||
|
||||
Use Docker to set up the development environment. See [Docker setup guide](https://github.com/sgl-project/sglang/blob/main/docs/developer_guide/development_guide_using_docker.md#setup-docker-container).
|
||||
|
||||
Create and enter development container:
|
||||
```bash
|
||||
docker run -itd --shm-size 32g --gpus all -v $HOME/.cache:/root/.cache --ipc=host --name sglang_zhyncs lmsysorg/sglang:dev /bin/zsh
|
||||
docker exec -it sglang_zhyncs /bin/zsh
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
|
||||
### Dependencies
|
||||
|
||||
Third-party libraries:
|
||||
|
||||
- [CUTLASS](https://github.com/NVIDIA/cutlass)
|
||||
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer)
|
||||
- [DeepGEMM](https://github.com/deepseek-ai/DeepGEMM)
|
||||
- [FlashAttention](https://github.com/Dao-AILab/flash-attention)
|
||||
|
||||
### FlashAttention FYI
|
||||
|
||||
FA3 can fail without a enough shared memory for a some shapes, such as higher hidden_dim or some special cases. Right now, fa3 is supported for sm80/sm87 and sm86/sm89.
|
||||
|
||||
The main different Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x.
|
||||
|
||||
And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. That means if you use **A100(tested)**/A*0/**L20(tested)**/L40/L40s/**3090(tested)** you can use fa3.
|
||||
|
||||
### Kernel Development
|
||||
|
||||
Steps to add a new kernel:
|
||||
|
||||
1. Implement the kernel in [csrc](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc)
|
||||
2. Expose the interface in [include/sgl_kernel_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/include/sgl_kernel_ops.h)
|
||||
3. Create torch extension in [csrc/common_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/csrc/common_extension.cc)
|
||||
4. Update [CMakeLists.txt](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/CMakeLists.txt) to include new CUDA source
|
||||
5. Expose Python interface in [python](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel)
|
||||
|
||||
### Development Tips
|
||||
|
||||
1. When implementing kernels in [csrc](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc), only define pure CUDA files and C++ interfaces. If you need to use `Torch::tensor`, use `<torch/all.h>` instead of `<torch/extension.h>`. Using `<torch/extension.h>` will cause compilation errors when using SABI.
|
||||
|
||||
2. When creating torch extensions, add the function definition with `m.def`, and device binding with `m.impl`:
|
||||
- Using torch.compile need `m.def` with schema, it helps auto capture the custom kernel. Reference: [How to add FakeTensor](https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit?tab=t.0#heading=h.ptttacy8y1u9)
|
||||
|
||||
- How to write schema: [Schema reference](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func)
|
||||
|
||||
```cpp
|
||||
// We need def with schema here for torch.compile
|
||||
m.def(
|
||||
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int "
|
||||
"cublas_handle, int cuda_stream) -> ()");
|
||||
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
|
||||
```
|
||||
|
||||
3. When exposing Python interfaces, avoid using kwargs in C++ interface kernels.
|
||||
|
||||
**Avoid this:**
|
||||
|
||||
```cpp
|
||||
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
|
||||
q=query.view(query.shape[0], -1, head_size),
|
||||
k=key.view(key.shape[0], -1, head_size),
|
||||
q_rope=query.view(query.shape[0], -1, head_size),
|
||||
k_rope=key.view(key.shape[0], -1, head_size),
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
pos_ids=positions.long(),
|
||||
interleave=(not is_neox),
|
||||
cuda_stream=get_cuda_stream(),
|
||||
)
|
||||
```
|
||||
|
||||
**Use this instead:**
|
||||
|
||||
```cpp
|
||||
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
|
||||
query.view(query.shape[0], -1, head_size),
|
||||
key.view(key.shape[0], -1, head_size),
|
||||
query.view(query.shape[0], -1, head_size),
|
||||
key.view(key.shape[0], -1, head_size),
|
||||
cos_sin_cache,
|
||||
positions.long(),
|
||||
(not is_neox),
|
||||
get_cuda_stream(),
|
||||
)
|
||||
```
|
||||
|
||||
### Integrating Third-Party Libraries with Data Type Conversion
|
||||
|
||||
When integrating new third-party libraries like flash-attention, you may encounter data type compatibility issues between the C++ interface and PyTorch bindings. For example, the third-party code might use `float` or `int` types, while PyTorch requires `double` and `int64_t`.
|
||||
|
||||
> The reason we need `double` and `int64_t` in torch binding is that TORCH_LIBRARY handles the `Python-to-C++` conversion process. Python's `float` data type actually corresponds to `double` in C++, while Python's `int` corresponds to `int64_t` in C++.
|
||||
|
||||
To address this issue, we provide the `make_pytorch_shim` function in [sgl_kernel_torch_shim](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/include/sgl_kernel_torch_shim.h) that handles data type conversions automatically.
|
||||
|
||||
When you need to support new data type conversions, you can easily add conversion functions like this:
|
||||
|
||||
```cpp
|
||||
// Map `int` -> `int64_t`
|
||||
template <>
|
||||
struct pytorch_library_compatible_type<int> {
|
||||
using type = int64_t;
|
||||
static int convert_from_type(int64_t arg) {
|
||||
TORCH_CHECK(arg <= std::numeric_limits<int>::max(), "int64_t value is too large to be converted to int");
|
||||
TORCH_CHECK(arg >= std::numeric_limits<int>::min(), "int64_t value is too small to be converted to int");
|
||||
return arg;
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
To use this with your library functions, simply wrap them with make_pytorch_shim:
|
||||
|
||||
```cpp
|
||||
/*
|
||||
* From flash-attention
|
||||
*/
|
||||
m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
|
||||
```
|
||||
|
||||
### Testing & Benchmarking
|
||||
|
||||
1. Add pytest tests in [tests/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests), if you need to skip some test, please use `@pytest.mark.skipif`
|
||||
|
||||
```python
|
||||
@pytest.mark.skipif(
|
||||
skip_condition, reason="Nvfp4 Requires compute capability of 10 or above."
|
||||
)
|
||||
```
|
||||
|
||||
2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in [benchmark/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/benchmark)
|
||||
3. Run test suite
|
||||
|
||||
### FAQ
|
||||
|
||||
- When encountering this error while compiling using ccache: `ImportError: /usr/local/lib/python3.10/dist-packages/sgl_kernel/common_ops.abi3.so: undefined symbol: _ZN3c108ListType3getERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEENS_4Type24SingletonOrSharedTypePtrIS9_EE`, please modify the last command as follows to resolve it: `python3 -m uv build --wheel -Cbuild-dir=build . --color=always --no-build-isolation` .
|
||||
|
||||
### Release new version
|
||||
|
||||
Update version in [pyproject.toml](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/pyproject.toml) and [version.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel/version.py)
|
||||
488
sgl-kernel/THIRDPARTYNOTICES.txt
Normal file
488
sgl-kernel/THIRDPARTYNOTICES.txt
Normal file
@@ -0,0 +1,488 @@
|
||||
Notice for flashinfer-ai/flashinfer
|
||||
-------------------------------
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
-------------------------------------------------------------------------------------------------
|
||||
Some of the code in this project are adapted from other open-source projects with different
|
||||
licenses. This product also bundles some third-party components under other open source licenses.
|
||||
This section summarizes those components and their licenses.
|
||||
See licenses/ for text of these licenses.
|
||||
|
||||
BSD 3-Clause License
|
||||
--------------------
|
||||
|
||||
include/flashinfer/attention/hopper/epilogue.cuh
|
||||
include/flashinfer/attention/hopper/mainloop.cuh
|
||||
include/flashinfer/attention/hopper/kernel_traits.cuh
|
||||
include/flashinfer/attention/hopper/named_barrier.cuh
|
||||
include/flashinfer/attention/hopper/tile_scheduler.cuh
|
||||
include/flashinfer/attention/hopper/utils.cuh
|
||||
|
||||
BSD 3-Clause "New" License
|
||||
--------------------------
|
||||
|
||||
3rdparty/cutlass
|
||||
include/flashinfer/attention/hopper/block_sparse_gather.cuh
|
||||
|
||||
Notice for NVIDIA/TensorRT-LLM
|
||||
-------------------------------
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
Notice for deepseek-ai/DeepGEMM
|
||||
-------------------------------
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 DeepSeek
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
Notice for Dao-AILab/flash-attention
|
||||
-------------------------------
|
||||
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
153
sgl-kernel/benchmark/bench_activation.py
Normal file
153
sgl-kernel/benchmark/bench_activation.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# Benchmarks SGLang kernels versus vLLM across
|
||||
# (kernel, dtype, batch_size, seq_len, dim) and prints speed-up.
|
||||
import argparse
|
||||
import itertools
|
||||
import re
|
||||
from typing import List, Tuple
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import gelu_quick # activation-only kernel
|
||||
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
if not hasattr(vllm_ops, "silu_and_mul"):
|
||||
vllm_ops = torch.ops._C
|
||||
|
||||
|
||||
def str2int_list(arg: str) -> List[int]:
|
||||
if arg in ("", None):
|
||||
return []
|
||||
if re.fullmatch(r"\d+(,\d+)*", arg.strip()) is None:
|
||||
raise argparse.ArgumentTypeError(f"Bad int list: {arg}")
|
||||
return [int(x) for x in arg.split(",")]
|
||||
|
||||
|
||||
def calculate_diff(
|
||||
kernel: str, dtype: torch.dtype, batch_size: int, seq_len: int, dim: int
|
||||
) -> bool:
|
||||
"""Compare vLLM with SGLang for one shape."""
|
||||
device = torch.device("cuda")
|
||||
|
||||
# activation-only quick GELU
|
||||
if kernel == "gelu_quick":
|
||||
x = torch.randn(batch_size, seq_len, dim, dtype=dtype, device=device)
|
||||
ref_out = torch.zeros_like(x)
|
||||
getattr(vllm_ops, kernel)(ref_out, x)
|
||||
test_out = getattr(sgl_kernel, kernel)(x)
|
||||
# fused activation x mul kernels
|
||||
else:
|
||||
x = torch.randn(batch_size, seq_len, 2 * dim, dtype=dtype, device=device)
|
||||
ref_out = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
|
||||
getattr(vllm_ops, kernel)(ref_out, x)
|
||||
test_out = getattr(sgl_kernel, kernel)(x)
|
||||
|
||||
ok = torch.allclose(ref_out, test_out, rtol=1e-3, atol=1e-5)
|
||||
tag = "✅ match" if ok else "❌ mismatch"
|
||||
print(
|
||||
f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | "
|
||||
f"L={seq_len:3d} | D={dim:5d}] {tag}"
|
||||
)
|
||||
return ok
|
||||
|
||||
|
||||
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul", "gelu_quick"]
|
||||
dtypes = [torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[Tuple]:
|
||||
return list(itertools.product(kernels, dtypes, bsizes, slens, dims_))
|
||||
|
||||
|
||||
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16
|
||||
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64
|
||||
default_dims = [2**i for i in range(7, 15)] # 128...16384
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["kernel", "dtype", "batch_size", "seq_len", "dim"],
|
||||
x_vals=[],
|
||||
line_arg="provider",
|
||||
line_vals=["vllm", "sglang", "speedup"],
|
||||
line_names=["vLLM", "SGL Kernel", "Speed-up (x)"],
|
||||
styles=[("blue", "-"), ("green", "-"), ("red", "--")],
|
||||
ylabel="µs (median) or × (speed-up)",
|
||||
plot_name="activation-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
|
||||
device = torch.device("cuda")
|
||||
in_mult = 1 if kernel == "gelu_quick" else 2
|
||||
x = torch.randn(batch_size, seq_len, in_mult * dim, dtype=dtype, device=device)
|
||||
y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
|
||||
|
||||
vllm_kernel = getattr(vllm_ops, kernel)
|
||||
sglang_kernel = getattr(sgl_kernel, kernel)
|
||||
|
||||
def baseline():
|
||||
tmp = y0.clone()
|
||||
vllm_kernel(tmp, x)
|
||||
return tmp
|
||||
|
||||
def sglang():
|
||||
return sglang_kernel(x)
|
||||
|
||||
# one-time correctness check
|
||||
if provider == "vllm" and not calculate_diff(
|
||||
kernel, dtype, batch_size, seq_len, dim
|
||||
):
|
||||
raise ValueError("Mismatch – abort benchmark")
|
||||
|
||||
# timing helper
|
||||
def timed(fn):
|
||||
for _ in range(5):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
ms, qmin, qmax = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
|
||||
return 1000 * ms, 1000 * qmax, 1000 * qmin
|
||||
|
||||
if provider == "vllm":
|
||||
return timed(baseline)
|
||||
if provider == "sglang":
|
||||
return timed(sglang)
|
||||
|
||||
# provider == "speedup"
|
||||
t_ref, _, _ = timed(baseline)
|
||||
t_sgl, _, _ = timed(sglang)
|
||||
spd = t_ref / t_sgl
|
||||
return (spd, spd, spd)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
p = argparse.ArgumentParser("Activation kernel benchmark")
|
||||
p.add_argument("--batch_sizes", type=str2int_list, default=default_batch_sizes)
|
||||
p.add_argument("--seq_lens", type=str2int_list, default=default_seq_lens)
|
||||
p.add_argument("--dims", type=str2int_list, default=default_dims)
|
||||
p.add_argument("--verify_only", action="store_true")
|
||||
args = p.parse_args()
|
||||
|
||||
# coerce lists
|
||||
if isinstance(args.batch_sizes, str):
|
||||
args.batch_sizes = str2int_list(args.batch_sizes)
|
||||
if isinstance(args.seq_lens, str):
|
||||
args.seq_lens = str2int_list(args.seq_lens)
|
||||
if isinstance(args.dims, str):
|
||||
args.dims = str2int_list(args.dims)
|
||||
|
||||
# patch perf_report grid
|
||||
benchmark_grid = make_configs(args.batch_sizes, args.seq_lens, args.dims)
|
||||
if hasattr(benchmark, "benchmarks"):
|
||||
benchmark.benchmarks.x_vals = benchmark_grid
|
||||
else:
|
||||
benchmark.benchmark.x_vals = benchmark_grid
|
||||
|
||||
if args.verify_only:
|
||||
ok = calculate_diff("gelu_quick", torch.float16, 1, 1, args.dims[0])
|
||||
print("✅ sanity pass" if ok else "❌ mismatch")
|
||||
else:
|
||||
benchmark.run(print_data=True)
|
||||
118
sgl-kernel/benchmark/bench_awq_dequant.py
Normal file
118
sgl-kernel/benchmark/bench_awq_dequant.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import itertools
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import awq_dequantize
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
|
||||
def vllm_awq_dequantize(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
|
||||
|
||||
|
||||
def sglang_awq_dequantize(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
return awq_dequantize(qweight, scales, qzeros)
|
||||
|
||||
|
||||
def calculate_diff(qweight_row: int, qweight_col: int):
|
||||
"""Calculate difference between VLLM and SGLang implementations."""
|
||||
device = torch.device("cuda")
|
||||
qweight = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qweight_row, qweight_col),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
group_size = qweight_row
|
||||
scales_row = qweight_row // group_size
|
||||
scales_col = qweight_col * 8
|
||||
scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
|
||||
qzeros = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(scales_row, qweight_col),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
vllm_out = vllm_awq_dequantize(qweight, scales, qzeros)
|
||||
sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)
|
||||
|
||||
output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
|
||||
|
||||
if torch.allclose(
|
||||
vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
||||
):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
qweight_row_range = [3584, 18944, 128, 256, 512, 1024]
|
||||
qweight_cols_range = [448, 576, 4736, 16, 32, 64, 128]
|
||||
|
||||
configs = list(itertools.product(qweight_row_range, qweight_cols_range))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["qweight_row", "qweight_col"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["vllm", "sglang"],
|
||||
line_names=["VLLM", "SGL Kernel"],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="awq-dequantize-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(qweight_row, qweight_col, provider):
|
||||
dtype = torch.float16
|
||||
device = torch.device("cuda")
|
||||
qweight = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qweight_row, qweight_col),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
group_size = qweight_row
|
||||
scales_row = qweight_row // group_size
|
||||
scales_col = qweight_col * 8
|
||||
scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
|
||||
qzeros = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(scales_row, qweight_col),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "vllm":
|
||||
fn = lambda: vllm_awq_dequantize(
|
||||
qweight.clone(), scales.clone(), qzeros.clone()
|
||||
)
|
||||
elif provider == "sglang":
|
||||
fn = lambda: sglang_awq_dequantize(
|
||||
qweight.clone(), scales.clone(), qzeros.clone()
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
calculate_diff(qweight_row=3584, qweight_col=448)
|
||||
benchmark.run(print_data=True)
|
||||
145
sgl-kernel/benchmark/bench_cutlass_mla.py
Normal file
145
sgl-kernel/benchmark/bench_cutlass_mla.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
|
||||
|
||||
bs_range = [1, 8, 32, 64, 128, 256]
|
||||
qlen_range = [1, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
||||
|
||||
configs = list(itertools.product(bs_range, qlen_range))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len"],
|
||||
x_vals=configs,
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=[
|
||||
"128 heads",
|
||||
"64 heads",
|
||||
"32 heads",
|
||||
"16 heads",
|
||||
],
|
||||
line_names=[
|
||||
"128 heads",
|
||||
"64 heads",
|
||||
"32 heads",
|
||||
"16 heads",
|
||||
],
|
||||
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
|
||||
ylabel="GB/s",
|
||||
plot_name="cutlass mla",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
|
||||
d = 576
|
||||
dn = 64
|
||||
dv = 512
|
||||
|
||||
h_q_map = {
|
||||
"128": 128,
|
||||
"64": 64,
|
||||
"32": 32,
|
||||
"16": 16,
|
||||
}
|
||||
parsed_h_q = next(
|
||||
(value for key, value in h_q_map.items() if key in provider), None
|
||||
)
|
||||
|
||||
if parsed_h_q is None:
|
||||
raise ValueError(f"Unknown head configuration in provider: {provider}")
|
||||
h_q = parsed_h_q
|
||||
|
||||
seq_lens = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda")
|
||||
max_seq_len = seq_lens.max().item()
|
||||
block_num = (max_seq_len + block_size - 1) // block_size
|
||||
|
||||
# Pad block_num so that small blocks can be packed into full 128-sized CUTLASS tiles.
|
||||
# One 128-wide tile can hold (128 // block_size) small blocks.
|
||||
pack_factor = 128 // block_size
|
||||
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor
|
||||
|
||||
qn = (
|
||||
torch.randn(h_q, batch_size, d - dn, dtype=torch.bfloat16, device="cuda")
|
||||
* 100.0
|
||||
)
|
||||
qr = torch.randn(batch_size, h_q, dn, dtype=torch.bfloat16, device="cuda") * 100.0
|
||||
block_table = torch.randint(
|
||||
0,
|
||||
batch_size * block_num,
|
||||
(batch_size, block_num),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
kv_cache = torch.randn(
|
||||
block_table.numel(), block_size, d, dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
|
||||
workspace_size = cutlass_mla_get_workspace_size(
|
||||
block_num * block_size, batch_size, num_kv_splits=num_kv_splits
|
||||
)
|
||||
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: cutlass_mla_decode(
|
||||
qn.transpose(0, 1),
|
||||
qr,
|
||||
kv_cache,
|
||||
seq_lens,
|
||||
block_table,
|
||||
workspace,
|
||||
1.44,
|
||||
num_kv_splits,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
q_size = qn.numel() * qn.element_size() + qr.numel() * qr.element_size()
|
||||
|
||||
gbps = (
|
||||
lambda ms: (
|
||||
q_size + q_size * dv / d + kv_cache.numel() * kv_cache.element_size()
|
||||
)
|
||||
* 1e-9
|
||||
/ (ms * 1e-3)
|
||||
)
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--block-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[1, 32, 64, 128],
|
||||
help="List of batch sizes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-kv-splits",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[-1],
|
||||
help="List of batch sizes",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
for block_size in args.block_sizes:
|
||||
for kv_split in args.num_kv_splits:
|
||||
print(f"block_size={block_size}, num_kv_splits={kv_split}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path="bench_blackwell_mla_res",
|
||||
block_size=block_size,
|
||||
num_kv_splits=kv_split,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
57
sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py
Normal file
57
sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import dsv3_fused_a_gemm
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens"],
|
||||
x_vals=[i + 1 for i in range(16)],
|
||||
x_log=False,
|
||||
line_arg="impl",
|
||||
line_vals=["torch", "sgl-kernel"],
|
||||
line_names=["torch (bf16)", "dsv3_fused_a_gemm"],
|
||||
styles=[("blue", "-"), ("orange", "-")],
|
||||
ylabel="TFLOPs",
|
||||
plot_name="bf16 dsv3 fused a GEMM throughput",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(num_tokens, impl):
|
||||
kHdIn = 7168
|
||||
kHdOut = 2112
|
||||
M, K, N = num_tokens, kHdIn, kHdOut
|
||||
|
||||
mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
|
||||
mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").transpose(0, 1)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if impl == "torch":
|
||||
|
||||
def runner():
|
||||
F.linear(mat_a, mat_b.T)
|
||||
|
||||
elif impl == "sgl-kernel":
|
||||
|
||||
def runner():
|
||||
dsv3_fused_a_gemm(mat_a, mat_b)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
|
||||
|
||||
def tflops(t_ms):
|
||||
flops = 2 * M * K * N
|
||||
return flops / (t_ms * 1e-3) / 1e12
|
||||
|
||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
args = parser.parse_args()
|
||||
|
||||
benchmark.run(print_data=True, show_plots=True, save_path="bench_dsv3_gemm")
|
||||
127
sgl-kernel/benchmark/bench_dsv3_router_gemm.py
Normal file
127
sgl-kernel/benchmark/bench_dsv3_router_gemm.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import dsv3_router_gemm
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens"],
|
||||
x_vals=[i + 1 for i in range(16)],
|
||||
x_log=False,
|
||||
line_arg="impl",
|
||||
line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"],
|
||||
line_names=[
|
||||
"torch-256",
|
||||
"dsv3_router_gemm-256",
|
||||
"torch-384",
|
||||
"dsv3_router_gemm-384",
|
||||
],
|
||||
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
|
||||
ylabel="TFLOPs",
|
||||
plot_name="input-bf16-output-bf16 dsv3 router gemm throughput",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark_bf16_output(num_tokens, impl):
|
||||
# M: num_tokens, K: hidden_dim, N: num_experts
|
||||
M, K = num_tokens, 7168
|
||||
|
||||
if impl == "torch-256" or impl == "sgl-kernel-256":
|
||||
N = 256
|
||||
elif impl == "torch-384" or impl == "sgl-kernel-384":
|
||||
N = 384
|
||||
else:
|
||||
raise ValueError(f"Unknown impl: {impl}")
|
||||
|
||||
mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
|
||||
mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous()
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if impl == "torch-256" or impl == "torch-384":
|
||||
|
||||
def runner():
|
||||
F.linear(mat_a, mat_b)
|
||||
|
||||
elif impl == "sgl-kernel-256" or impl == "sgl-kernel-384":
|
||||
|
||||
def runner():
|
||||
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
|
||||
|
||||
def tflops(t_ms):
|
||||
flops = 2 * M * K * N
|
||||
return flops / (t_ms * 1e-3) / 1e12
|
||||
|
||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens"],
|
||||
x_vals=[i + 1 for i in range(16)],
|
||||
x_log=False,
|
||||
line_arg="impl",
|
||||
line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"],
|
||||
line_names=[
|
||||
"torch-256",
|
||||
"dsv3_router_gemm-256",
|
||||
"torch-384",
|
||||
"dsv3_router_gemm-384",
|
||||
],
|
||||
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
|
||||
ylabel="TFLOPs",
|
||||
plot_name="input-bf16-output-fp32 dsv3 router gemm throughput",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark_float_output(num_tokens, impl):
|
||||
# M: num_tokens, K: hidden_dim, N: num_experts
|
||||
M, K = num_tokens, 7168
|
||||
|
||||
if impl == "torch-256" or impl == "sgl-kernel-256":
|
||||
N = 256
|
||||
elif impl == "torch-384" or impl == "sgl-kernel-384":
|
||||
N = 384
|
||||
else:
|
||||
raise ValueError(f"Unknown impl: {impl}")
|
||||
|
||||
mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
|
||||
mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous()
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if impl == "torch-256" or impl == "torch-384":
|
||||
|
||||
def runner():
|
||||
F.linear(mat_a, mat_b).to(torch.float32)
|
||||
|
||||
elif impl == "sgl-kernel-256" or impl == "sgl-kernel-384":
|
||||
|
||||
def runner():
|
||||
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
|
||||
|
||||
def tflops(t_ms):
|
||||
flops = 2 * M * K * N
|
||||
return flops / (t_ms * 1e-3) / 1e12
|
||||
|
||||
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
args = parser.parse_args()
|
||||
|
||||
benchmark_bf16_output.run(
|
||||
print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm"
|
||||
)
|
||||
benchmark_float_output.run(
|
||||
print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm"
|
||||
)
|
||||
210
sgl-kernel/benchmark/bench_fp4_gemm.py
Executable file
210
sgl-kernel/benchmark/bench_fp4_gemm.py
Executable file
@@ -0,0 +1,210 @@
|
||||
import argparse
|
||||
import copy
|
||||
import csv
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import triton
|
||||
from flashinfer import mm_fp4
|
||||
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
|
||||
FLOAT4_E2M1_MAX = 6.0
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
|
||||
def get_weight_shapes(args):
|
||||
models_tps = args.tp_sizes
|
||||
|
||||
if models_tps == [4]:
|
||||
return [[1024, 3584], [7168, 256], [7168, 2304], [9216, 3584]]
|
||||
|
||||
if models_tps == [8]:
|
||||
return [[512, 3584], [7168, 128], [7168, 1152], [4608, 3584]]
|
||||
return [
|
||||
[1024, 3584],
|
||||
[7168, 256],
|
||||
[7168, 2304],
|
||||
[9216, 3584],
|
||||
[512, 3584],
|
||||
[7168, 128],
|
||||
[7168, 1152],
|
||||
[4608, 3584],
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
8192,
|
||||
16384,
|
||||
],
|
||||
# x_vals = [64],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=["cutlass", "cudnn", "trtllm"],
|
||||
line_names=["baseline cutlass fp4", "cudnn fp4", "trtllm fp4"],
|
||||
styles=[("red", "solid"), ("blue", "solid"), ("green", "solid")],
|
||||
ylabel="latency (ms)",
|
||||
plot_name="fp4_gemm_benchmark",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file):
|
||||
M = batch_size
|
||||
packed_k = K
|
||||
K = 2 * packed_k
|
||||
a_dtype = torch.randn((M, K), dtype=dtype, device="cuda")
|
||||
b_dtype = torch.randn((N, K), dtype=dtype, device="cuda")
|
||||
a_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
b_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
|
||||
alpha = 1.0 / (a_global_scale * b_global_scale)
|
||||
a_fp4, a_scale_interleaved = scaled_fp4_quant(a_dtype, a_global_scale)
|
||||
# print("a_fp4", a_fp4)
|
||||
b_fp4, b_scale_interleaved = scaled_fp4_quant(b_dtype, b_global_scale)
|
||||
res_fi = torch.empty((M, N), dtype=dtype, device="cuda")
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "cutlass":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: cutlass_scaled_fp4_mm(
|
||||
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "cudnn":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: mm_fp4(
|
||||
a_fp4,
|
||||
b_fp4.T,
|
||||
a_scale_interleaved,
|
||||
b_scale_interleaved.T,
|
||||
alpha,
|
||||
dtype,
|
||||
res_fi,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "trtllm":
|
||||
a_scale_interleaved = a_scale_interleaved.to(torch.uint8)
|
||||
b_scale_interleaved = b_scale_interleaved.to(torch.uint8)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: mm_fp4(
|
||||
a_fp4,
|
||||
b_fp4.T,
|
||||
a_scale_interleaved,
|
||||
b_scale_interleaved.T,
|
||||
alpha,
|
||||
dtype,
|
||||
res_fi,
|
||||
backend="trtllm",
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if correctness:
|
||||
res_cutlass = cutlass_scaled_fp4_mm(
|
||||
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
|
||||
)
|
||||
mm_fp4(
|
||||
a_fp4,
|
||||
b_fp4.T,
|
||||
a_scale_interleaved,
|
||||
b_scale_interleaved.T,
|
||||
alpha,
|
||||
dtype,
|
||||
res_fi,
|
||||
backend="cudnn",
|
||||
)
|
||||
assert torch.allclose(
|
||||
res_fi, res_cutlass, atol=1e-3, rtol=1e-3
|
||||
), "cudnn fp4 doesn't match cutlass fp4"
|
||||
mm_fp4(
|
||||
a_fp4,
|
||||
b_fp4.T,
|
||||
a_scale_interleaved,
|
||||
b_scale_interleaved.T,
|
||||
alpha,
|
||||
dtype,
|
||||
res_fi,
|
||||
backend="trtllm",
|
||||
)
|
||||
assert torch.allclose(
|
||||
res_fi, res_cutlass, atol=1e-3, rtol=1e-3
|
||||
), "trtllm fp4 doesn't match cutlass fp4"
|
||||
|
||||
if csv_file:
|
||||
with open(csv_file, "a", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow([provider, M, N, K, ms])
|
||||
|
||||
return ms, min_ms, max_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[1],
|
||||
help="List of tensor parallel sizes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=torch.dtype,
|
||||
default=torch.bfloat16,
|
||||
help="Data type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--correctness",
|
||||
action="store_true",
|
||||
help="Check correctness",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--csv",
|
||||
type=str,
|
||||
default="results_cutlass_cudnn.csv",
|
||||
help="CSV file to save results",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.csv:
|
||||
with open(args.csv, "w", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(["provider", "m", "n", "k", "time_ms"])
|
||||
|
||||
NKs = get_weight_shapes(args)
|
||||
for N, K in NKs:
|
||||
print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path="bench_fp4_res",
|
||||
N=N,
|
||||
K=K,
|
||||
dtype=args.dtype,
|
||||
correctness=args.correctness,
|
||||
csv_file=args.csv,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
183
sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py
Normal file
183
sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import deep_gemm
|
||||
import torch
|
||||
import triton
|
||||
from deep_gemm import get_col_major_tma_aligned_tensor
|
||||
from sgl_kernel import fp8_blockwise_scaled_mm
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
w8a8_block_fp8_matmul_triton as w8a8_block_fp8_matmul,
|
||||
)
|
||||
|
||||
|
||||
def get_weight_shapes(args):
|
||||
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model.
|
||||
# cannot TP
|
||||
total = [
|
||||
(512 + 64, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(7168, 16384),
|
||||
(7168, 18432),
|
||||
]
|
||||
# N can TP
|
||||
n_tp = [
|
||||
(18432 * 2, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(24576, 1536),
|
||||
(4096, 7168),
|
||||
]
|
||||
# K can TP
|
||||
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
|
||||
# only support Deepseek-V3
|
||||
SUPPORT_MODEL = ["deepseek-ai/DeepSeek-V3"]
|
||||
|
||||
weight_shapes = []
|
||||
for model, tp_size in models_tps:
|
||||
assert model in SUPPORT_MODEL
|
||||
for t in total:
|
||||
new_t = [t[0], t[1], model]
|
||||
weight_shapes.append(new_t)
|
||||
for n_t in n_tp:
|
||||
new_t = [n_t[0] // tp_size, n_t[1], model]
|
||||
weight_shapes.append(new_t)
|
||||
for k_t in k_tp:
|
||||
new_t = [k_t[0], k_t[1] // tp_size, model]
|
||||
weight_shapes.append(new_t)
|
||||
return weight_shapes
|
||||
|
||||
|
||||
def cdiv(a: int, b: int) -> int:
|
||||
"""Ceiling division."""
|
||||
return -(a // -b)
|
||||
|
||||
|
||||
def fp8_gemm_deepgemm(
|
||||
x_fp8: torch.Tensor,
|
||||
x_scale: torch.Tensor,
|
||||
y_fp8: torch.Tensor,
|
||||
y_scale: torch.Tensor,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
):
|
||||
"""DeepGEMM implementation of FP8 GEMM"""
|
||||
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Run DeepGEMM kernel
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
|
||||
return out
|
||||
|
||||
|
||||
def scale_shape(shape, group_shape):
|
||||
assert len(shape) == len(group_shape)
|
||||
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=["vllm", "sgl-kernel", "triton", "deepgemm"],
|
||||
line_names=["vllm", "sgl-kernel", "sglang triton", "deepgemm"],
|
||||
styles=[("blue", "-"), ("orange", "-"), ("red", "-"), ("yellow", "-")],
|
||||
ylabel="GB/s",
|
||||
plot_name="fp8 blockwise scaled matmul",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K):
|
||||
M = batch_size
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
a_fp32 = (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
||||
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
b_fp32 = (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
||||
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
scale_a_group_shape = (1, 128)
|
||||
scale_b_group_shape = (128, 128)
|
||||
scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape)
|
||||
scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape)
|
||||
|
||||
scale_a = torch.randn(scale_a_shape, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.randn(scale_b_shape, device="cuda", dtype=torch.float32)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "sgl-kernel":
|
||||
scale_a = scale_a.t().contiguous().t()
|
||||
b_fp8, scale_b = b_fp8.t(), scale_b.t()
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: fp8_blockwise_scaled_mm(
|
||||
a_fp8, b_fp8, scale_a, scale_b, torch.float16
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "vllm":
|
||||
scale_a = scale_a.t().contiguous().t()
|
||||
b_fp8, scale_b = b_fp8.t(), scale_b.t()
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "triton":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: w8a8_block_fp8_matmul(
|
||||
a_fp8, b_fp8, scale_a, scale_b, [128, 128], torch.float16
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "deepgemm":
|
||||
scale_a_col_major = get_col_major_tma_aligned_tensor(scale_a.clone())
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: fp8_gemm_deepgemm(
|
||||
a_fp8, scale_a_col_major, b_fp8, scale_b, M, N, K
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["deepseek-ai/DeepSeek-V3"],
|
||||
help="List of models to benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[1],
|
||||
help="List of tensor parallel sizes",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
NK_model_names = get_weight_shapes(args)
|
||||
for N, K, model_name in NK_model_names:
|
||||
if N % 128 != 0 or K % 128 != 0:
|
||||
print(f"Skip {N=}, {K=} now")
|
||||
continue
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path="bench_fp8_blockwise_res",
|
||||
N=N,
|
||||
K=K,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
328
sgl-kernel/benchmark/bench_fp8_blockwise_group_gemm.py
Normal file
328
sgl-kernel/benchmark/bench_fp8_blockwise_group_gemm.py
Normal file
@@ -0,0 +1,328 @@
|
||||
import argparse
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
|
||||
import deep_gemm
|
||||
import torch
|
||||
from sgl_kernel import fp8_blockwise_scaled_grouped_mm
|
||||
|
||||
|
||||
def get_m_alignment_for_contiguous_layout():
|
||||
return 128
|
||||
|
||||
|
||||
def ceil_div(x: int, y: int) -> int:
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
pad_size = (128 - (n % 128)) % 128
|
||||
x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
|
||||
x_view = x.view(m, -1, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
|
||||
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros(
|
||||
(ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device
|
||||
)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(
|
||||
x_view.size(0), x_view.size(2)
|
||||
)
|
||||
|
||||
|
||||
def construct_contiguous_grouped(
|
||||
num_groups: int, expected_m_per_group: int, k: int, n: int
|
||||
) -> Tuple[
|
||||
int,
|
||||
Tuple[torch.Tensor, torch.Tensor],
|
||||
Tuple[torch.Tensor, torch.Tensor],
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
alignment = get_m_alignment_for_contiguous_layout()
|
||||
group_ms = [int(expected_m_per_group) for _ in range(num_groups)]
|
||||
m = sum([ceil_div(x, alignment) * alignment for x in group_ms])
|
||||
|
||||
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||
y = torch.randn((num_groups, n, k), device="cuda", dtype=torch.bfloat16)
|
||||
m_indices = torch.empty(m, device="cuda", dtype=torch.int32)
|
||||
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
start = 0
|
||||
for i, group_m in enumerate(group_ms):
|
||||
actual_end = start + group_m
|
||||
aligned_end = start + ceil_div(group_m, alignment) * alignment
|
||||
m_indices[start:actual_end] = i
|
||||
m_indices[actual_end:aligned_end] = -1
|
||||
start = aligned_end
|
||||
|
||||
assert m % 4 == 0, f"TMA alignment error: {m}"
|
||||
x_fp8 = per_token_cast_to_fp8(x)
|
||||
y_fp8 = (
|
||||
torch.empty_like(y, dtype=torch.float8_e4m3fn),
|
||||
torch.empty(
|
||||
(num_groups, ceil_div(n, 128), k // 128), device="cuda", dtype=torch.float
|
||||
),
|
||||
)
|
||||
for i in range(num_groups):
|
||||
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
|
||||
|
||||
return m, x_fp8, y_fp8, m_indices, out
|
||||
|
||||
|
||||
def bench_deepgemm(
|
||||
expected_m_per_group: int,
|
||||
n: int,
|
||||
k: int,
|
||||
num_groups: int,
|
||||
num_warmup: int,
|
||||
num_run: int,
|
||||
) -> Tuple[float, int]:
|
||||
# construct tensors
|
||||
m, x_fp8, y_fp8, m_indices, out = construct_contiguous_grouped(
|
||||
num_groups, expected_m_per_group, k, n
|
||||
)
|
||||
|
||||
def run_deepgemm():
|
||||
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(x_fp8, y_fp8, out, m_indices)
|
||||
|
||||
# warmup
|
||||
for _ in range(num_warmup):
|
||||
run_deepgemm()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# run
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
latencies: list[float] = []
|
||||
start_event.record()
|
||||
for _ in range(num_run):
|
||||
run_deepgemm()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
avg = start_event.elapsed_time(end_event) / num_run * 1000 # us
|
||||
|
||||
return avg, m
|
||||
|
||||
|
||||
def bench_cutlass(
|
||||
expected_m_per_group: int,
|
||||
n: int,
|
||||
k: int,
|
||||
num_groups: int,
|
||||
num_warmup: int,
|
||||
num_run: int,
|
||||
) -> Tuple[float, int]:
|
||||
device = "cuda"
|
||||
alignment = 16
|
||||
n_g = ceil_div(n, alignment) * alignment
|
||||
k_g = ceil_div(k, alignment) * alignment
|
||||
out_dtype = torch.bfloat16
|
||||
|
||||
expert_offsets = torch.zeros((num_groups + 1), device=device, dtype=torch.int32)
|
||||
problem_sizes = torch.zeros((num_groups, 3), device=device, dtype=torch.int32)
|
||||
layout_sfa = torch.zeros((num_groups, 5), device=device, dtype=torch.int32)
|
||||
layout_sfb = torch.zeros((num_groups, 5), device=device, dtype=torch.int32)
|
||||
|
||||
a_tensors = []
|
||||
b_tensors = []
|
||||
a_scales_tensors = []
|
||||
b_scales_tensors = []
|
||||
|
||||
# TODO(@TianQiLin666666): Unique group_ms in all bench function
|
||||
group_ms = [
|
||||
alignment * ceil_div(int(expected_m_per_group), alignment)
|
||||
for _ in range(num_groups)
|
||||
]
|
||||
for g in range(num_groups):
|
||||
m_g = group_ms[g]
|
||||
expert_offsets[g + 1] = expert_offsets[g] + m_g
|
||||
problem_sizes[g][:] = torch.tensor([m_g, n_g, k_g], device=device)
|
||||
|
||||
a_g, a_scale = per_token_cast_to_fp8(torch.randn((m_g, k_g), device=device))
|
||||
b_g, b_scale = per_block_cast_to_fp8(torch.randn((n_g, k_g), device=device).t())
|
||||
a_tensors.append(a_g)
|
||||
b_tensors.append(b_g)
|
||||
a_scales_tensors.append(a_scale)
|
||||
b_scales_tensors.append(b_scale)
|
||||
|
||||
a_stack = torch.empty(
|
||||
(expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
b_stack = torch.empty(
|
||||
(num_groups, n_g, k_g), device=device, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
for g in range(num_groups):
|
||||
a_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[g]
|
||||
b_stack[g] = b_tensors[g].t()
|
||||
b_stack = b_stack.transpose(1, 2)
|
||||
|
||||
a_scale_stack = torch.empty(
|
||||
(expert_offsets[-1], k_g // 128), device=device, dtype=torch.float32
|
||||
)
|
||||
b_scale_stack = torch.empty(
|
||||
(num_groups, n_g // 128, k_g // 128), device=device, dtype=torch.float32
|
||||
)
|
||||
|
||||
for g in range(num_groups):
|
||||
a_scale_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_scales_tensors[g]
|
||||
b_scale_stack[g] = b_scales_tensors[g].t()
|
||||
b_scale_stack = b_scale_stack.transpose(1, 2)
|
||||
|
||||
c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype)
|
||||
a_strides = torch.full(
|
||||
(num_groups,), a_stack.stride(0), device=device, dtype=torch.int64
|
||||
)
|
||||
c_strides = torch.full(
|
||||
(num_groups,), c_out.stride(0), device=device, dtype=torch.int64
|
||||
)
|
||||
workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
|
||||
a_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
|
||||
b_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
|
||||
out_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
|
||||
a_scales_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
|
||||
b_scales_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
|
||||
|
||||
def run_cutlass():
|
||||
fp8_blockwise_scaled_grouped_mm(
|
||||
c_out,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a_stack,
|
||||
b_stack,
|
||||
a_scale_stack,
|
||||
b_scale_stack,
|
||||
a_strides,
|
||||
a_strides,
|
||||
c_strides,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets[:-1],
|
||||
workspace,
|
||||
)
|
||||
|
||||
# warmup
|
||||
for _ in range(num_warmup):
|
||||
run_cutlass()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# run
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
for _ in range(num_run):
|
||||
run_cutlass()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
avg = start_event.elapsed_time(end_event) / num_run * 1000 # us
|
||||
|
||||
return avg, expert_offsets[-1]
|
||||
|
||||
|
||||
def bench_sglang_triton(
|
||||
expected_m_per_group: int,
|
||||
n: int,
|
||||
k: int,
|
||||
num_groups: int,
|
||||
num_warmup: int,
|
||||
num_run: int,
|
||||
) -> Tuple[float, int]:
|
||||
pass
|
||||
|
||||
|
||||
benchmark_kernels = {
|
||||
"deepgemm": bench_deepgemm,
|
||||
"cutlass": bench_cutlass,
|
||||
# "triton": bench_sglang_triton,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShapeArg:
|
||||
expected_m_per_group: int
|
||||
n: int
|
||||
k: int
|
||||
num_groups: int
|
||||
|
||||
|
||||
def benchmark_one_shape(
|
||||
shape_args: List[ShapeArg],
|
||||
num_warmup: int,
|
||||
num_run: int,
|
||||
):
|
||||
for shape in shape_args:
|
||||
print(
|
||||
f"\nBenchmark: expected_m_per_group={shape.expected_m_per_group}, n={shape.n}, k={shape.k}, num_groups={shape.num_groups}"
|
||||
)
|
||||
for kernel_name, kernel_func in benchmark_kernels.items():
|
||||
average_time, m = kernel_func(
|
||||
shape.expected_m_per_group,
|
||||
shape.n,
|
||||
shape.k,
|
||||
shape.num_groups,
|
||||
num_warmup,
|
||||
num_run,
|
||||
)
|
||||
print(f"{kernel_name}: {average_time} us")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-warmup", type=int, default=3)
|
||||
parser.add_argument("--num-run", type=int, default=10)
|
||||
shape_args = [
|
||||
# Prefill, DeepSeek-R1, gateup, chunk_size = 4096, TP = 8
|
||||
ShapeArg(expected_m_per_group=128, n=512, k=7168, num_groups=256),
|
||||
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 8
|
||||
ShapeArg(expected_m_per_group=256, n=512, k=7168, num_groups=256),
|
||||
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 16
|
||||
ShapeArg(expected_m_per_group=256, n=256, k=7168, num_groups=256),
|
||||
# Prefill, DeepSeek-R1, gateup, chunk_size = 16384, TP = 16
|
||||
ShapeArg(expected_m_per_group=512, n=256, k=7168, num_groups=256),
|
||||
# Decode, DeepSeek-R1, gateup, bs = 32, TP = 8
|
||||
ShapeArg(expected_m_per_group=1, n=512, k=7168, num_groups=256),
|
||||
# Decode, DeepSeek-R1, gateup, bs = 64, TP = 16
|
||||
ShapeArg(expected_m_per_group=2, n=256, k=7168, num_groups=256),
|
||||
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, EP = 8
|
||||
ShapeArg(expected_m_per_group=256, n=4096, k=7168, num_groups=32),
|
||||
# Prefill, DeepSeek-R1, gateup, chunk_size = 16384, EP = 16
|
||||
ShapeArg(expected_m_per_group=512, n=4096, k=7168, num_groups=16),
|
||||
# Decode, DeepSeek-R1, gateup, bs = 128, EP = 8
|
||||
ShapeArg(expected_m_per_group=4, n=4096, k=7168, num_groups=32),
|
||||
# Decode, DeepSeek-R1, gateup, bs = 256, EP = 16
|
||||
ShapeArg(expected_m_per_group=8, n=4096, k=7168, num_groups=16),
|
||||
# Prefill, Qwen3-235B-A22B-FP8, gateup, chunk_size = 16384, TP = 4
|
||||
ShapeArg(expected_m_per_group=1024, n=768, k=4096, num_groups=128),
|
||||
# Prefill, Qwen3-235B-A22B-FP8, down, chunk_size = 16384, TP = 4
|
||||
ShapeArg(expected_m_per_group=1024, n=4096, k=384, num_groups=128),
|
||||
# Decode, Qwen3-235B-A22B-FP8, gateup, bs = 256, TP = 4
|
||||
ShapeArg(expected_m_per_group=16, n=768, k=4096, num_groups=128),
|
||||
# Decode, Qwen3-235B-A22B-FP8, down, bs = 256, TP = 4
|
||||
ShapeArg(expected_m_per_group=16, n=4096, k=384, num_groups=128),
|
||||
]
|
||||
args = parser.parse_args()
|
||||
benchmark_one_shape(shape_args, args.num_warmup, args.num_run)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
184
sgl-kernel/benchmark/bench_fp8_gemm.py
Normal file
184
sgl-kernel/benchmark/bench_fp8_gemm.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm
|
||||
from sgl_kernel import sgl_per_tensor_quant_fp8
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
|
||||
|
||||
# Weight Shapes are in the format
|
||||
# ([K, N], TP_SPLIT_DIM)
|
||||
# Example:
|
||||
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
|
||||
# - TP1 : K = 14336, N = 4096
|
||||
# - TP2 : K = 7168, N = 4096
|
||||
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
|
||||
# - TP1 : K = 4096, N = 6144
|
||||
# - TP4 : K = 4096, N = 1536
|
||||
|
||||
# TP1 shapes
|
||||
WEIGHT_SHAPES = {
|
||||
"meta-llama/Llama-3.1-8B-Instruct": [
|
||||
([4096, 6144], 1),
|
||||
([4096, 4096], 0),
|
||||
([4096, 28672], 1),
|
||||
([14336, 4096], 0),
|
||||
],
|
||||
"meta-llama/Llama-3.3-70B-Instruct": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 57344], 1),
|
||||
([28672, 8192], 0),
|
||||
],
|
||||
"mistralai/Mistral-Large-Instruct-2407": [
|
||||
([12288, 14336], 1),
|
||||
([12288, 12288], 0),
|
||||
([12288, 57344], 1),
|
||||
([28672, 12288], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-7B-Instruct": [
|
||||
([3584, 4608], 1),
|
||||
([3584, 3584], 0),
|
||||
([3584, 37888], 1),
|
||||
([18944, 3584], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-32B-Instruct": [
|
||||
([5120, 7168], 1),
|
||||
([5120, 5120], 0),
|
||||
([5120, 55296], 1),
|
||||
([27648, 5120], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-72B-Instruct": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 59136], 1),
|
||||
([29568, 8192], 0),
|
||||
],
|
||||
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
|
||||
([2048, 3072], 1),
|
||||
([2048, 4096], 1),
|
||||
([2048, 2048], 0),
|
||||
([2048, 576], 0),
|
||||
([2048, 21888], 1),
|
||||
([10944, 2048], 0),
|
||||
([2048, 2816], 1),
|
||||
([1408, 2048], 0),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def sglang_scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
fp8_type_: torch.dtype = torch.float8_e4m3fn
|
||||
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
|
||||
is_static = True
|
||||
if scale is None:
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
is_static = False
|
||||
sgl_per_tensor_quant_fp8(input, output, scale, is_static)
|
||||
|
||||
return output, scale
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=[
|
||||
"vllm-fp8-fp16",
|
||||
"vllm-fp8-bf16",
|
||||
"sglang-fp8-fp16",
|
||||
"sglang-fp8-bf16",
|
||||
],
|
||||
line_names=[
|
||||
"vllm-fp8-fp16",
|
||||
"vllm-fp8-bf16",
|
||||
"sglang-fp8-fp16",
|
||||
"sglang-fp8-bf16",
|
||||
],
|
||||
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
|
||||
ylabel="GB/s",
|
||||
plot_name="fp8 scaled matmul",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K):
|
||||
# M, N, K = batch_size, 4096, 8192
|
||||
M = batch_size
|
||||
a = torch.ones((M, K), device="cuda") * 5.0
|
||||
b = torch.ones((N, K), device="cuda") * 5.0
|
||||
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
dtype = torch.float16 if "fp16" in provider else torch.bfloat16
|
||||
|
||||
if "vllm-fp8" in provider:
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
|
||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
|
||||
b_fp8 = b_fp8.t()
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif "sglang-fp8" in provider:
|
||||
a_fp8, scale_a_fp8 = sglang_scaled_fp8_quant(a, scale_a)
|
||||
b_fp8, scale_b_fp8 = sglang_scaled_fp8_quant(b, scale_b)
|
||||
b_fp8 = b_fp8.t()
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: sgl_scaled_mm(
|
||||
a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
gbps = lambda ms: (2 * M * N * K + M * N) * a.element_size() * 1e-9 / (ms * 1e-3)
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
def prepare_shapes(args):
|
||||
KN_model_names = []
|
||||
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||
for model, tp_size in models_tps:
|
||||
assert model in WEIGHT_SHAPES
|
||||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||
KN.append(model)
|
||||
KN_model_names.append(KN)
|
||||
return KN_model_names
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||
help="List of models to benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[1],
|
||||
help="List of tensor parallel sizes",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
KN_model_names = prepare_shapes(args)
|
||||
for K, N, model_name in KN_model_names:
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True, show_plots=True, save_path="bench_fp8_res", N=N, K=K
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
146
sgl-kernel/benchmark/bench_int8_gemm.py
Normal file
146
sgl-kernel/benchmark/bench_int8_gemm.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import int8_scaled_mm
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
|
||||
|
||||
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||
|
||||
|
||||
WEIGHT_SHAPES = {
|
||||
"meta-llama/Llama-3.1-8B-Instruct": [
|
||||
([4096, 6144], 1),
|
||||
([4096, 4096], 0),
|
||||
([4096, 28672], 1),
|
||||
([14336, 4096], 0),
|
||||
],
|
||||
"meta-llama/Llama-3.3-70B-Instruct": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 57344], 1),
|
||||
([28672, 8192], 0),
|
||||
],
|
||||
"mistralai/Mistral-Large-Instruct-2407": [
|
||||
([12288, 14336], 1),
|
||||
([12288, 12288], 0),
|
||||
([12288, 57344], 1),
|
||||
([28672, 12288], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-7B-Instruct": [
|
||||
([3584, 4608], 1),
|
||||
([3584, 3584], 0),
|
||||
([3584, 37888], 1),
|
||||
([18944, 3584], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-32B-Instruct": [
|
||||
([5120, 7168], 1),
|
||||
([5120, 5120], 0),
|
||||
([5120, 55296], 1),
|
||||
([27648, 5120], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-72B-Instruct": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 59136], 1),
|
||||
([29568, 8192], 0),
|
||||
],
|
||||
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
|
||||
([2048, 3072], 1),
|
||||
([2048, 4096], 1),
|
||||
([2048, 2048], 0),
|
||||
([2048, 576], 0),
|
||||
([2048, 21888], 1),
|
||||
([10944, 2048], 0),
|
||||
([2048, 2816], 1),
|
||||
([1408, 2048], 0),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=["vllm", "sgl-kernel"],
|
||||
line_names=["vllm int8 gemm", "sgl-kernel int8 gemm"],
|
||||
styles=[("blue", "-"), ("orange", "-")],
|
||||
ylabel="GB/s",
|
||||
plot_name="int8 scaled matmul",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K):
|
||||
M = batch_size
|
||||
a = to_int8(torch.randn((M, K), device="cuda") * 5)
|
||||
b = to_int8(torch.randn((N, K), device="cuda").t() * 5)
|
||||
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
|
||||
bias = torch.randn((N,), device="cuda", dtype=torch.float16)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "sgl-kernel":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "vllm":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: vllm_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
gbps = (
|
||||
lambda ms: (
|
||||
(2 * M * N * K - M * N) * a.element_size()
|
||||
+ (3 * M * N) * scale_a.element_size()
|
||||
)
|
||||
* 1e-9
|
||||
/ (ms * 1e-3)
|
||||
)
|
||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
||||
|
||||
|
||||
def prepare_shapes(args):
|
||||
KN_model_names = []
|
||||
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||
for model, tp_size in models_tps:
|
||||
assert model in WEIGHT_SHAPES
|
||||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||
KN.append(model)
|
||||
KN_model_names.append(KN)
|
||||
return KN_model_names
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||
help="List of models to benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[1],
|
||||
help="List of tensor parallel sizes",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
KN_model_names = prepare_shapes(args)
|
||||
for K, N, model_name in KN_model_names:
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True, show_plots=True, save_path="bench_int8_res", N=N, K=K
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
299
sgl-kernel/benchmark/bench_lightning_attention_decode.py
Normal file
299
sgl-kernel/benchmark/bench_lightning_attention_decode.py
Normal file
@@ -0,0 +1,299 @@
|
||||
import itertools
|
||||
import math
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sgl_kernel import lightning_attention_decode
|
||||
|
||||
|
||||
def next_power_of_2(n):
|
||||
return 2 ** (int(math.ceil(math.log(n, 2))))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _decode_kernel(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
KV,
|
||||
Out,
|
||||
S,
|
||||
b: tl.constexpr,
|
||||
h: tl.constexpr,
|
||||
n: tl.constexpr,
|
||||
d: tl.constexpr,
|
||||
d_original: tl.constexpr,
|
||||
e: tl.constexpr,
|
||||
e_original: tl.constexpr,
|
||||
):
|
||||
off_bh = tl.program_id(0)
|
||||
off_h = off_bh % h
|
||||
|
||||
qk_offset = off_bh * n * d
|
||||
v_offset = off_bh * n * e
|
||||
o_offset = off_bh * n * e
|
||||
kv_offset = off_bh * d * e
|
||||
|
||||
s = tl.load(S + off_h)
|
||||
ratio = tl.exp(-s)
|
||||
|
||||
d_idx = tl.arange(0, d)
|
||||
e_idx = tl.arange(0, e)
|
||||
|
||||
# Create masks for original dimensions
|
||||
d_mask = d_idx < d_original
|
||||
e_mask = e_idx < e_original
|
||||
|
||||
# Load with masking
|
||||
q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0)
|
||||
k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0)
|
||||
v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0)
|
||||
|
||||
# Load KV with 2D masking
|
||||
kv = tl.load(
|
||||
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
|
||||
mask=(d_mask[:, None] & e_mask[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
# Compute outer product using element-wise operations
|
||||
k_v_prod = k[:, None] * v[None, :]
|
||||
kv = ratio * kv + k_v_prod
|
||||
|
||||
# Store KV with 2D masking
|
||||
tl.store(
|
||||
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
|
||||
kv.to(KV.dtype.element_ty),
|
||||
mask=(d_mask[:, None] & e_mask[None, :]),
|
||||
)
|
||||
|
||||
# Compute matrix-vector multiplication using element-wise operations and reduction
|
||||
o = tl.sum(q[:, None] * kv, axis=0)
|
||||
|
||||
# Store output with masking
|
||||
tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask)
|
||||
|
||||
|
||||
def triton_lightning_attn_decode(q, k, v, kv, s):
|
||||
"""Triton implementation of Lightning Attention decode operation"""
|
||||
b, h, n, d = q.shape
|
||||
e = v.shape[-1]
|
||||
assert n == 1, "Sequence length must be 1 in decode mode"
|
||||
|
||||
# Get padded dimensions (power of 2)
|
||||
d_padded = next_power_of_2(d)
|
||||
e_padded = next_power_of_2(e)
|
||||
|
||||
# Create output tensor (padded)
|
||||
o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
|
||||
|
||||
# Create padded tensors without actually padding the data
|
||||
q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device)
|
||||
k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device)
|
||||
v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
|
||||
kv_padded = torch.empty(
|
||||
b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device
|
||||
)
|
||||
|
||||
# Copy data to padded tensors
|
||||
q_padded[..., :d] = q
|
||||
k_padded[..., :d] = k
|
||||
v_padded[..., :e] = v
|
||||
kv_padded[..., :d, :e] = kv
|
||||
|
||||
# Launch kernel
|
||||
grid = (b * h, 1)
|
||||
_decode_kernel[grid](
|
||||
q_padded,
|
||||
k_padded,
|
||||
v_padded,
|
||||
kv_padded,
|
||||
o_padded,
|
||||
s,
|
||||
b=b,
|
||||
h=h,
|
||||
n=n,
|
||||
d=d_padded,
|
||||
d_original=d,
|
||||
e=e_padded,
|
||||
e_original=e,
|
||||
)
|
||||
|
||||
# Get unpadded outputs
|
||||
o = o_padded[..., :e]
|
||||
kv_out = kv_padded[..., :d, :e]
|
||||
|
||||
return o, kv_out
|
||||
|
||||
|
||||
def lightning_attention_decode_naive(q, k, v, past_kv, slope):
|
||||
"""Naive implementation of lightning attention decode"""
|
||||
original_dtype = q.dtype
|
||||
ratio = torch.exp(-slope) # [h, 1, 1]
|
||||
|
||||
kv = past_kv
|
||||
b, h, n, d = q.shape
|
||||
|
||||
output = []
|
||||
for i in range(n):
|
||||
kv = ratio * kv.to(torch.float32) + torch.einsum(
|
||||
"... n d, ... n e -> ... d e",
|
||||
k[:, :, i : i + 1],
|
||||
v[:, :, i : i + 1],
|
||||
)
|
||||
qkv = torch.einsum(
|
||||
"... n e, ... e d -> ... n d",
|
||||
q[:, :, i : i + 1].to(torch.float32),
|
||||
kv.to(torch.float32),
|
||||
)
|
||||
output.append(qkv)
|
||||
output = torch.cat(output, dim=-2)
|
||||
|
||||
return output.to(original_dtype), kv
|
||||
|
||||
|
||||
def lightning_attention_decode_kernel(q, k, v, past_kv, slope, output, new_kv):
|
||||
return lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv)
|
||||
|
||||
|
||||
def calculate_diff(batch_size):
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
num_heads = 64
|
||||
head_dim = 96
|
||||
seq_len = 1
|
||||
|
||||
q = torch.randn(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
k = torch.randn(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
v = torch.randn(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device)
|
||||
slope = torch.randn(num_heads, 1, 1, device=device)
|
||||
|
||||
output_naive, new_kv_naive = lightning_attention_decode_naive(
|
||||
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
||||
)
|
||||
|
||||
output_kernel = torch.empty_like(output_naive)
|
||||
new_kv_kernel = torch.empty_like(new_kv_naive)
|
||||
lightning_attention_decode_kernel(
|
||||
q.clone(),
|
||||
k.clone(),
|
||||
v.clone(),
|
||||
past_kv.clone(),
|
||||
slope.clone(),
|
||||
output_kernel,
|
||||
new_kv_kernel,
|
||||
)
|
||||
|
||||
output_triton, new_kv_triton = triton_lightning_attn_decode(
|
||||
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
||||
)
|
||||
|
||||
if (
|
||||
torch.allclose(output_naive, output_kernel, atol=1e-2, rtol=1e-2)
|
||||
and torch.allclose(output_naive, output_triton, atol=1e-2, rtol=1e-2)
|
||||
and torch.allclose(new_kv_naive, new_kv_kernel, atol=1e-2, rtol=1e-2)
|
||||
and torch.allclose(new_kv_naive, new_kv_triton, atol=1e-2, rtol=1e-2)
|
||||
):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
batch_size_range = [i for i in range(1, 65)] # 1 to 128
|
||||
configs = [(bs,) for bs in batch_size_range]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["naive", "kernel", "triton"],
|
||||
line_names=["PyTorch Naive", "SGL Kernel", "Triton"],
|
||||
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="lightning-attention-decode-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider):
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
num_heads = 64
|
||||
head_dim = 96
|
||||
seq_len = 1
|
||||
|
||||
q = torch.randn(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
k = torch.randn(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
v = torch.randn(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device)
|
||||
slope = torch.randn(num_heads, 1, 1, device=device)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "naive":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: lightning_attention_decode_naive(
|
||||
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "kernel":
|
||||
output = torch.empty(
|
||||
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
|
||||
)
|
||||
new_kv = torch.empty(batch_size, num_heads, head_dim, head_dim, device=device)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: lightning_attention_decode_kernel(
|
||||
q.clone(),
|
||||
k.clone(),
|
||||
v.clone(),
|
||||
past_kv.clone(),
|
||||
slope.clone(),
|
||||
output,
|
||||
new_kv,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "triton":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: triton_lightning_attn_decode(
|
||||
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/lightning_attention_decode_sgl/",
|
||||
help="Path to save lightning attention decode benchmark results",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run correctness test
|
||||
calculate_diff(batch_size=4)
|
||||
|
||||
# Run performance benchmark
|
||||
benchmark.run(print_data=True)
|
||||
401
sgl-kernel/benchmark/bench_moe_align_block_size.py
Normal file
401
sgl-kernel/benchmark/bench_moe_align_block_size.py
Normal file
@@ -0,0 +1,401 @@
|
||||
import argparse
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
||||
|
||||
try:
|
||||
from vllm import _custom_ops as ops
|
||||
except ImportError:
|
||||
ops = None
|
||||
|
||||
USE_RANDOM_PERM = False
|
||||
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage1(
|
||||
topk_ids_ptr,
|
||||
tokens_cnts_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
numel: tl.constexpr,
|
||||
tokens_per_thread: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
start_idx = pid * tokens_per_thread
|
||||
off_c = (pid + 1) * num_experts
|
||||
|
||||
for i in range(tokens_per_thread):
|
||||
if start_idx + i < numel:
|
||||
idx = tl.load(topk_ids_ptr + start_idx + i)
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
|
||||
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage2(
|
||||
tokens_cnts_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
last_cnt = 0
|
||||
for i in range(1, num_experts + 1):
|
||||
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
|
||||
last_cnt = last_cnt + token_cnt
|
||||
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage3(
|
||||
total_tokens_post_pad_ptr,
|
||||
tokens_cnts_ptr,
|
||||
cumsum_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
last_cumsum = 0
|
||||
off_cnt = num_experts * num_experts
|
||||
for i in range(1, num_experts + 1):
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
|
||||
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
|
||||
tl.store(cumsum_ptr + i, last_cumsum)
|
||||
tl.store(total_tokens_post_pad_ptr, last_cumsum)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage4(
|
||||
topk_ids_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
tokens_cnts_ptr,
|
||||
cumsum_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
numel: tl.constexpr,
|
||||
tokens_per_thread: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
start_idx = tl.load(cumsum_ptr + pid)
|
||||
end_idx = tl.load(cumsum_ptr + pid + 1)
|
||||
|
||||
for i in range(start_idx, end_idx, block_size):
|
||||
tl.store(expert_ids_ptr + i // block_size, pid)
|
||||
|
||||
start_idx = pid * tokens_per_thread
|
||||
off_t = pid * num_experts
|
||||
|
||||
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
|
||||
expert_id = tl.load(topk_ids_ptr + i)
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
|
||||
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
|
||||
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
|
||||
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
|
||||
|
||||
|
||||
def moe_align_block_size_triton(
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
block_size: int,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_pad: torch.Tensor,
|
||||
) -> None:
|
||||
numel = topk_ids.numel()
|
||||
grid = (num_experts,)
|
||||
tokens_cnts = torch.zeros(
|
||||
(num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
|
||||
tokens_per_thread = ceil_div(numel, num_experts)
|
||||
|
||||
moe_align_block_size_stage1[grid](
|
||||
topk_ids,
|
||||
tokens_cnts,
|
||||
num_experts,
|
||||
numel,
|
||||
tokens_per_thread,
|
||||
)
|
||||
moe_align_block_size_stage2[grid](
|
||||
tokens_cnts,
|
||||
num_experts,
|
||||
)
|
||||
moe_align_block_size_stage3[(1,)](
|
||||
num_tokens_post_pad,
|
||||
tokens_cnts,
|
||||
cumsum,
|
||||
num_experts,
|
||||
block_size,
|
||||
)
|
||||
moe_align_block_size_stage4[grid](
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
tokens_cnts,
|
||||
cumsum,
|
||||
num_experts,
|
||||
block_size,
|
||||
numel,
|
||||
tokens_per_thread,
|
||||
)
|
||||
|
||||
|
||||
def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
|
||||
topk_ids = torch.stack(
|
||||
[
|
||||
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
|
||||
for _ in range(num_tokens)
|
||||
]
|
||||
)
|
||||
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
sorted_ids_cuda = torch.empty(
|
||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
sorted_ids_cuda.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||
expert_ids_cuda = torch.zeros(
|
||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
num_tokens_post_pad_cuda = torch.empty(
|
||||
(1), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
cumsum_buffer = torch.zeros(
|
||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
sorted_ids_triton = torch.empty_like(sorted_ids_cuda)
|
||||
sorted_ids_triton.fill_(topk_ids.numel())
|
||||
expert_ids_triton = torch.zeros_like(expert_ids_cuda)
|
||||
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
|
||||
|
||||
sorted_ids_vllm = torch.empty_like(sorted_ids_cuda)
|
||||
sorted_ids_vllm.fill_(topk_ids.numel())
|
||||
expert_ids_vllm = torch.zeros_like(expert_ids_cuda)
|
||||
num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_cuda)
|
||||
|
||||
# compare the performance of cuda, triton and vllm implementation
|
||||
sgl_moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids_cuda,
|
||||
expert_ids_cuda,
|
||||
num_tokens_post_pad_cuda,
|
||||
cumsum_buffer,
|
||||
)
|
||||
moe_align_block_size_triton(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids_triton,
|
||||
expert_ids_triton,
|
||||
num_tokens_post_pad_triton,
|
||||
)
|
||||
|
||||
try:
|
||||
ops.moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids_vllm,
|
||||
expert_ids_vllm,
|
||||
num_tokens_post_pad_vllm,
|
||||
)
|
||||
print(f"✅ VLLM implementation works with {num_experts} experts!")
|
||||
vllm_works = True
|
||||
except Exception as e:
|
||||
print(f"❌ VLLM implementation failed with {num_experts} experts: {e}")
|
||||
vllm_works = False
|
||||
|
||||
if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose(
|
||||
num_tokens_post_pad_cuda, num_tokens_post_pad_triton
|
||||
):
|
||||
print("✅ SGL and Triton implementations match")
|
||||
else:
|
||||
print("❌ SGL and Triton implementations do not match")
|
||||
print("SGL expert_ids:", expert_ids_cuda)
|
||||
print("Triton expert_ids:", expert_ids_triton)
|
||||
print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda)
|
||||
print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton)
|
||||
|
||||
if (
|
||||
vllm_works
|
||||
and torch.allclose(expert_ids_cuda, expert_ids_vllm)
|
||||
and torch.allclose(num_tokens_post_pad_cuda, num_tokens_post_pad_vllm)
|
||||
):
|
||||
print("✅ SGL and VLLM implementations match")
|
||||
else:
|
||||
if not vllm_works:
|
||||
print("⚠️ VLLM comparison skipped due to failure")
|
||||
else:
|
||||
print("❌ SGL and VLLM implementations do not match")
|
||||
print("SGL expert_ids:", expert_ids_cuda)
|
||||
print("VLLM expert_ids:", expert_ids_vllm)
|
||||
print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda)
|
||||
print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm)
|
||||
|
||||
|
||||
# Test range
|
||||
num_tokens_range = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
||||
num_experts_range = [8, 32, 64, 128, 256]
|
||||
topk_range = [1, 2, 4, 8]
|
||||
|
||||
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
|
||||
|
||||
|
||||
def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
|
||||
topk_ids = torch.zeros((num_tokens, topk), dtype=torch.int32, device="cuda")
|
||||
for i in range(num_tokens):
|
||||
topk_ids[i, :] = torch.randperm(num_experts, dtype=torch.int32, device="cuda")[
|
||||
:topk
|
||||
]
|
||||
return topk_ids
|
||||
|
||||
|
||||
def sgl_moe_align_block_size_with_empty(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
pad_sorted_token_ids=False,
|
||||
):
|
||||
if not pad_sorted_token_ids:
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
|
||||
cumsum_buffer = torch.empty(
|
||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
sgl_moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids.clone(),
|
||||
expert_ids.clone(),
|
||||
num_tokens_post_pad.clone(),
|
||||
cumsum_buffer,
|
||||
pad_sorted_token_ids,
|
||||
)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens", "num_experts", "topk"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["sgl", "sgl_fusion", "triton"],
|
||||
line_names=["SGL", "SGL Fusion", "Triton"],
|
||||
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="moe-align-block-size-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(num_tokens, num_experts, topk, provider):
|
||||
block_size = 128
|
||||
|
||||
if USE_RANDOM_PERM:
|
||||
topk_ids = get_topk_ids(num_tokens, num_experts, topk)
|
||||
else:
|
||||
topk_ids = torch.randint(
|
||||
0,
|
||||
num_experts,
|
||||
(num_tokens, topk),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
sorted_ids = torch.empty(
|
||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||
expert_ids = torch.empty(
|
||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "sgl":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: sgl_moe_align_block_size_with_empty(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "sgl_fusion":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: sgl_moe_align_block_size_with_empty(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
pad_sorted_token_ids=True,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "triton":
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: moe_align_block_size_triton(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids.clone(),
|
||||
expert_ids.clone(),
|
||||
num_tokens_post_pad.clone(),
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
default="./configs/benchmark_ops/moe_align_blocks/",
|
||||
help="Path to save moe align benchmark results",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_experts",
|
||||
type=int,
|
||||
default=256,
|
||||
choices=[8, 16, 32, 64, 128, 256],
|
||||
help="Number of experts for benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--topk",
|
||||
type=int,
|
||||
default=8,
|
||||
choices=[2, 4, 8],
|
||||
help="Top-k value for benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_full_benchmark",
|
||||
action="store_true",
|
||||
help="Only run the calculate_diff function, skip full benchmarking",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
calculate_diff(num_tokens=1024, num_experts=args.num_experts, topk=args.topk)
|
||||
|
||||
if not args.skip_full_benchmark:
|
||||
print(f"\n📊 Running performance benchmark for {args.num_experts} experts...")
|
||||
benchmark.run(print_data=True)
|
||||
93
sgl-kernel/benchmark/bench_moe_ep_post_reorder.py
Normal file
93
sgl-kernel/benchmark/bench_moe_ep_post_reorder.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import ep_moe_post_reorder
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
|
||||
|
||||
batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096]
|
||||
configs = [(bs,) for bs in batch_sizes]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["cuda", "triton"],
|
||||
line_names=["CUDA Kernel", "Triton Kernel"],
|
||||
styles=[("green", "-"), ("orange", "-")],
|
||||
ylabel="us",
|
||||
plot_name="ep-moe-post-reorder-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider):
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
hidden_size, topk, start_expert_id, end_expert_id, block_size = 4096, 8, 0, 255, 512
|
||||
|
||||
def alloc_tensors():
|
||||
down_output = torch.randn(
|
||||
batch_size * topk, hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
output = torch.zeros(batch_size, hidden_size, dtype=dtype, device=device)
|
||||
src2dst = torch.randint(
|
||||
0, batch_size * topk, (batch_size, topk), dtype=torch.int32, device=device
|
||||
)
|
||||
topk_ids = torch.randint(
|
||||
start_expert_id,
|
||||
end_expert_id + 1,
|
||||
(batch_size, topk),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
topk_weights = torch.rand(batch_size, topk, dtype=dtype, device=device)
|
||||
return down_output, output, src2dst, topk_ids, topk_weights
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "cuda":
|
||||
d_out, out, s2d, tk_ids, tk_weights = alloc_tensors()
|
||||
|
||||
def run_cuda():
|
||||
ep_moe_post_reorder(
|
||||
d_out,
|
||||
out,
|
||||
s2d,
|
||||
tk_ids,
|
||||
tk_weights,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles)
|
||||
|
||||
elif provider == "triton":
|
||||
d_out, out, s2d, tk_ids, tk_weights = alloc_tensors()
|
||||
|
||||
def run_triton():
|
||||
post_reorder_triton_kernel[(batch_size,)](
|
||||
d_out.view(-1),
|
||||
out.view(-1),
|
||||
s2d.view(-1),
|
||||
tk_ids.view(-1),
|
||||
tk_weights.view(-1),
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
hidden_size,
|
||||
0,
|
||||
block_size,
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown provider: {provider}")
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark.run(print_data=True)
|
||||
103
sgl-kernel/benchmark/bench_moe_ep_pre_reorder.py
Normal file
103
sgl-kernel/benchmark/bench_moe_ep_pre_reorder.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import ep_moe_pre_reorder
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import pre_reorder_triton_kernel
|
||||
|
||||
batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096]
|
||||
configs = [(bs,) for bs in batch_sizes]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["cuda", "triton"],
|
||||
line_names=["CUDA Kernel", "Triton Kernel"],
|
||||
styles=[("green", "-"), ("orange", "-")],
|
||||
ylabel="us",
|
||||
plot_name="ep-moe-pre-reorder-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider):
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
hidden_size, topk, start_expert_id, end_expert_id, block_size = (
|
||||
4096,
|
||||
8,
|
||||
0,
|
||||
255,
|
||||
512,
|
||||
)
|
||||
|
||||
# Allocate fresh tensors for every run to match bench_moe_fused_gate style
|
||||
def alloc_tensors():
|
||||
input_ = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
|
||||
gateup_input = torch.zeros(
|
||||
batch_size * topk, hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
src2dst = torch.randint(
|
||||
0, batch_size * topk, (batch_size, topk), dtype=torch.int32, device=device
|
||||
)
|
||||
topk_ids = torch.randint(
|
||||
start_expert_id,
|
||||
end_expert_id + 1,
|
||||
(batch_size, topk),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
a1_scales = torch.rand(
|
||||
end_expert_id - start_expert_id + 1, dtype=torch.float32, device=device
|
||||
)
|
||||
return input_, gateup_input, src2dst, topk_ids, a1_scales
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "cuda":
|
||||
inp, gout, s2d, tk_ids, scales = alloc_tensors()
|
||||
|
||||
def run_cuda():
|
||||
ep_moe_pre_reorder(
|
||||
inp,
|
||||
gout,
|
||||
s2d,
|
||||
tk_ids,
|
||||
scales,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
True,
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles)
|
||||
|
||||
elif provider == "triton":
|
||||
inp, gout, s2d, tk_ids, scales = alloc_tensors()
|
||||
|
||||
def run_triton():
|
||||
pre_reorder_triton_kernel[(batch_size,)](
|
||||
inp.view(-1),
|
||||
gout.view(-1),
|
||||
s2d.view(-1),
|
||||
tk_ids.view(-1),
|
||||
scales,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
hidden_size,
|
||||
block_size,
|
||||
True,
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown provider: {provider}")
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark.run(print_data=True)
|
||||
77
sgl-kernel/benchmark/bench_moe_fused_gate.py
Normal file
77
sgl-kernel/benchmark/bench_moe_fused_gate.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import itertools
|
||||
import math
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sgl_kernel import moe_fused_gate
|
||||
|
||||
from sglang.srt.layers.moe.topk import biased_grouped_topk
|
||||
|
||||
|
||||
def biased_grouped_topk_org(scores, bias, num_expert_group, topk_group, topk):
|
||||
return biased_grouped_topk(
|
||||
scores,
|
||||
scores,
|
||||
bias,
|
||||
topk=topk,
|
||||
renormalize=True,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
routed_scaling_factor=2.5, # DeepSeek-R1 : 2.5, Kimi K2: 2.872
|
||||
)
|
||||
|
||||
|
||||
def biased_grouped_topk_org_fuse_kernel(
|
||||
scores, bias, num_expert_group, topk_group, topk
|
||||
):
|
||||
return moe_fused_gate(scores, bias, num_expert_group, topk_group, topk)
|
||||
|
||||
|
||||
seq_length_range = [5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000]
|
||||
configs = [(sq,) for sq in seq_length_range]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["seq_length"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["original", "kernel"],
|
||||
line_names=["Original", "SGL Kernel"],
|
||||
styles=[("blue", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name="moe-fused-gate-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(seq_length, provider):
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
num_experts, num_expert_group, topk_group, topk = 256, 8, 4, 8
|
||||
|
||||
scores = torch.randn((seq_length, num_experts), device=device, dtype=dtype)
|
||||
bias = torch.rand(num_experts, device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "original":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: biased_grouped_topk_org(
|
||||
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "kernel":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: biased_grouped_topk_org_fuse_kernel(
|
||||
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark.run(print_data=True)
|
||||
92
sgl-kernel/benchmark/bench_moe_silu_and_mul.py
Normal file
92
sgl-kernel/benchmark/bench_moe_silu_and_mul.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import ep_moe_silu_and_mul
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_triton_kernel
|
||||
|
||||
batch_size_range = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096]
|
||||
hidden_size_range = [1024, 2048, 4096, 8192]
|
||||
block_size_range = [128, 256, 512]
|
||||
configs = list(itertools.product(batch_size_range, hidden_size_range, block_size_range))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "hidden_size", "block_size"],
|
||||
x_vals=[list(cfg) for cfg in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["cuda", "triton"],
|
||||
line_names=["CUDA Kernel", "Triton Kernel"],
|
||||
styles=[("green", "-"), ("orange", "-")],
|
||||
ylabel="us",
|
||||
plot_name="ep-moe-silu-and-mul-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, hidden_size, block_size, provider):
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device("cuda")
|
||||
|
||||
half_hidden_size = hidden_size // 2
|
||||
start_expert_id, end_expert_id = 0, 255
|
||||
block_size = 512
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
def alloc_tensors():
|
||||
gateup_output = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
|
||||
down_input = torch.empty(
|
||||
batch_size, half_hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
reorder_topk_ids = torch.randint(
|
||||
start_expert_id,
|
||||
end_expert_id + 1,
|
||||
(batch_size,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
scales = torch.rand(
|
||||
end_expert_id - start_expert_id + 1, dtype=torch.float32, device=device
|
||||
)
|
||||
return gateup_output, down_input, reorder_topk_ids, scales
|
||||
|
||||
if provider == "cuda":
|
||||
gateup, down, ids, scales = alloc_tensors()
|
||||
|
||||
def run_cuda():
|
||||
ep_moe_silu_and_mul(
|
||||
gateup,
|
||||
down,
|
||||
ids,
|
||||
scales,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles)
|
||||
|
||||
elif provider == "triton":
|
||||
gateup, down, ids, scales = alloc_tensors()
|
||||
|
||||
def run_triton():
|
||||
silu_and_mul_triton_kernel[(batch_size,)](
|
||||
gateup.view(-1),
|
||||
down.view(-1),
|
||||
hidden_size,
|
||||
ids,
|
||||
scales,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
block_size,
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider: {provider}")
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark.run(print_data=True)
|
||||
116
sgl-kernel/benchmark/bench_moe_topk_softmax.py
Normal file
116
sgl-kernel/benchmark/bench_moe_topk_softmax.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import topk_softmax
|
||||
from vllm import _custom_ops as vllm_custom_ops
|
||||
|
||||
|
||||
def vllm_topk_softmax(gating_output, topk):
|
||||
num_tokens, num_experts = gating_output.shape
|
||||
|
||||
topk_weights = torch.empty(
|
||||
(num_tokens, topk), device=gating_output.device, dtype=torch.float32
|
||||
)
|
||||
topk_indices = torch.empty(
|
||||
(num_tokens, topk), dtype=torch.int32, device=gating_output.device
|
||||
)
|
||||
token_expert_indices = torch.empty(
|
||||
(num_tokens, topk), dtype=torch.int32, device=gating_output.device
|
||||
)
|
||||
torch.ops._moe_C.topk_softmax(
|
||||
topk_weights, topk_indices, token_expert_indices, gating_output
|
||||
)
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
def sglang_topk_softmax(gating_output, topk):
|
||||
num_tokens, num_experts = gating_output.shape
|
||||
|
||||
topk_weights = torch.empty(
|
||||
(num_tokens, topk), device=gating_output.device, dtype=torch.float32
|
||||
)
|
||||
topk_indices = torch.empty(
|
||||
(num_tokens, topk), dtype=torch.int32, device=gating_output.device
|
||||
)
|
||||
|
||||
topk_softmax(
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_indices,
|
||||
gating_output=gating_output,
|
||||
)
|
||||
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
def calculate_diff(num_tokens, num_experts, topk):
|
||||
gating_output = torch.randn(
|
||||
(num_tokens, num_experts), device="cuda", dtype=torch.float32
|
||||
)
|
||||
weights_vllm, indices_vllm = vllm_topk_softmax(gating_output.clone(), topk)
|
||||
weights_sglang, indices_sglang = sglang_topk_softmax(gating_output.clone(), topk)
|
||||
|
||||
weights_diff = torch.abs(weights_vllm - weights_sglang).mean().item()
|
||||
indices_match = torch.equal(indices_vllm, indices_sglang)
|
||||
|
||||
if (
|
||||
torch.allclose(weights_vllm, weights_sglang, atol=1e-3, rtol=1e-3)
|
||||
and indices_match
|
||||
):
|
||||
print("✅ VLLM and SGLang topk_softmax implementations match")
|
||||
else:
|
||||
print(
|
||||
f"❌ Implementations differ: Weights diff={weights_diff}, Indices match={indices_match}"
|
||||
)
|
||||
|
||||
|
||||
num_tokens_range = [128, 512, 1024, 2048, 4096, 8192, 16384, 32768]
|
||||
num_experts_range = [32, 64, 128, 256, 12, 512]
|
||||
topk_range = [1, 2, 4, 8]
|
||||
|
||||
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens", "num_experts", "topk"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["sglang", "vllm"],
|
||||
line_names=["SGLang", "VLLM"],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
ylabel="Latency (us)",
|
||||
plot_name="topk-softmax-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(num_tokens, num_experts, topk, provider):
|
||||
|
||||
gating_output = torch.randn(
|
||||
(num_tokens, num_experts), device="cuda", dtype=torch.float32
|
||||
)
|
||||
|
||||
if provider == "vllm" or provider == "vllm1":
|
||||
fn = lambda: vllm_topk_softmax(gating_output, topk)
|
||||
elif provider == "sglang" or provider == "sglang1":
|
||||
fn = lambda: sglang_topk_softmax(gating_output, topk)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
configs = [
|
||||
(20, 256, 4),
|
||||
(20, 256, 8),
|
||||
(20, 12, 4),
|
||||
(20, 12, 1),
|
||||
(20, 512, 4),
|
||||
(20, 512, 1),
|
||||
]
|
||||
for num_tokens, num_experts, topk in configs:
|
||||
calculate_diff(num_tokens, num_experts, topk)
|
||||
benchmark.run(print_data=True)
|
||||
172
sgl-kernel/benchmark/bench_nvfp4_scaled_gemm.py
Normal file
172
sgl-kernel/benchmark/bench_nvfp4_scaled_gemm.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
|
||||
FLOAT4_E2M1_MAX = 6.0
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
# Weight Shapes are in the format
|
||||
# ([K, N], TP_SPLIT_DIM)
|
||||
# Example:
|
||||
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
|
||||
# - TP1 : K = 14336, N = 4096
|
||||
# - TP2 : K = 7168, N = 4096
|
||||
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
|
||||
# - TP1 : K = 4096, N = 6144
|
||||
# - TP4 : K = 4096, N = 1536
|
||||
|
||||
# TP1 shapes
|
||||
WEIGHT_SHAPES = {
|
||||
"meta-llama/Llama-3.1-8B-Instruct": [
|
||||
([4096, 6144], 1),
|
||||
([4096, 4096], 0),
|
||||
([4096, 28672], 1),
|
||||
([14336, 4096], 0),
|
||||
],
|
||||
"meta-llama/Llama-3.3-70B-Instruct": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 57344], 1),
|
||||
([28672, 8192], 0),
|
||||
],
|
||||
"mistralai/Mistral-Large-Instruct-2407": [
|
||||
([12288, 14336], 1),
|
||||
([12288, 12288], 0),
|
||||
([12288, 57344], 1),
|
||||
([28672, 12288], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-7B-Instruct": [
|
||||
([3584, 4608], 1),
|
||||
([3584, 3584], 0),
|
||||
([3584, 37888], 1),
|
||||
([18944, 3584], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-32B-Instruct": [
|
||||
([5120, 7168], 1),
|
||||
([5120, 5120], 0),
|
||||
([5120, 55296], 1),
|
||||
([27648, 5120], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-72B-Instruct": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 59136], 1),
|
||||
([29568, 8192], 0),
|
||||
],
|
||||
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
|
||||
([2048, 3072], 1),
|
||||
([2048, 4096], 1),
|
||||
([2048, 2048], 0),
|
||||
([2048, 576], 0),
|
||||
([2048, 21888], 1),
|
||||
([10944, 2048], 0),
|
||||
([2048, 2816], 1),
|
||||
([1408, 2048], 0),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=[
|
||||
"sglang-fp4-fp16",
|
||||
"sglang-fp4-bf16",
|
||||
],
|
||||
line_names=[
|
||||
"sglang-fp4-fp16",
|
||||
"sglang-fp4-bf16",
|
||||
],
|
||||
styles=[("green", "-"), ("blue", "-")],
|
||||
ylabel="TFLOPS",
|
||||
plot_name="fp4 block scaled matmul",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K):
|
||||
# M, N, K = batch_size, 4096, 8192
|
||||
run_step = 100
|
||||
dtype = torch.float16 if "fp16" in provider else torch.bfloat16
|
||||
M = batch_size
|
||||
a = torch.randn((M, K), dtype=dtype, device="cuda")
|
||||
b = torch.randn((N, K), dtype=dtype, device="cuda")
|
||||
a_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
b_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
alpha = 1.0 / (a_global_scale * b_global_scale)
|
||||
a_fp4, a_scale_interleaved = scaled_fp4_quant(a, a_global_scale)
|
||||
b_fp4, b_scale_interleaved = scaled_fp4_quant(b, b_global_scale)
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
# Bridging the gap between CPU and GPU
|
||||
for _ in range(25):
|
||||
c = a @ b.t()
|
||||
# Warmup
|
||||
for _ in range(5):
|
||||
cutlass_scaled_fp4_mm(
|
||||
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
|
||||
)
|
||||
start_event.record()
|
||||
for _ in range(run_step):
|
||||
cutlass_scaled_fp4_mm(
|
||||
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
|
||||
)
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
ms = start_event.elapsed_time(end_event) / run_step
|
||||
|
||||
tflops = lambda ms: (2 * M * N * K) * 1e-9 / ms
|
||||
return tflops(ms)
|
||||
|
||||
|
||||
def prepare_shapes(args):
|
||||
KN_model_names = []
|
||||
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||
for model, tp_size in models_tps:
|
||||
assert model in WEIGHT_SHAPES
|
||||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||
KN.append(model)
|
||||
KN_model_names.append(KN)
|
||||
return KN_model_names
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||
help="List of models to benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[1],
|
||||
help="List of tensor parallel sizes",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
KN_model_names = prepare_shapes(args)
|
||||
for K, N, model_name in KN_model_names:
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True, show_plots=True, save_path="bench_fp4_res", N=N, K=K
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
98
sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py
Normal file
98
sgl-kernel/benchmark/bench_per_tensor_quant_fp8.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import itertools
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import sgl_per_tensor_quant_fp8
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
|
||||
|
||||
def vllm_scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return ops.scaled_fp8_quant(input, scale)
|
||||
|
||||
|
||||
def sglang_scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
fp8_type_: torch.dtype = torch.float8_e4m3fn
|
||||
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
|
||||
is_static = True
|
||||
if scale is None:
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
is_static = False
|
||||
sgl_per_tensor_quant_fp8(input, output, scale, is_static)
|
||||
|
||||
return output, scale
|
||||
|
||||
|
||||
def calculate_diff(batch_size: int, seq_len: int):
|
||||
"""Calculate difference between VLLM and SGLang implementations."""
|
||||
device = torch.device("cuda")
|
||||
x = torch.rand((batch_size, seq_len), dtype=torch.float16, device=device)
|
||||
|
||||
vllm_out, vllm_scale = vllm_scaled_fp8_quant(x)
|
||||
sglang_out, sglang_scale = sglang_scaled_fp8_quant(x)
|
||||
|
||||
scale_diff = torch.abs(vllm_scale - sglang_scale).item()
|
||||
output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
|
||||
|
||||
if torch.allclose(
|
||||
vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
||||
) and torch.allclose(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
batch_size_range = [16, 32, 64, 128]
|
||||
seq_len_range = [64, 128, 256, 512, 1024, 2048]
|
||||
|
||||
configs = list(itertools.product(batch_size_range, seq_len_range))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["vllm", "sglang"],
|
||||
line_names=["VLLM", "SGL Kernel"],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="per-tensor-quant-fp8-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, seq_len, provider):
|
||||
dtype = torch.float16
|
||||
device = torch.device("cuda")
|
||||
|
||||
x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "vllm":
|
||||
fn = lambda: vllm_scaled_fp8_quant(x.clone())
|
||||
elif provider == "sglang":
|
||||
fn = lambda: sglang_scaled_fp8_quant(x.clone())
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
calculate_diff(batch_size=4, seq_len=4096)
|
||||
benchmark.run(print_data=True)
|
||||
98
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
Normal file
98
sgl-kernel/benchmark/bench_per_token_group_quant_8bit.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import itertools
|
||||
import time
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from sglang.srt.bench_utils import bench_kineto
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
create_per_token_group_quant_fp8_output_scale,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_token_group_quant_8bit as triton_per_token_group_quant_8bit,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
|
||||
|
||||
num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384]
|
||||
hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1
|
||||
group_size_range = [128] # For DeepSeek V3/R1
|
||||
# TODO test int8
|
||||
dst_dtype_range = [fp8_type_]
|
||||
flags_range = [
|
||||
dict(
|
||||
column_major_scales=False,
|
||||
scale_tma_aligned=False,
|
||||
scale_ue8m0=False,
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=False,
|
||||
scale_ue8m0=False,
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=False,
|
||||
),
|
||||
dict(
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
scale_ue8m0=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
configs = list(
|
||||
itertools.product(
|
||||
num_tokens_range,
|
||||
hidden_dim_range,
|
||||
group_size_range,
|
||||
dst_dtype_range,
|
||||
flags_range,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens", "hidden_dim", "group_size", "dst_dtype", "flags"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["triton", "sglang"],
|
||||
line_names=["Triton", "SGL Kernel"],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="per-token-group-quant-8bit-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(num_tokens, hidden_dim, group_size, dst_dtype, flags, provider):
|
||||
if flags["scale_ue8m0"] and group_size != 128:
|
||||
return
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16)
|
||||
|
||||
fn, kernel_names = {
|
||||
"triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_fp8"),
|
||||
"sglang": (
|
||||
sglang_per_token_group_quant_8bit,
|
||||
"per_token_group_quant_8bit_kernel",
|
||||
),
|
||||
}[provider]
|
||||
bench_fn = lambda: fn(x=x, group_size=group_size, dst_dtype=dst_dtype, **flags)
|
||||
|
||||
time_s = bench_kineto(bench_fn, kernel_names=kernel_names)
|
||||
return time_s * 1e6
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark.run(print_data=True)
|
||||
177
sgl-kernel/benchmark/bench_per_token_quant_fp8.py
Normal file
177
sgl-kernel/benchmark/bench_per_token_quant_fp8.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import itertools
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.testing
|
||||
from sgl_kernel import sgl_per_token_quant_fp8
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
|
||||
# Get correct FP8 E4M3 maximum value
|
||||
if _is_hip:
|
||||
FP8_E4M3_MAX = 224.0 # ROCM uses 224.0
|
||||
else:
|
||||
# For CUDA, get the actual max value from the type
|
||||
FP8_E4M3_MAX = float(torch.finfo(fp8_type_).max)
|
||||
|
||||
|
||||
def torch_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Pure PyTorch reference implementation for per-token FP8 quantization."""
|
||||
device = input.device
|
||||
dtype = input.dtype
|
||||
|
||||
# Find max absolute value per token (row) - exactly like CUDA kernel
|
||||
max_vals = torch.abs(input).max(dim=1)[0] # [num_tokens]
|
||||
|
||||
# Calculate scale per token - exactly like CUDA kernel: scale = max_value / FP8_E4M3_MAX
|
||||
scales = max_vals / FP8_E4M3_MAX # [num_tokens]
|
||||
|
||||
# No special zero handling - directly compute 1.0 / scale like CUDA kernel
|
||||
scale_inv = 1.0 / scales # [num_tokens]
|
||||
|
||||
# Quantize: input * scale_inv, then clamp to FP8 range
|
||||
quantized_float = input * scale_inv.unsqueeze(1) # Broadcast scale_inv
|
||||
quantized_float = torch.clamp(quantized_float, -FP8_E4M3_MAX, FP8_E4M3_MAX)
|
||||
|
||||
# Convert to FP8 - use more explicit conversion
|
||||
quantized_fp8 = quantized_float.to(fp8_type_)
|
||||
|
||||
return quantized_fp8, scales
|
||||
|
||||
|
||||
def vllm_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return ops.scaled_fp8_quant(input, use_per_token_if_dynamic=True)
|
||||
|
||||
|
||||
def sglang_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
scale = torch.zeros(input.size(0), device=input.device, dtype=torch.float32)
|
||||
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
|
||||
sgl_per_token_quant_fp8(input, output, scale)
|
||||
|
||||
return output, scale
|
||||
|
||||
|
||||
def calculate_diff(batch_size: int, seq_len: int, hidden_dim: int):
|
||||
"""Compare Torch reference, VLLM, and SGLang implementations."""
|
||||
device = torch.device("cuda")
|
||||
x = torch.rand(
|
||||
(batch_size * seq_len, hidden_dim), dtype=torch.float16, device=device
|
||||
)
|
||||
|
||||
# Get all three implementations
|
||||
torch_out, torch_scale = torch_per_token_quant_fp8(x)
|
||||
vllm_out, vllm_scale = vllm_per_token_quant_fp8(x)
|
||||
sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)
|
||||
|
||||
print(f"\n=== Comparison for hidden_dim={hidden_dim} ===")
|
||||
|
||||
# Compare scales
|
||||
torch_vllm_scale_diff = torch.abs(torch_scale - vllm_scale).mean().item()
|
||||
torch_sglang_scale_diff = torch.abs(torch_scale - sglang_scale).mean().item()
|
||||
vllm_sglang_scale_diff = torch.abs(vllm_scale - sglang_scale).mean().item()
|
||||
|
||||
print(f"Scale differences:")
|
||||
print(f" Torch vs VLLM: {torch_vllm_scale_diff:.8f}")
|
||||
print(f" Torch vs SGLang: {torch_sglang_scale_diff:.8f}")
|
||||
print(f" VLLM vs SGLang: {vllm_sglang_scale_diff:.8f}")
|
||||
|
||||
# Compare outputs
|
||||
torch_vllm_out_diff = torch.abs(torch_out.float() - vllm_out.float()).mean().item()
|
||||
torch_sglang_out_diff = (
|
||||
torch.abs(torch_out.float() - sglang_out.float()).mean().item()
|
||||
)
|
||||
vllm_sglang_out_diff = (
|
||||
torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
|
||||
)
|
||||
|
||||
print(f"Output differences:")
|
||||
print(f" Torch vs VLLM: {torch_vllm_out_diff:.8f}")
|
||||
print(f" Torch vs SGLang: {torch_sglang_out_diff:.8f}")
|
||||
print(f" VLLM vs SGLang: {vllm_sglang_out_diff:.8f}")
|
||||
|
||||
# Check tolerances
|
||||
rtol, atol = 1e-3, 1e-5
|
||||
|
||||
torch_vllm_match = torch.allclose(
|
||||
torch_out.float(), vllm_out.float(), rtol=rtol, atol=atol
|
||||
) and torch.allclose(torch_scale, vllm_scale, rtol=rtol, atol=atol)
|
||||
torch_sglang_match = torch.allclose(
|
||||
torch_out.float(), sglang_out.float(), rtol=rtol, atol=atol
|
||||
) and torch.allclose(torch_scale, sglang_scale, rtol=rtol, atol=atol)
|
||||
|
||||
if hidden_dim == 1368:
|
||||
rtol = 1e-2
|
||||
# we found vllm sglang has diff when hidden dim is not dividable by 16
|
||||
# and we believe SGLang is closer to Torch implementation
|
||||
|
||||
vllm_sglang_match = torch.allclose(
|
||||
vllm_out.float(), sglang_out.float(), rtol=rtol, atol=atol
|
||||
) and torch.allclose(vllm_scale, sglang_scale, rtol=rtol, atol=atol)
|
||||
|
||||
print(f"Matches (rtol={rtol}, atol={atol}):")
|
||||
print(f" Torch vs VLLM: {'✅' if torch_vllm_match else '❌'}")
|
||||
print(f" Torch vs SGLang: {'✅' if torch_sglang_match else '❌'}")
|
||||
print(f" VLLM vs SGLang: {'✅' if vllm_sglang_match else '❌'}")
|
||||
|
||||
|
||||
batch_size_range = [16, 32, 64, 128]
|
||||
seq_len_range = [64, 128, 256, 512, 1024, 2048, 4096]
|
||||
hidden_dim_range = [1368, 2048, 4096]
|
||||
|
||||
configs = list(itertools.product(batch_size_range, seq_len_range, hidden_dim_range))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len", "hidden_dim"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["torch", "vllm", "sglang"],
|
||||
line_names=["Torch Reference", "VLLM", "SGL Kernel"],
|
||||
styles=[("red", "-"), ("blue", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="per-token-dynamic-quant-fp8-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark_quantization(batch_size, seq_len, hidden_dim, provider):
|
||||
dtype = torch.float16
|
||||
device = torch.device("cuda")
|
||||
|
||||
x = torch.randn(batch_size * seq_len, hidden_dim, device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch":
|
||||
fn = lambda: torch_per_token_quant_fp8(x.clone())
|
||||
elif provider == "vllm":
|
||||
fn = lambda: vllm_per_token_quant_fp8(x.clone())
|
||||
elif provider == "sglang":
|
||||
fn = lambda: sglang_per_token_quant_fp8(x.clone())
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test various hidden dimensions for correctness
|
||||
test_dims = [1368, 2048, 4096]
|
||||
|
||||
for dim in test_dims:
|
||||
calculate_diff(batch_size=4, seq_len=4096, hidden_dim=dim)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Starting performance benchmark...")
|
||||
benchmark_quantization.run(print_data=True)
|
||||
198
sgl-kernel/benchmark/bench_qserve_w4a8_gemm.py
Normal file
198
sgl-kernel/benchmark/bench_qserve_w4a8_gemm.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import (
|
||||
int8_scaled_mm,
|
||||
qserve_w4a8_per_chn_gemm,
|
||||
qserve_w4a8_per_group_gemm,
|
||||
)
|
||||
|
||||
|
||||
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||
|
||||
|
||||
WEIGHT_SHAPES = {
|
||||
"meta-llama/Llama-3.1-8B-Instruct": [
|
||||
([4096, 6144], 1),
|
||||
([4096, 4096], 0),
|
||||
([4096, 28672], 1),
|
||||
([14336, 4096], 0),
|
||||
],
|
||||
"meta-llama/Llama-3.3-70B-Instruct": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 57344], 1),
|
||||
([28672, 8192], 0),
|
||||
],
|
||||
"mistralai/Mistral-Large-Instruct-2407": [
|
||||
([12288, 14336], 1),
|
||||
([12288, 12288], 0),
|
||||
([12288, 57344], 1),
|
||||
([28672, 12288], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-7B-Instruct": [
|
||||
([3584, 4608], 1),
|
||||
([3584, 3584], 0),
|
||||
([3584, 37888], 1),
|
||||
([18944, 3584], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-32B-Instruct": [
|
||||
([5120, 7168], 1),
|
||||
([5120, 5120], 0),
|
||||
([5120, 55296], 1),
|
||||
([27648, 5120], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-72B-Instruct": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 59136], 1),
|
||||
([29568, 8192], 0),
|
||||
],
|
||||
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
|
||||
([2048, 3072], 1),
|
||||
([2048, 4096], 1),
|
||||
([2048, 2048], 0),
|
||||
([2048, 576], 0),
|
||||
([2048, 21888], 1),
|
||||
([10944, 2048], 0),
|
||||
([2048, 2816], 1),
|
||||
([1408, 2048], 0),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=["FP16", "W8A8", "Qserve_W4A8_Per_Channel", "Qserve_W4A8_Per_Group"],
|
||||
line_names=["FP16", "W8A8", "Qserve_W4A8_Per_Channel", "Qserve_W4A8_Per_Group"],
|
||||
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
|
||||
ylabel="ms",
|
||||
plot_name="FP16_vs_W8A8_vs_Qserve_W4A8_GEMM",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K):
|
||||
M = batch_size
|
||||
# For W8A8
|
||||
a = to_int8(torch.randn((M, K), device="cuda") * 5)
|
||||
b = to_int8(torch.randn((N, K), device="cuda").t() * 5)
|
||||
a_fp16 = a.to(torch.float16)
|
||||
b_fp16 = b.to(torch.float16)
|
||||
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
|
||||
|
||||
# For Qserve W4A8 per channel
|
||||
a_qserve_chn = a
|
||||
# two int4s pack into one int8
|
||||
b_qserve_chn = to_int8(torch.randn((N, K // 2), device="cuda") * 5)
|
||||
# b_qserve_chn = b.t().contiguous()
|
||||
scale_a_qserve_chn = scale_a.to(torch.float16)
|
||||
scale_b_qserve_chn = scale_b.to(torch.float16)
|
||||
szero_b_qserve_chn = torch.randn((N,), device="cuda", dtype=torch.float16)
|
||||
a_sum_qserve_chn = torch.randn((M,), device="cuda", dtype=torch.float16)
|
||||
|
||||
# For Qserve W4A8 per group
|
||||
group_size = 128
|
||||
assert K % group_size == 0, "K must be divisible by group_size"
|
||||
a_qserve_group = a
|
||||
# two int4s pack into one int8
|
||||
b_qserve_group = to_int8(torch.randn((N, K // 2), device="cuda") * 5)
|
||||
# b_qserve_group = b.t().contiguous()
|
||||
scale_a_qserve_group = scale_a.to(torch.float16)
|
||||
scale_b_qserve_group = scale_b.to(torch.float16)
|
||||
scale_i8_b_qserve_group = to_int8(
|
||||
torch.randn((K // group_size, N), device="cuda", dtype=torch.float16)
|
||||
)
|
||||
zero_i8_b_qserve_group = to_int8(
|
||||
torch.randn((K // group_size, N), device="cuda", dtype=torch.float16)
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == "FP16":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: torch.matmul(a_fp16, b_fp16),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "W8A8":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "Qserve_W4A8_Per_Channel":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: qserve_w4a8_per_chn_gemm(
|
||||
a_qserve_chn,
|
||||
b_qserve_chn,
|
||||
scale_b_qserve_chn,
|
||||
scale_a_qserve_chn,
|
||||
szero_b_qserve_chn,
|
||||
a_sum_qserve_chn,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
if provider == "Qserve_W4A8_Per_Group":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: qserve_w4a8_per_group_gemm(
|
||||
a_qserve_group,
|
||||
b_qserve_group,
|
||||
zero_i8_b_qserve_group,
|
||||
scale_i8_b_qserve_group,
|
||||
scale_b_qserve_group,
|
||||
scale_a_qserve_group,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return ms, max_ms, min_ms
|
||||
|
||||
|
||||
def prepare_shapes(args):
|
||||
KN_model_names = []
|
||||
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||
for model, tp_size in models_tps:
|
||||
assert model in WEIGHT_SHAPES
|
||||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||
KN.append(model)
|
||||
KN_model_names.append(KN)
|
||||
return KN_model_names
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||
help="List of models to benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[1],
|
||||
help="List of tensor parallel sizes",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
KN_model_names = prepare_shapes(args)
|
||||
for K, N, model_name in KN_model_names:
|
||||
print(f"{model_name} N={N} K={K}: ")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path="bench_qserve_w4a8_gemm_res",
|
||||
N=N,
|
||||
K=K,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
96
sgl-kernel/benchmark/bench_rotary_embedding.py
Normal file
96
sgl-kernel/benchmark/bench_rotary_embedding.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel import FusedSetKVBufferArg
|
||||
from sgl_kernel.testing.rotary_embedding import (
|
||||
FlashInferRotaryEmbedding,
|
||||
MHATokenToKVPool,
|
||||
RotaryEmbedding,
|
||||
create_inputs,
|
||||
)
|
||||
|
||||
from sglang.srt.bench_utils import bench_kineto
|
||||
|
||||
configs = [
|
||||
(batch_size, seq_len, save_kv_cache)
|
||||
for batch_size, seq_len in (
|
||||
(1, 1),
|
||||
(32, 1),
|
||||
(128, 1),
|
||||
(512, 1),
|
||||
(2, 512),
|
||||
(4, 4096),
|
||||
)
|
||||
for save_kv_cache in (False, True)
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len", "save_kv_cache"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["sglang"],
|
||||
line_names=["SGL Kernel"],
|
||||
styles=[("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="bench_rotary_embedding",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, seq_len, save_kv_cache, provider):
|
||||
device = torch.device("cuda")
|
||||
|
||||
num_q_heads = 32
|
||||
num_kv_heads = 8
|
||||
head_size = 64
|
||||
dtype = torch.bfloat16
|
||||
|
||||
config = dict(
|
||||
head_size=head_size,
|
||||
rotary_dim=64,
|
||||
max_position_embeddings=4096,
|
||||
base=8000,
|
||||
is_neox_style=True,
|
||||
dtype=dtype,
|
||||
)
|
||||
rope_flashinfer = FlashInferRotaryEmbedding(**config).to(device)
|
||||
pool_flashinfer = MHATokenToKVPool(head_num=num_kv_heads, head_dim=head_size)
|
||||
|
||||
inputs = create_inputs(
|
||||
head_size=head_size,
|
||||
batch_size=batch_size,
|
||||
seq_len=seq_len,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
num_q_heads=num_q_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
)
|
||||
|
||||
query_flashinfer, key_flashinfer = inputs["query"].clone(), inputs["key"].clone()
|
||||
|
||||
bench_fn = lambda: rope_flashinfer.forward_cuda(
|
||||
inputs["pos_ids"],
|
||||
query_flashinfer,
|
||||
key_flashinfer,
|
||||
fused_set_kv_buffer_arg=(
|
||||
FusedSetKVBufferArg(
|
||||
value=inputs["value"],
|
||||
k_buffer=pool_flashinfer.k_buffer[0].view(-1, num_kv_heads * head_size),
|
||||
v_buffer=pool_flashinfer.v_buffer[0].view(-1, num_kv_heads * head_size),
|
||||
k_scale=None,
|
||||
v_scale=None,
|
||||
cache_loc=inputs["out_cache_loc"],
|
||||
)
|
||||
if save_kv_cache
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
time_s = bench_kineto(bench_fn, kernel_names="BatchQKApplyRotaryPosIds")
|
||||
return time_s * 1e6
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark.run(print_data=True)
|
||||
128
sgl-kernel/benchmark/bench_top_k_top_p_sampling.py
Normal file
128
sgl-kernel/benchmark/bench_top_k_top_p_sampling.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import itertools
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
import triton
|
||||
import triton.testing
|
||||
|
||||
|
||||
def torch_top_k_top_p_joint_sampling_from_probs(
|
||||
normalized_prob, top_k, top_p, eps=1e-4
|
||||
):
|
||||
"""Reference PyTorch implementation of joint top-k top-p sampling."""
|
||||
batch_size, vocab_size = normalized_prob.shape
|
||||
samples = torch.empty(batch_size, dtype=torch.int64, device=normalized_prob.device)
|
||||
|
||||
for i in range(batch_size):
|
||||
p_val = top_p[i].item()
|
||||
k_val = top_k[i].item()
|
||||
|
||||
# top-p mask
|
||||
sorted_prob, indices = torch.sort(normalized_prob[i], descending=False)
|
||||
cdf = torch.cumsum(sorted_prob, dim=-1)
|
||||
mask_top_p = torch.zeros(
|
||||
vocab_size, dtype=torch.int32, device=normalized_prob.device
|
||||
)
|
||||
mask_top_p.scatter_add_(0, indices, (cdf > (1 - p_val) - eps).int())
|
||||
|
||||
# top-k mask
|
||||
sorted_prob_desc, _ = torch.sort(normalized_prob[i], descending=True)
|
||||
pivot = sorted_prob_desc[k_val - 1]
|
||||
mask_top_k = (normalized_prob[i] >= pivot).int()
|
||||
|
||||
# joint mask
|
||||
mask = torch.minimum(mask_top_p, mask_top_k).bool()
|
||||
|
||||
# sample from masked probs
|
||||
masked_probs = normalized_prob[i] * mask
|
||||
masked_probs = masked_probs / masked_probs.sum()
|
||||
idx = torch.multinomial(masked_probs, 1)
|
||||
samples[i] = idx
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def calculate_diff(batch_size, vocab_size, p):
|
||||
"""Compare Torch reference and SGLang kernel for correctness."""
|
||||
torch.manual_seed(42)
|
||||
if p == 0.1:
|
||||
k = int(vocab_size * 0.5)
|
||||
elif p == 0.5:
|
||||
k = int(vocab_size * 0.1)
|
||||
else:
|
||||
raise ValueError("p not recognized")
|
||||
|
||||
device = torch.device("cuda")
|
||||
pre_norm_prob = torch.rand(batch_size, vocab_size, device=device)
|
||||
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
||||
|
||||
top_p_tensor = torch.full((batch_size,), p, device=device)
|
||||
top_k_tensor = torch.full((batch_size,), k, device=device)
|
||||
|
||||
torch_samples = torch_top_k_top_p_joint_sampling_from_probs(
|
||||
normalized_prob, top_k_tensor, top_p_tensor
|
||||
)
|
||||
sglang_samples = sgl_kernel.top_k_top_p_sampling_from_probs(
|
||||
normalized_prob, top_k_tensor, top_p_tensor, filter_apply_order="joint"
|
||||
)
|
||||
|
||||
|
||||
# parameter space
|
||||
batch_size_range = [16, 64, 128]
|
||||
vocab_size_range = [111, 32000]
|
||||
p_range = [0.1, 0.5]
|
||||
configs = list(itertools.product(batch_size_range, vocab_size_range, p_range))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "vocab_size", "p"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["torch", "sglang"],
|
||||
line_names=["Torch Reference", "SGL Kernel"],
|
||||
styles=[("red", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="top-k-top-p-joint-sampling-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark_sampling(batch_size, vocab_size, p, provider):
|
||||
torch.manual_seed(42)
|
||||
if p == 0.1:
|
||||
k = int(vocab_size * 0.5)
|
||||
elif p == 0.5:
|
||||
k = int(vocab_size * 0.1)
|
||||
else:
|
||||
raise ValueError("p not recognized")
|
||||
|
||||
device = torch.device("cuda")
|
||||
pre_norm_prob = torch.rand(batch_size, vocab_size, device=device)
|
||||
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
||||
top_p_tensor = torch.full((batch_size,), p, device=device)
|
||||
top_k_tensor = torch.full((batch_size,), k, device=device)
|
||||
|
||||
if provider == "torch":
|
||||
fn = lambda: torch_top_k_top_p_joint_sampling_from_probs(
|
||||
normalized_prob.clone(), top_k_tensor, top_p_tensor
|
||||
)
|
||||
elif provider == "sglang":
|
||||
fn = lambda: sgl_kernel.top_k_top_p_sampling_from_probs(
|
||||
normalized_prob.clone(),
|
||||
top_k_tensor,
|
||||
top_p_tensor,
|
||||
filter_apply_order="joint",
|
||||
)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Correctness check
|
||||
for cfg in configs:
|
||||
calculate_diff(*cfg)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Starting performance benchmark...")
|
||||
benchmark_sampling.run(print_data=True)
|
||||
75
sgl-kernel/build.sh
Executable file
75
sgl-kernel/build.sh
Executable file
@@ -0,0 +1,75 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
PYTHON_VERSION=$1
|
||||
CUDA_VERSION=$2
|
||||
PYTHON_ROOT_PATH=/opt/python/cp${PYTHON_VERSION//.}-cp${PYTHON_VERSION//.}
|
||||
|
||||
if [ -z "$3" ]; then
|
||||
ARCH=$(uname -i)
|
||||
else
|
||||
ARCH=$3
|
||||
fi
|
||||
|
||||
echo "ARCH: $ARCH"
|
||||
if [ ${ARCH} = "aarch64" ]; then
|
||||
LIBCUDA_ARCH="sbsa"
|
||||
BUILDER_NAME="pytorch/manylinuxaarch64-builder"
|
||||
CMAKE_BUILD_PARALLEL_LEVEL=16
|
||||
else
|
||||
LIBCUDA_ARCH=${ARCH}
|
||||
BUILDER_NAME="pytorch/manylinux2_28-builder"
|
||||
fi
|
||||
|
||||
if [ ${CUDA_VERSION} = "12.9" ]; then
|
||||
DOCKER_IMAGE="${BUILDER_NAME}:cuda${CUDA_VERSION}"
|
||||
TORCH_INSTALL="pip install --no-cache-dir torch==2.8.0 --index-url https://download.pytorch.org/whl/cu129"
|
||||
elif [ ${CUDA_VERSION} = "12.8" ]; then
|
||||
DOCKER_IMAGE="${BUILDER_NAME}:cuda${CUDA_VERSION}"
|
||||
TORCH_INSTALL="pip install --no-cache-dir torch==2.8.0 --index-url https://download.pytorch.org/whl/cu128"
|
||||
else
|
||||
DOCKER_IMAGE="${BUILDER_NAME}:cuda${CUDA_VERSION}"
|
||||
TORCH_INSTALL="pip install --no-cache-dir torch==2.8.0 --index-url https://download.pytorch.org/whl/cu126"
|
||||
fi
|
||||
|
||||
docker run --rm \
|
||||
-v $(pwd):/sgl-kernel \
|
||||
${DOCKER_IMAGE} \
|
||||
bash -c "
|
||||
# Install CMake (version >= 3.26) - Robust Installation
|
||||
export CMAKE_VERSION_MAJOR=3.31
|
||||
export CMAKE_VERSION_MINOR=1
|
||||
# Setting these flags to reduce OOM chance only on ARM
|
||||
if [ \"${ARCH}\" = \"aarch64\" ]; then
|
||||
export CUDA_NVCC_FLAGS=\"-Xcudafe --threads=2\"
|
||||
export MAKEFLAGS='-j2'
|
||||
export CMAKE_BUILD_PARALLEL_LEVEL=2
|
||||
export NINJAFLAGS='-j2'
|
||||
fi
|
||||
echo \"Downloading CMake from: https://cmake.org/files/v\${CMAKE_VERSION_MAJOR}/cmake-\${CMAKE_VERSION_MAJOR}.\${CMAKE_VERSION_MINOR}-linux-${ARCH}.tar.gz\"
|
||||
wget https://cmake.org/files/v\${CMAKE_VERSION_MAJOR}/cmake-\${CMAKE_VERSION_MAJOR}.\${CMAKE_VERSION_MINOR}-linux-${ARCH}.tar.gz
|
||||
tar -xzf cmake-\${CMAKE_VERSION_MAJOR}.\${CMAKE_VERSION_MINOR}-linux-${ARCH}.tar.gz
|
||||
mv cmake-\${CMAKE_VERSION_MAJOR}.\${CMAKE_VERSION_MINOR}-linux-${ARCH} /opt/cmake
|
||||
export PATH=/opt/cmake/bin:\$PATH
|
||||
export LD_LIBRARY_PATH=/lib64:\$LD_LIBRARY_PATH
|
||||
|
||||
# Debugging CMake
|
||||
echo \"PATH: \$PATH\"
|
||||
which cmake
|
||||
cmake --version
|
||||
|
||||
yum install numactl-devel -y && \
|
||||
yum install libibverbs -y --nogpgcheck && \
|
||||
ln -sv /usr/lib64/libibverbs.so.1 /usr/lib64/libibverbs.so && \
|
||||
${PYTHON_ROOT_PATH}/bin/${TORCH_INSTALL} && \
|
||||
${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir ninja setuptools==75.0.0 wheel==0.41.0 numpy uv scikit-build-core && \
|
||||
export TORCH_CUDA_ARCH_LIST='8.0 8.9 9.0+PTX' && \
|
||||
export CUDA_VERSION=${CUDA_VERSION} && \
|
||||
mkdir -p /usr/lib/${ARCH}-linux-gnu/ && \
|
||||
ln -s /usr/local/cuda-${CUDA_VERSION}/targets/${LIBCUDA_ARCH}-linux/lib/stubs/libcuda.so /usr/lib/${ARCH}-linux-gnu/libcuda.so && \
|
||||
|
||||
cd /sgl-kernel && \
|
||||
ls -la ${PYTHON_ROOT_PATH}/lib/python${PYTHON_VERSION}/site-packages/wheel/ && \
|
||||
PYTHONPATH=${PYTHON_ROOT_PATH}/lib/python${PYTHON_VERSION}/site-packages ${PYTHON_ROOT_PATH}/bin/python -m uv build --wheel -Cbuild-dir=build . --color=always --no-build-isolation && \
|
||||
./rename_wheels.sh
|
||||
"
|
||||
@@ -0,0 +1,117 @@
|
||||
import ctypes
|
||||
import os
|
||||
import platform
|
||||
|
||||
import torch
|
||||
|
||||
SYSTEM_ARCH = platform.machine()
|
||||
|
||||
cuda_path = f"/usr/local/cuda/targets/{SYSTEM_ARCH}-linux/lib/libcudart.so.12"
|
||||
if os.path.exists(cuda_path):
|
||||
ctypes.CDLL(cuda_path, mode=ctypes.RTLD_GLOBAL)
|
||||
|
||||
from sgl_kernel import common_ops
|
||||
from sgl_kernel.allreduce import *
|
||||
from sgl_kernel.attention import (
|
||||
cutlass_mla_decode,
|
||||
cutlass_mla_get_workspace_size,
|
||||
lightning_attention_decode,
|
||||
merge_state,
|
||||
merge_state_v2,
|
||||
)
|
||||
from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_data
|
||||
from sgl_kernel.elementwise import (
|
||||
FusedSetKVBufferArg,
|
||||
apply_rope_with_cos_sin_cache_inplace,
|
||||
downcast_fp8,
|
||||
fused_add_rmsnorm,
|
||||
gelu_and_mul,
|
||||
gelu_tanh_and_mul,
|
||||
gemma_fused_add_rmsnorm,
|
||||
gemma_rmsnorm,
|
||||
rmsnorm,
|
||||
silu_and_mul,
|
||||
)
|
||||
|
||||
if torch.version.hip is not None:
|
||||
from sgl_kernel.elementwise import gelu_quick
|
||||
|
||||
from sgl_kernel.fused_moe import fused_marlin_moe
|
||||
from sgl_kernel.gemm import (
|
||||
awq_dequantize,
|
||||
bmm_fp8,
|
||||
cutlass_scaled_fp4_mm,
|
||||
dsv3_fused_a_gemm,
|
||||
dsv3_router_gemm,
|
||||
fp8_blockwise_scaled_mm,
|
||||
fp8_scaled_mm,
|
||||
gptq_gemm,
|
||||
gptq_marlin_gemm,
|
||||
gptq_shuffle,
|
||||
int8_scaled_mm,
|
||||
qserve_w4a8_per_chn_gemm,
|
||||
qserve_w4a8_per_group_gemm,
|
||||
scaled_fp4_experts_quant,
|
||||
scaled_fp4_grouped_quant,
|
||||
scaled_fp4_quant,
|
||||
sgl_per_tensor_quant_fp8,
|
||||
sgl_per_token_group_quant_fp8,
|
||||
sgl_per_token_group_quant_int8,
|
||||
sgl_per_token_quant_fp8,
|
||||
shuffle_rows,
|
||||
silu_and_mul_scaled_fp4_grouped_quant,
|
||||
)
|
||||
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
|
||||
from sgl_kernel.kvcacheio import (
|
||||
transfer_kv_all_layer,
|
||||
transfer_kv_all_layer_mla,
|
||||
transfer_kv_per_layer,
|
||||
transfer_kv_per_layer_mla,
|
||||
)
|
||||
from sgl_kernel.marlin import (
|
||||
awq_marlin_moe_repack,
|
||||
awq_marlin_repack,
|
||||
gptq_marlin_repack,
|
||||
)
|
||||
from sgl_kernel.memory import set_kv_buffer_kernel
|
||||
from sgl_kernel.moe import (
|
||||
apply_shuffle_mul_sum,
|
||||
cutlass_fp4_group_mm,
|
||||
ep_moe_post_reorder,
|
||||
ep_moe_pre_reorder,
|
||||
ep_moe_silu_and_mul,
|
||||
fp8_blockwise_scaled_grouped_mm,
|
||||
moe_align_block_size,
|
||||
moe_fused_gate,
|
||||
prepare_moe_input,
|
||||
topk_softmax,
|
||||
)
|
||||
from sgl_kernel.sampling import (
|
||||
min_p_sampling_from_probs,
|
||||
top_k_mask_logits,
|
||||
top_k_renorm_prob,
|
||||
top_k_top_p_sampling_from_logits,
|
||||
top_k_top_p_sampling_from_probs,
|
||||
top_p_renorm_prob,
|
||||
top_p_sampling_from_probs,
|
||||
)
|
||||
from sgl_kernel.speculative import (
|
||||
build_tree_kernel_efficient,
|
||||
segment_packbits,
|
||||
tree_speculative_sampling_target_only,
|
||||
verify_tree_greedy,
|
||||
)
|
||||
from sgl_kernel.top_k import fast_topk
|
||||
from sgl_kernel.version import __version__
|
||||
|
||||
|
||||
def create_greenctx_stream_by_value(*args, **kwargs):
|
||||
from sgl_kernel.spatial import create_greenctx_stream_by_value as _impl
|
||||
|
||||
return _impl(*args, **kwargs)
|
||||
|
||||
|
||||
def get_sm_available(*args, **kwargs):
|
||||
from sgl_kernel.spatial import get_sm_available as _impl
|
||||
|
||||
return _impl(*args, **kwargs)
|
||||
@@ -0,0 +1,173 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
if torch.version.hip is not None:
|
||||
# ROCM custom allreduce
|
||||
def init_custom_ar(
|
||||
meta: torch.Tensor,
|
||||
rank_data: torch.Tensor,
|
||||
handles: List[str],
|
||||
offsets: List[int],
|
||||
rank: int,
|
||||
full_nvlink: bool,
|
||||
) -> int:
|
||||
return torch.ops.sgl_kernel.init_custom_ar.default(
|
||||
meta, rank_data, handles, offsets, rank, full_nvlink
|
||||
)
|
||||
|
||||
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||
torch.ops.sgl_kernel.all_reduce_reg.default(fa, inp, out)
|
||||
|
||||
def all_reduce_unreg(
|
||||
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.all_reduce_unreg.default(fa, inp, reg_buffer, out)
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
torch.ops.sgl_kernel.dispose.default(fa)
|
||||
|
||||
def meta_size() -> int:
|
||||
return torch.ops.sgl_kernel.meta_size.default()
|
||||
|
||||
def register_buffer(
|
||||
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
|
||||
) -> None:
|
||||
return torch.ops.sgl_kernel.register_buffer.default(fa, t, handles, offsets)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
|
||||
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta.default(fa)
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[str], offsets: List[List[int]]
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.register_graph_buffers.default(fa, handles, offsets)
|
||||
|
||||
def allocate_meta_buffer(size: int) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.allocate_meta_buffer.default(size)
|
||||
|
||||
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle.default(inp)
|
||||
|
||||
# ROCM quick allreduce
|
||||
def init_custom_qr(
|
||||
rank: int, world_size: int, qr_max_size: Optional[int] = None
|
||||
) -> int:
|
||||
return torch.ops.sgl_kernel.init_custom_qr.default(
|
||||
world_size, rank, qr_max_size
|
||||
)
|
||||
|
||||
def qr_get_handle(fa: int) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.qr_get_handle.default(fa)
|
||||
|
||||
def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None:
|
||||
torch.ops.sgl_kernel.qr_open_handles.default(fa, handles)
|
||||
|
||||
def qr_all_reduce(
|
||||
fa: int,
|
||||
profile: int,
|
||||
inp: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
cast_bf162half: bool,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.qr_all_reduce.default(
|
||||
fa, profile, inp, out, cast_bf162half
|
||||
)
|
||||
|
||||
def qr_destroy(fa: int) -> None:
|
||||
torch.ops.sgl_kernel.qr_destroy.default(fa)
|
||||
|
||||
def qr_max_size() -> int:
|
||||
return torch.ops.sgl_kernel.qr_max_size.default()
|
||||
|
||||
# mscclpp
|
||||
def mscclpp_generate_unique_id() -> bytes:
|
||||
raise NotImplementedError()
|
||||
|
||||
def mscclpp_init_context(
|
||||
unique_id: bytes,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
scratch: torch.Tensor,
|
||||
put_buffer: torch.Tensor,
|
||||
nranks_per_node: int,
|
||||
rank_to_node: List[int],
|
||||
rank_to_ib: List[int],
|
||||
context_selection: int,
|
||||
) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
def mscclpp_allreduce(
|
||||
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
else:
|
||||
|
||||
def init_custom_ar(
|
||||
ipc_tensors: List[int], rank_data: torch.Tensor, rank: int, full_nvlink: bool
|
||||
) -> int:
|
||||
return torch.ops.sgl_kernel.init_custom_ar.default(
|
||||
ipc_tensors, rank_data, rank, full_nvlink
|
||||
)
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
torch.ops.sgl_kernel.dispose.default(fa)
|
||||
|
||||
def all_reduce(
|
||||
fa: int,
|
||||
inp: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
reg_buffer: int,
|
||||
reg_buffer_sz_bytes: int,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.all_reduce.default(
|
||||
fa, inp, out, reg_buffer, reg_buffer_sz_bytes
|
||||
)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa) -> Tuple[List[int], List[int]]:
|
||||
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta.default(fa)
|
||||
|
||||
def register_buffer(fa: int, fake_ipc_ptrs: List[int]) -> None:
|
||||
return torch.ops.sgl_kernel.register_buffer.default(fa, fake_ipc_ptrs)
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[List[int]], offsets: List[List[int]]
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.register_graph_buffers.default(fa, handles, offsets)
|
||||
|
||||
def meta_size() -> int:
|
||||
return torch.ops.sgl_kernel.meta_size.default()
|
||||
|
||||
def mscclpp_generate_unique_id() -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.mscclpp_generate_unique_id.default()
|
||||
|
||||
def mscclpp_init_context(
|
||||
unique_id: torch.Tensor,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
scratch: torch.Tensor,
|
||||
put_buffer: torch.Tensor,
|
||||
nranks_per_node: int,
|
||||
rank_to_node: List[int],
|
||||
rank_to_ib: List[int],
|
||||
context_selection: int,
|
||||
) -> int:
|
||||
return torch.ops.sgl_kernel.mscclpp_init_context.default(
|
||||
unique_id,
|
||||
rank,
|
||||
world_size,
|
||||
scratch,
|
||||
put_buffer,
|
||||
nranks_per_node,
|
||||
rank_to_node,
|
||||
rank_to_ib,
|
||||
context_selection,
|
||||
)
|
||||
|
||||
def mscclpp_allreduce(
|
||||
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.mscclpp_allreduce.default(
|
||||
context, inp, out, nthreads, nblocks
|
||||
)
|
||||
@@ -0,0 +1,138 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
|
||||
torch.ops.sgl_kernel.lightning_attention_decode.default(
|
||||
q, k, v, past_kv, slope, output, new_kv
|
||||
)
|
||||
|
||||
|
||||
def merge_state(
|
||||
v_a: torch.Tensor,
|
||||
s_a: torch.Tensor,
|
||||
v_b: torch.Tensor,
|
||||
s_b: torch.Tensor,
|
||||
v_merged: Optional[torch.Tensor] = None,
|
||||
s_merged: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
s_a = s_a.to(torch.float32)
|
||||
s_b = s_b.to(torch.float32)
|
||||
# Avoid creating new tensors if they are already provided
|
||||
if v_merged is None:
|
||||
v_merged = torch.empty_like(v_a)
|
||||
if s_merged is None:
|
||||
s_merged = torch.empty_like(s_a)
|
||||
torch.ops.sgl_kernel.merge_state.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
|
||||
return v_merged, s_merged
|
||||
|
||||
|
||||
def merge_state_v2(
|
||||
v_a: torch.Tensor,
|
||||
s_a: torch.Tensor,
|
||||
v_b: torch.Tensor,
|
||||
s_b: torch.Tensor,
|
||||
v_merged: Optional[torch.Tensor] = None,
|
||||
s_merged: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
s_a = s_a.to(torch.float32)
|
||||
s_b = s_b.to(torch.float32)
|
||||
# TODO(DefTruth): Currently, the custom merge_attn_states kernel
|
||||
# does not support the FP8 data type and non - CUDA devices.
|
||||
# It may be necessary to fall back to using the Triton kernel.
|
||||
|
||||
# Avoid creating new tensors if they are already provided
|
||||
if v_merged is None:
|
||||
v_merged = torch.empty_like(v_a)
|
||||
if s_merged is None:
|
||||
s_merged = torch.empty_like(s_a)
|
||||
torch.ops.sgl_kernel.merge_state_v2.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
|
||||
return v_merged, s_merged
|
||||
|
||||
|
||||
def cutlass_mla_decode(
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
page_table: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
sm_scale: float,
|
||||
num_kv_splits: int = 1, # Set to 1 to avoid cuda_graph issue by default.
|
||||
) -> torch.Tensor:
|
||||
assert q_nope.ndim == 3, f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
|
||||
assert q_pe.ndim == 3, f"q_pe must be a 3D tensor, but got {q_pe.ndim}"
|
||||
assert (
|
||||
kv_c_and_k_pe_cache.ndim == 3
|
||||
), f"kv_c_and_k_pe_cache must be a 3D tensor, but got {kv_c_and_k_pe_cache.ndim}"
|
||||
|
||||
B_q, H, D_q_nope = q_nope.shape
|
||||
B_q_2, H_2, D_q_pe = q_pe.shape
|
||||
assert (B_q == B_q_2) and (H == H_2)
|
||||
|
||||
_, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape
|
||||
|
||||
D_latent = 512
|
||||
D_rope = 64
|
||||
assert D_q_nope == D_latent
|
||||
assert D_q_pe == D_rope
|
||||
assert D_ckv == D_latent + D_rope
|
||||
|
||||
MAX_HEADS = 128
|
||||
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
|
||||
if H < MAX_HEADS:
|
||||
q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope))
|
||||
q_nope_padded[:, :H] = q_nope
|
||||
q_nope = q_nope_padded
|
||||
|
||||
q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe))
|
||||
q_pe_padded[:, :H] = q_pe
|
||||
q_pe = q_pe_padded
|
||||
|
||||
assert len(page_table.shape) == 2
|
||||
B_block_table, block_num = page_table.shape
|
||||
assert B_block_table == B_q
|
||||
assert block_num > 0, f"block num must be greater than 0, got {block_num}"
|
||||
assert block_num % (128 / PAGE_SIZE) == 0
|
||||
|
||||
# TODO(kaixih@nvidia): support fp8
|
||||
assert q_nope.dtype in (
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
), f"q_nope.dtype needs to be fp16 or bf16 but got {q_nope.dtype}."
|
||||
assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype
|
||||
assert (
|
||||
seq_lens.dtype == torch.int32
|
||||
), f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}."
|
||||
assert (
|
||||
page_table.dtype == torch.int32
|
||||
), f"page_table.dtype needs to be int32 but got {page_table.dtype}."
|
||||
|
||||
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent))
|
||||
|
||||
torch.ops.sgl_kernel.cutlass_mla_decode.default(
|
||||
out,
|
||||
q_nope,
|
||||
q_pe,
|
||||
kv_c_and_k_pe_cache,
|
||||
seq_lens,
|
||||
page_table,
|
||||
workspace,
|
||||
sm_scale,
|
||||
num_kv_splits,
|
||||
)
|
||||
return out[:, :H].contiguous()
|
||||
|
||||
|
||||
def cutlass_mla_get_workspace_size(
|
||||
max_seq_len: int,
|
||||
num_batches: int,
|
||||
sm_count: int = 0,
|
||||
num_kv_splits: int = 1, # Set to 1 to avoid cuda_graph issue by default.
|
||||
) -> int:
|
||||
assert max_seq_len > 0, f"max_seq_len must be greater than 0, got {max_seq_len}"
|
||||
assert num_batches > 0, f"num_batches must be greater than 0, got {num_batches}"
|
||||
return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size.default(
|
||||
max_seq_len, num_batches, sm_count, num_kv_splits
|
||||
)
|
||||
Binary file not shown.
@@ -0,0 +1,112 @@
|
||||
import torch
|
||||
|
||||
|
||||
def get_cutlass_w4a8_moe_mm_data(
|
||||
topk_ids: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
problem_sizes1: torch.Tensor,
|
||||
problem_sizes2: torch.Tensor,
|
||||
input_permutation: torch.Tensor,
|
||||
output_permutation: torch.Tensor,
|
||||
num_experts: int,
|
||||
n: int,
|
||||
k: int,
|
||||
):
|
||||
"""
|
||||
Prepare data necessary to perform CUTLASS grouped matrix multiplications
|
||||
used in CUTLASS-based fused MoE.
|
||||
|
||||
The function takes in topk_ids (token-expert mapping) and uses it to
|
||||
compute:
|
||||
- expert_offsets: Indices that mark at which token index each expert begins
|
||||
its computation after the input is sorted with
|
||||
input_permutation. The number of tokens computed with
|
||||
expert E is expert_offsets[E + 1] - expert_offsets[E]
|
||||
- problem_sizes1, problem_sizes2: MxNxK sizes of each expert's
|
||||
multiplication in two grouped MMs used in
|
||||
the fused MoE operation.
|
||||
- input_permutation: Permutation that must be used to shuffle the input
|
||||
before executing the MMs.
|
||||
- output_permutation: Permutation that must be used to shuffle the output
|
||||
after executing the MMs.
|
||||
"""
|
||||
torch.ops.sgl_kernel.get_cutlass_w4a8_moe_mm_data.default(
|
||||
topk_ids,
|
||||
expert_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
input_permutation,
|
||||
output_permutation,
|
||||
num_experts,
|
||||
n,
|
||||
k,
|
||||
)
|
||||
|
||||
|
||||
def cutlass_w4a8_moe_mm(
|
||||
d: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
a_scales: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
experts_offsets: torch.tensor,
|
||||
problem_sizes: torch.tensor,
|
||||
a_strides: torch.tensor,
|
||||
b_strides: torch.tensor,
|
||||
d_strides: torch.tensor,
|
||||
s_strides: torch.tensor,
|
||||
chunk_size: int = 128,
|
||||
topk: int = 8,
|
||||
):
|
||||
"""
|
||||
Perform grouped matrix multiplication between int4 weights and fp8 activations.
|
||||
|
||||
This function executes multiple GEMM operations in parallel, which is useful for
|
||||
scenarios like Mixture of Experts (MoE) where different inputs go through different
|
||||
experts. The implementation leverages NVIDIA Hopper architecture features for
|
||||
optimal performance with quantized weights.
|
||||
|
||||
Args:
|
||||
d: Output matrices of shape [total_m, total_n]
|
||||
a: Activation matrices in FP8 (float_e4m3_t) format
|
||||
Each tensor should be of shape [total_m, K] in row-major layout
|
||||
b: Weight matrices in packed int4 format
|
||||
Each tensor should be of shape [E, N, K//2] in column-major layout
|
||||
where each byte contains two 4-bit integers
|
||||
a_scales: Scale factors for the inputs
|
||||
b_scales: Scale factors for the quantized weights
|
||||
Each tensor should be of shape [E, K//512, N*8]
|
||||
experts_offsets: Tensor containing expert offsets for determining group boundaries
|
||||
problem_sizes: with shape [num_experts, 3] (M, N, K for each group) (int32)
|
||||
a_strides: Strides information for A matrices
|
||||
b_strides: Strides information for B matrices
|
||||
d_strides: Strides information for D matrices
|
||||
s_strides: Strides information for b_scales matrices
|
||||
chunk_size: Number of elements each scale value applies to (K//512), default to 128
|
||||
|
||||
Requirements:
|
||||
- All tensors must be on a CUDA device
|
||||
- Requires an NVIDIA Hopper GPU (H100)
|
||||
- A tensors must be in float8_e4m3fn format
|
||||
- B tensors must contain packed int4 values (stored as int8)
|
||||
|
||||
Note:
|
||||
The function computes: D = (A * (B * scales))
|
||||
for each group of tensors in parallel
|
||||
"""
|
||||
|
||||
torch.ops.sgl_kernel.cutlass_w4a8_moe_mm.default(
|
||||
d,
|
||||
a,
|
||||
b,
|
||||
a_scales,
|
||||
b_scales,
|
||||
experts_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size,
|
||||
topk,
|
||||
)
|
||||
@@ -0,0 +1,369 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel.utils import get_cuda_stream, is_arch_support_pdl
|
||||
|
||||
|
||||
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
|
||||
# Kudos to @yzh119
|
||||
def rmsnorm(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
enable_pdl: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Root mean square normalization.
|
||||
|
||||
``out[i] = (input[i] / RMS(input)) * weight[i]``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input: torch.Tensor
|
||||
Input tensor, shape (batch_size, hidden_size).
|
||||
weight: torch.Tensor
|
||||
Weight tensor, shape (hidden_size,).
|
||||
eps: float
|
||||
Epsilon for numerical stability.
|
||||
out: Optional[torch.Tensor]
|
||||
The output tensor, if specified, the kernel will update this tensor inplace.
|
||||
enable_pdl: Optional[bool]
|
||||
Whether to enable `programmatic dependent launch
|
||||
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
|
||||
If None, will be automatically enabled on Hopper architecture.
|
||||
|
||||
Returns
|
||||
-------
|
||||
output: torch.Tensor
|
||||
Normalized tensor, shape (batch_size, hidden_size).
|
||||
"""
|
||||
if out is None:
|
||||
out = torch.empty_like(input)
|
||||
if enable_pdl is None:
|
||||
enable_pdl = is_arch_support_pdl()
|
||||
torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, enable_pdl)
|
||||
return out
|
||||
|
||||
|
||||
def fused_add_rmsnorm(
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
enable_pdl: Optional[bool] = None,
|
||||
) -> None:
|
||||
r"""Fused add root mean square normalization.
|
||||
|
||||
Step 1:
|
||||
``residual[i] += input[i]``
|
||||
|
||||
Step 2:
|
||||
``input[i] = (residual[i] / RMS(residual)) * weight[i]``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input: torch.Tensor
|
||||
Input tensor, shape (batch_size, hidden_size).
|
||||
residual: torch.Tensor
|
||||
Residual tensor, shape (batch_size, hidden_size).
|
||||
weight: torch.Tensor
|
||||
Weight tensor, shape (hidden_size,).
|
||||
eps: float
|
||||
Epsilon for numerical stability.
|
||||
enable_pdl: Optional[bool]
|
||||
Whether to enable `programmatic dependent launch
|
||||
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
|
||||
If None, will be automatically enabled on Hopper architecture.
|
||||
"""
|
||||
if enable_pdl is None:
|
||||
enable_pdl = is_arch_support_pdl()
|
||||
torch.ops.sgl_kernel.fused_add_rmsnorm.default(
|
||||
input, residual, weight, eps, enable_pdl
|
||||
)
|
||||
|
||||
|
||||
def gemma_rmsnorm(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
enable_pdl: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Gemma-style root mean square normalization.
|
||||
|
||||
``out[i] = (input[i] / RMS(input)) * (weight[i] + 1)``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input: torch.Tensor
|
||||
Input tensor, shape (batch_size, hidden_size).
|
||||
weight: torch.Tensor
|
||||
Weight tensor, shape (hidden_size,).
|
||||
eps: float
|
||||
Epsilon for numerical stability.
|
||||
out: Optional[torch.Tensor]
|
||||
The output tensor, if specified, the kernel will update this tensor inplace.
|
||||
enable_pdl: Optional[bool]
|
||||
Whether to enable `programmatic dependent launch
|
||||
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
|
||||
If None, will be automatically enabled on Hopper architecture.
|
||||
|
||||
Returns
|
||||
-------
|
||||
output: torch.Tensor
|
||||
Gemma Normalized tensor, shape (batch_size, hidden_size).
|
||||
"""
|
||||
if out is None:
|
||||
out = torch.empty_like(input)
|
||||
if enable_pdl is None:
|
||||
enable_pdl = is_arch_support_pdl()
|
||||
torch.ops.sgl_kernel.gemma_rmsnorm.default(out, input, weight, eps, enable_pdl)
|
||||
return out
|
||||
|
||||
|
||||
def gemma_fused_add_rmsnorm(
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
enable_pdl: Optional[bool] = None,
|
||||
) -> None:
|
||||
r"""Gemma-style fused add root mean square normalization.
|
||||
|
||||
Step 1:
|
||||
``residual[i] += input[i]``
|
||||
|
||||
Step 2:
|
||||
``input[i] = (residual[i] / RMS(residual)) * (weight + 1)``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input: torch.Tensor
|
||||
Input tensor, shape (batch_size, hidden_size).
|
||||
residual: torch.Tensor
|
||||
Residual tensor, shape (batch_size, hidden_size).
|
||||
weight: torch.Tensor
|
||||
Weight tensor, shape (hidden_size,).
|
||||
eps: float
|
||||
Epsilon for numerical stability.
|
||||
enable_pdl: Optional[bool]
|
||||
Whether to enable `programmatic dependent launch
|
||||
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
|
||||
If None, will be automatically enabled on Hopper architecture.
|
||||
"""
|
||||
if enable_pdl is None:
|
||||
enable_pdl = is_arch_support_pdl()
|
||||
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default(
|
||||
input, residual, weight, eps, enable_pdl
|
||||
)
|
||||
|
||||
|
||||
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
|
||||
assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}"
|
||||
assert (
|
||||
input.shape[:-1] == output.shape[:-1]
|
||||
), f"{input.shape[:-1]} != {output.shape[:-1]}"
|
||||
assert (
|
||||
input.shape[-1] == 2 * output.shape[-1]
|
||||
), f"{input.shape[-1]} != {2 * output.shape[-1]}"
|
||||
|
||||
|
||||
def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||
raise ValueError("The pointers must be multiple of 16 bytes.")
|
||||
if out is not None:
|
||||
_check_shape(input, out)
|
||||
else:
|
||||
out = torch.empty(
|
||||
input.shape[:-1] + (input.shape[-1] // 2,),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.silu_and_mul.default(out, input)
|
||||
return out
|
||||
|
||||
|
||||
def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||
raise ValueError("The pointers must be multiple of 16 bytes.")
|
||||
if out is not None:
|
||||
_check_shape(input, out)
|
||||
else:
|
||||
out = torch.empty(
|
||||
input.shape[:-1] + (input.shape[-1] // 2,),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input)
|
||||
return out
|
||||
|
||||
|
||||
def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||
raise ValueError("The pointers must be multiple of 16 bytes.")
|
||||
if out is not None:
|
||||
_check_shape(input, out)
|
||||
else:
|
||||
out = torch.empty(
|
||||
input.shape[:-1] + (input.shape[-1] // 2,),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.gelu_and_mul.default(out, input)
|
||||
return out
|
||||
|
||||
|
||||
if torch.version.hip is not None:
|
||||
|
||||
def gelu_quick(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
"""
|
||||
Quick-GELU: y = x * sigmoid(1.702 * x)
|
||||
|
||||
The CUDA/HIP kernel uses 128-bit (16-byte) vector loads & stores,
|
||||
so the last-dimension byte length must be a multiple of 16 bytes.
|
||||
"""
|
||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||
raise ValueError(
|
||||
f"The last dimension ({input.shape[-1]}) x itemsize "
|
||||
f"({input.dtype.itemsize}) must be a multiple of 16 bytes."
|
||||
)
|
||||
|
||||
if out is not None:
|
||||
assert input.shape == out.shape, f"{input.shape} != {out.shape}"
|
||||
else:
|
||||
out = torch.empty_like(input)
|
||||
|
||||
torch.ops.sgl_kernel.gelu_quick(out, input)
|
||||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedSetKVBufferArg:
|
||||
"""
|
||||
value : Optional[torch.Tensor]
|
||||
Value tensor, shape: ``(nnz, num_v_heads * head_size)``.
|
||||
k_buffer : Optional[torch.Tensor]
|
||||
Buffer for keys, shape: ``(nnz, num_k_heads * head_size)``.
|
||||
v_buffer : Optional[torch.Tensor]
|
||||
Buffer for values, shape: ``(nnz, num_v_heads * head_size)``.
|
||||
k_scale : Optional[float]
|
||||
Scale factor for keys.
|
||||
v_scale : Optional[float]
|
||||
Scale factor for values.
|
||||
cache_loc : Optional[torch.Tensor]
|
||||
Cache location tensor, used for indexing kv cache.
|
||||
"""
|
||||
|
||||
value: torch.Tensor
|
||||
k_buffer: torch.Tensor
|
||||
v_buffer: torch.Tensor
|
||||
k_scale: Optional[float]
|
||||
v_scale: Optional[float]
|
||||
cache_loc: torch.Tensor
|
||||
|
||||
|
||||
def apply_rope_with_cos_sin_cache_inplace(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
head_size: int,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool = True,
|
||||
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
||||
enable_pdl: Optional[bool] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Apply rotary embedding to keys and queries with precomputed cos/sin values.
|
||||
This is designed to be compatible with the SGL/vLLM implementation.
|
||||
The result is inplace applied to the input tensors.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
positions : torch.Tensor
|
||||
Position indices, shape: ``(nnz)``.
|
||||
query : torch.Tensor
|
||||
Query tensor, shape: ``(nnz, num_q_heads * head_size)``.
|
||||
key : torch.Tensor
|
||||
Key tensor, shape: ``(nnz, num_k_heads * head_size)``.
|
||||
cos_sin_cache : torch.Tensor
|
||||
Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``.
|
||||
Cosine is the first half and Sine is the second half on rotary_dim.
|
||||
is_neox : bool
|
||||
Whether to use Neox style RoPE, default: ``True``.
|
||||
|
||||
* If ``True``, the last dimension of the query/key tensor is not interleaved, i.e.,
|
||||
we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half
|
||||
dimensions ``([..., head_dim//2:])``.
|
||||
|
||||
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
|
||||
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
|
||||
fused_set_kv_buffer_arg : FusedSetKVBufferArg
|
||||
Fuse the set-kv-buffer operation into this kernel
|
||||
|
||||
Note
|
||||
----
|
||||
The rotary dimension is determined by the cosine cache and sine cache.
|
||||
"""
|
||||
if cos_sin_cache.dtype != torch.float32:
|
||||
raise ValueError("cos_sin_cache should be float32")
|
||||
|
||||
if enable_pdl is None:
|
||||
# the non-fused branch does not yet support PDL, but after we switch to our impl for that branch it will
|
||||
enable_pdl = is_arch_support_pdl() and (fused_set_kv_buffer_arg is not None)
|
||||
|
||||
if (a := fused_set_kv_buffer_arg) is not None:
|
||||
assert a.k_scale is None, "k_scale is not yet supported"
|
||||
assert a.v_scale is None, "v_scale is not yet supported"
|
||||
assert a.cache_loc.dtype == torch.int64, f"{a.cache_loc.dtype=}"
|
||||
|
||||
def _view_3d(x):
|
||||
return x.view(x.shape[0], -1, head_size)
|
||||
|
||||
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
|
||||
_view_3d(query),
|
||||
_view_3d(key),
|
||||
_view_3d(query),
|
||||
_view_3d(key),
|
||||
cos_sin_cache,
|
||||
positions.long(),
|
||||
(not is_neox),
|
||||
enable_pdl,
|
||||
get_cuda_stream(),
|
||||
(
|
||||
_view_3d(fused_set_kv_buffer_arg.value)
|
||||
if fused_set_kv_buffer_arg is not None
|
||||
else None
|
||||
),
|
||||
(
|
||||
_view_3d(fused_set_kv_buffer_arg.k_buffer)
|
||||
if fused_set_kv_buffer_arg is not None
|
||||
else None
|
||||
),
|
||||
(
|
||||
_view_3d(fused_set_kv_buffer_arg.v_buffer)
|
||||
if fused_set_kv_buffer_arg is not None
|
||||
else None
|
||||
),
|
||||
(
|
||||
fused_set_kv_buffer_arg.cache_loc
|
||||
if fused_set_kv_buffer_arg is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downcast_fp8(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
k_out: torch.Tensor,
|
||||
v_out: torch.Tensor,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
loc: torch.Tensor,
|
||||
mult: int = 1,
|
||||
offset: int = 0,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.downcast_fp8(
|
||||
k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, get_cuda_stream()
|
||||
)
|
||||
@@ -0,0 +1,288 @@
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
try:
|
||||
from sgl_kernel import flash_ops
|
||||
except:
|
||||
raise ImportError("Can not import sgl_kernel. Please check your installation.")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def is_fa3_supported(device=None) -> bool:
|
||||
# There some fa3 FYI
|
||||
# FA3 can fail without a enough shared memory for a some shapes, such as higher
|
||||
# hidden_dim or some special cases.
|
||||
# Right now, fa3 is supported for sm80/sm87 and sm86/sm89. The main different
|
||||
# Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information
|
||||
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
|
||||
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
|
||||
# That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
|
||||
return (torch.version.cuda >= "12.3") and (
|
||||
torch.cuda.get_device_capability(device)[0] == 9
|
||||
or torch.cuda.get_device_capability(device)[0] == 8
|
||||
)
|
||||
|
||||
|
||||
def maybe_contiguous(x):
|
||||
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
||||
|
||||
|
||||
def flash_attn_with_kvcache(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
k=None,
|
||||
v=None,
|
||||
qv=None,
|
||||
rotary_cos=None,
|
||||
rotary_sin=None,
|
||||
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
|
||||
cache_batch_idx: Optional[torch.Tensor] = None,
|
||||
cache_leftpad: Optional[torch.Tensor] = None,
|
||||
page_table: Optional[torch.Tensor] = None,
|
||||
cu_seqlens_q: Optional[torch.Tensor] = None,
|
||||
cu_seqlens_k_new: Optional[torch.Tensor] = None,
|
||||
max_seqlen_q: Optional[int] = None,
|
||||
rotary_seqlens: Optional[torch.Tensor] = None,
|
||||
q_descale: Optional[torch.Tensor] = None,
|
||||
k_descale: Optional[torch.Tensor] = None,
|
||||
v_descale: Optional[torch.Tensor] = None,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
window_size=(-1, -1), # -1 means infinite context window
|
||||
softcap=0.0, # 0.0 means deactivated
|
||||
rotary_interleaved=True,
|
||||
scheduler_metadata=None,
|
||||
num_splits=0, # Can be tuned for speed
|
||||
pack_gqa=None, # Can be tuned for speed
|
||||
sm_margin=0, # Can be tuned if some SMs are used for communication
|
||||
return_softmax_lse=False,
|
||||
sinks=None,
|
||||
):
|
||||
"""
|
||||
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
|
||||
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
|
||||
the previous step, and update them with the new keys/values from the current step, and do
|
||||
attention with the updated cache, all in 1 kernel.
|
||||
|
||||
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
|
||||
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
|
||||
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
|
||||
|
||||
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
|
||||
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
|
||||
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
|
||||
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
|
||||
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
|
||||
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
|
||||
|
||||
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
|
||||
|
||||
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
||||
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
||||
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
||||
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
||||
|
||||
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
||||
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
||||
1 1 1 1 0
|
||||
1 1 1 1 1
|
||||
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
||||
0 0
|
||||
0 0
|
||||
0 0
|
||||
1 0
|
||||
1 1
|
||||
If the row of the mask is all zero, the output will be zero.
|
||||
|
||||
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
||||
will only attend to keys between
|
||||
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
||||
|
||||
Note: Does not support backward pass.
|
||||
|
||||
Arguments:
|
||||
q: (batch_size, seqlen, nheads, headdim)
|
||||
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
|
||||
or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
|
||||
page_block_size must be a multiple of 256.
|
||||
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
|
||||
or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
|
||||
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
|
||||
k with k_cache, starting at the indices specified by cache_seqlens.
|
||||
v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
|
||||
qv [optional]: (batch_size, seqlen, nheads, headdim_v)
|
||||
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
|
||||
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
|
||||
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
|
||||
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
|
||||
KV cache.
|
||||
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
|
||||
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
|
||||
If the indices are not distinct, and k and v are provided, the values updated in the cache
|
||||
might come from any of the duplicate indices.
|
||||
cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
|
||||
page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
||||
softcap: float. Anything > 0 activates softcapping attention.
|
||||
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
|
||||
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
|
||||
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
|
||||
(i.e. GPT-NeoX style).
|
||||
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
|
||||
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
|
||||
to automatically determine the number of splits.
|
||||
Don't change this unless you know what you are doing.
|
||||
return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
|
||||
|
||||
Return:
|
||||
out: (batch_size, seqlen, nheads, headdim).
|
||||
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
"""
|
||||
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
|
||||
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
|
||||
if softmax_scale is None:
|
||||
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
|
||||
-0.5
|
||||
)
|
||||
if cache_seqlens is not None and isinstance(cache_seqlens, int):
|
||||
cache_seqlens = torch.full(
|
||||
(k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
|
||||
)
|
||||
cache_seqlens = maybe_contiguous(cache_seqlens)
|
||||
|
||||
q, k_cache, k, v = [maybe_contiguous(x) for x in (q, k_cache, k, v)]
|
||||
v_cache = (
|
||||
v_cache.contiguous()
|
||||
if v_cache.stride(-1) != 1 and v_cache.stride(-3) != 1
|
||||
else v_cache
|
||||
)
|
||||
cu_seqlens_q, cu_seqlens_k_new = [
|
||||
maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k_new)
|
||||
]
|
||||
page_table, cache_batch_idx, cache_leftpad = [
|
||||
maybe_contiguous(x) for x in (page_table, cache_batch_idx, cache_leftpad)
|
||||
]
|
||||
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
|
||||
rotary_seqlens = maybe_contiguous(rotary_seqlens)
|
||||
|
||||
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
k,
|
||||
v,
|
||||
qv,
|
||||
None, # out
|
||||
cu_seqlens_q,
|
||||
None, # cu_seqlens_k
|
||||
cu_seqlens_k_new,
|
||||
None, # seqused_q
|
||||
cache_seqlens,
|
||||
max_seqlen_q,
|
||||
None, # max_seqlen_k
|
||||
page_table,
|
||||
cache_batch_idx,
|
||||
cache_leftpad,
|
||||
rotary_cos,
|
||||
rotary_sin,
|
||||
rotary_seqlens,
|
||||
q_descale,
|
||||
k_descale,
|
||||
v_descale,
|
||||
softmax_scale,
|
||||
causal,
|
||||
window_size[0],
|
||||
window_size[1],
|
||||
softcap,
|
||||
rotary_interleaved,
|
||||
scheduler_metadata,
|
||||
num_splits,
|
||||
pack_gqa,
|
||||
sm_margin,
|
||||
sinks,
|
||||
)
|
||||
# return (out, softmax_lse) if return_softmax_lse else out
|
||||
return (out, softmax_lse, *rest) if return_softmax_lse else out
|
||||
|
||||
|
||||
def flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
seqused_q=None,
|
||||
seqused_k=None,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
qv=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
window_size=(-1, -1),
|
||||
softcap=0.0,
|
||||
num_splits=1,
|
||||
pack_gqa=None,
|
||||
sm_margin=0,
|
||||
return_softmax_lse=False,
|
||||
sinks=None,
|
||||
):
|
||||
if not is_fa3_supported():
|
||||
raise NotImplementedError(
|
||||
"flash_attn at sgl-kernel is only supported on sm90 and above"
|
||||
)
|
||||
|
||||
if softmax_scale is None:
|
||||
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
|
||||
-0.5
|
||||
)
|
||||
|
||||
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
None, # k_new
|
||||
None, # v_new
|
||||
qv, # qv
|
||||
None, # out
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
None, # cu_seqlens_k_new
|
||||
seqused_q,
|
||||
seqused_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
None, # page_table,
|
||||
None, # kv_batch_idx
|
||||
None, # leftpad_k
|
||||
None, # rotary cos
|
||||
None, # rotary sin
|
||||
None, # seqlens_rotary
|
||||
q_descale,
|
||||
k_descale,
|
||||
v_descale,
|
||||
softmax_scale,
|
||||
causal,
|
||||
window_size[0],
|
||||
window_size[1],
|
||||
softcap,
|
||||
is_rotary_interleaved=False,
|
||||
scheduler_metadata=None,
|
||||
num_splits=num_splits,
|
||||
pack_gqa=pack_gqa,
|
||||
sm_margin=sm_margin,
|
||||
sinks=sinks,
|
||||
)
|
||||
|
||||
return (out, softmax_lse, *rest) if return_softmax_lse else out
|
||||
@@ -0,0 +1,225 @@
|
||||
import functools
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel import silu_and_mul
|
||||
|
||||
|
||||
def get_scalar_type(num_bits: int, has_zp: bool):
|
||||
from sgl_kernel.scalar_type import scalar_types
|
||||
|
||||
if has_zp:
|
||||
assert num_bits == 4
|
||||
return scalar_types.uint4
|
||||
else:
|
||||
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
|
||||
|
||||
|
||||
def fused_marlin_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
g_idx1: Optional[torch.Tensor] = None,
|
||||
g_idx2: Optional[torch.Tensor] = None,
|
||||
sort_indices1: Optional[torch.Tensor] = None,
|
||||
sort_indices2: Optional[torch.Tensor] = None,
|
||||
w1_zeros: Optional[torch.Tensor] = None,
|
||||
w2_zeros: Optional[torch.Tensor] = None,
|
||||
workspace: Optional[torch.Tensor] = None,
|
||||
num_bits: int = 8,
|
||||
is_k_full: bool = True,
|
||||
inplace: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
weights, w1 and w2, and top-k gating mechanism.
|
||||
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- w1_scale (torch.Tensor): Scale to be used for w1.
|
||||
- w2_scale (torch.Tensor): Scale to be used for w2.
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
|
||||
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
|
||||
- sort_indices1 (Optional[torch.Tensor]): The first act_order input
|
||||
permutation.
|
||||
- sort_indices2 (Optional[torch.Tensor]): The second act_order input
|
||||
permutation.
|
||||
- topk_weights (torch.Tensor): Top-k weights.
|
||||
- topk_ids (torch.Tensor): Indices of topk-k elements.
|
||||
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
|
||||
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
|
||||
- num_bits (bool): The number of bits in expert weights quantization.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
# Delay the import to avoid circular dependency
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (
|
||||
moe_align_block_size,
|
||||
try_get_optimal_moe_config,
|
||||
)
|
||||
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1"
|
||||
assert hidden_states.shape[1] == w2.shape[2] // (
|
||||
num_bits // 2
|
||||
), "Hidden size mismatch w2"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
|
||||
assert num_bits in [4, 8]
|
||||
|
||||
M, K = hidden_states.shape
|
||||
E = w1.shape[0]
|
||||
N = w2.shape[1] * 16
|
||||
topk = topk_ids.shape[1]
|
||||
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
w1.shape,
|
||||
w2.shape,
|
||||
topk_ids.shape[1],
|
||||
None,
|
||||
is_marlin=True,
|
||||
)
|
||||
config = get_config_func(M)
|
||||
|
||||
block_size_m = config["BLOCK_SIZE_M"]
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
topk_ids, block_size_m, global_num_experts
|
||||
)
|
||||
|
||||
if workspace is None:
|
||||
max_workspace_size = (max(2 * N, K) // 64) * (
|
||||
sorted_token_ids.size(0) // block_size_m
|
||||
)
|
||||
device = hidden_states.device
|
||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
max_workspace_size = min(max_workspace_size, sms * 4)
|
||||
workspace = torch.zeros(
|
||||
max_workspace_size, dtype=torch.int, device=device, requires_grad=False
|
||||
)
|
||||
|
||||
scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None)
|
||||
scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None)
|
||||
|
||||
intermediate_cache2 = torch.empty(
|
||||
(M * topk_ids.shape[1], N),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache13 = torch.empty(
|
||||
(M * topk_ids.shape[1] * max(2 * N, K),),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache1 = intermediate_cache13[: M * topk_ids.shape[1] * 2 * N]
|
||||
intermediate_cache1 = intermediate_cache1.view(-1, 2 * N)
|
||||
intermediate_cache3 = intermediate_cache13[: M * topk_ids.shape[1] * K]
|
||||
intermediate_cache3 = intermediate_cache3.view(-1, K)
|
||||
|
||||
use_atomic_add = (
|
||||
hidden_states.dtype == torch.half
|
||||
or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
|
||||
)
|
||||
|
||||
intermediate_cache1 = torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default(
|
||||
hidden_states,
|
||||
intermediate_cache1,
|
||||
w1,
|
||||
w1_scale,
|
||||
w1_zeros,
|
||||
g_idx1,
|
||||
sort_indices1,
|
||||
workspace,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
topk_weights,
|
||||
moe_block_size=block_size_m,
|
||||
top_k=topk,
|
||||
mul_topk_weights=False,
|
||||
is_ep=expert_map is not None,
|
||||
b_q_type_id=scalar_type1.id,
|
||||
size_m=M,
|
||||
size_n=2 * N,
|
||||
size_k=K,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
silu_and_mul(intermediate_cache1.view(-1, 2 * N), intermediate_cache2)
|
||||
|
||||
if expert_map is not None:
|
||||
intermediate_cache3.zero_()
|
||||
|
||||
intermediate_cache3 = torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default(
|
||||
intermediate_cache2,
|
||||
intermediate_cache3,
|
||||
w2,
|
||||
w2_scale,
|
||||
w2_zeros,
|
||||
g_idx2,
|
||||
sort_indices2,
|
||||
workspace,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
topk_weights,
|
||||
moe_block_size=block_size_m,
|
||||
top_k=1,
|
||||
mul_topk_weights=True,
|
||||
is_ep=expert_map is not None,
|
||||
b_q_type_id=scalar_type2.id,
|
||||
size_m=M * topk,
|
||||
size_n=K,
|
||||
size_k=N,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
).view(-1, topk, K)
|
||||
|
||||
output = hidden_states if inplace else torch.empty_like(hidden_states)
|
||||
return torch.sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape), dim=1, out=output
|
||||
)
|
||||
|
||||
|
||||
def fused_marlin_moe_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
g_idx1: Optional[torch.Tensor] = None,
|
||||
g_idx2: Optional[torch.Tensor] = None,
|
||||
sort_indices1: Optional[torch.Tensor] = None,
|
||||
sort_indices2: Optional[torch.Tensor] = None,
|
||||
w1_zeros: Optional[torch.Tensor] = None,
|
||||
w2_zeros: Optional[torch.Tensor] = None,
|
||||
num_bits: int = 8,
|
||||
is_k_full: bool = True,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
550
sgl-kernel/build/lib.linux-x86_64-cpython-310/sgl_kernel/gemm.py
Normal file
550
sgl-kernel/build/lib.linux-x86_64-cpython-310/sgl_kernel/gemm.py
Normal file
@@ -0,0 +1,550 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from sgl_kernel.scalar_type import ScalarType
|
||||
from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
|
||||
|
||||
|
||||
def awq_dequantize(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
||||
) -> torch.ByteTensor:
|
||||
return torch.ops.sgl_kernel.awq_dequantize.default(qweight, scales, qzeros)
|
||||
|
||||
|
||||
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||
return torch.ops.sgl_kernel.int8_scaled_mm.default(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
out_dtype,
|
||||
bias,
|
||||
)
|
||||
|
||||
|
||||
def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
|
||||
return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm.default(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
out_dtype,
|
||||
)
|
||||
|
||||
|
||||
def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||
return torch.ops.sgl_kernel.fp8_scaled_mm.default(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
out_dtype,
|
||||
bias,
|
||||
)
|
||||
|
||||
|
||||
def _bmm_fp8_internal(
|
||||
workspace_buffer: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
D: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
) -> None:
|
||||
cublas_handle = torch.cuda.current_blas_handle()
|
||||
torch.ops.sgl_kernel.bmm_fp8.default(
|
||||
A,
|
||||
B,
|
||||
D,
|
||||
A_scale,
|
||||
B_scale,
|
||||
workspace_buffer,
|
||||
cublas_handle,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
|
||||
|
||||
def bmm_fp8(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if out is None:
|
||||
out = torch.empty(
|
||||
(A.shape[0], A.shape[1], B.shape[2]),
|
||||
device=A.device,
|
||||
dtype=dtype,
|
||||
)
|
||||
workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device)
|
||||
_bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale)
|
||||
return out
|
||||
|
||||
|
||||
def dsv3_fused_a_gemm(
|
||||
mat_a: torch.Tensor,
|
||||
mat_b: torch.Tensor,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if output is None:
|
||||
output = torch.empty(
|
||||
(mat_a.shape[0], mat_b.shape[1]),
|
||||
device=mat_a.device,
|
||||
dtype=mat_a.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.dsv3_fused_a_gemm.default(output, mat_a, mat_b)
|
||||
return output
|
||||
|
||||
|
||||
def sgl_per_token_group_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
output_q: torch.Tensor,
|
||||
output_s: torch.Tensor,
|
||||
group_size: int,
|
||||
eps: float,
|
||||
fp8_min: float,
|
||||
fp8_max: float,
|
||||
scale_ue8m0: bool,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default(
|
||||
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
||||
)
|
||||
|
||||
|
||||
def sgl_per_token_group_quant_int8(
|
||||
input: torch.Tensor,
|
||||
output_q: torch.Tensor,
|
||||
output_s: torch.Tensor,
|
||||
group_size: int,
|
||||
eps: float,
|
||||
int8_min: float,
|
||||
int8_max: float,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.sgl_per_token_group_quant_int8.default(
|
||||
input, output_q, output_s, group_size, eps, int8_min, int8_max
|
||||
)
|
||||
|
||||
|
||||
def sgl_per_tensor_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
output_q: torch.Tensor,
|
||||
output_s: torch.Tensor,
|
||||
is_static: bool,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8.default(
|
||||
input, output_q, output_s, is_static
|
||||
)
|
||||
|
||||
|
||||
def sgl_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
output_q: torch.Tensor,
|
||||
output_s: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.sgl_per_token_quant_fp8.default(input, output_q, output_s)
|
||||
|
||||
|
||||
def cutlass_scaled_fp4_mm(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
block_scale_a: torch.Tensor,
|
||||
block_scale_b: torch.Tensor,
|
||||
alpha: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
assert a.ndim == 2 and b.ndim == 2
|
||||
m, n = a.shape[0], b.shape[0]
|
||||
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
||||
torch.ops.sgl_kernel.cutlass_scaled_fp4_mm.default(
|
||||
out, a, b, block_scale_a, block_scale_b, alpha
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def scaled_fp4_quant(
|
||||
input: torch.Tensor, input_global_scale: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP4 and return quantized tensor and scale.
|
||||
|
||||
This function quantizes the last dimension of the given tensor `input`. For
|
||||
every 16 consecutive elements, a single dynamically computed scaling factor
|
||||
is shared. This scaling factor is quantized using the `input_global_scale`
|
||||
and is stored in a swizzled layout (see
|
||||
https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x).
|
||||
|
||||
Args:
|
||||
input: The input tensor to be quantized to FP4
|
||||
input_global_scale: A scalar scaling factor for the entire tensor.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
|
||||
two values are packed into a uint8 and float8_e4m3 scaling factors
|
||||
in a sizzled layout.
|
||||
"""
|
||||
assert input.ndim >= 1, f"input.ndim needs to be >= 1, but got {input.ndim}."
|
||||
other_dims = 1 if input.ndim == 1 else -1
|
||||
input = input.reshape(other_dims, input.shape[-1])
|
||||
m, n = input.shape
|
||||
block_size = 16
|
||||
device = input.device
|
||||
|
||||
assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}."
|
||||
assert input.dtype in (
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
), f"input.dtype needs to be fp16 or bf16 but got {input.dtype}."
|
||||
|
||||
# Two fp4 values will be packed into an uint8.
|
||||
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
|
||||
|
||||
# We use the rounded values to store the swizzled values. Then, the scaling
|
||||
# factors in float8_e4m3fn are packed into an int32 for every 4 values.
|
||||
rounded_m = ((m + 128 - 1) // 128) * 128
|
||||
scale_n = n // block_size
|
||||
rounded_n = ((scale_n + 4 - 1) // 4) * 4
|
||||
# padded part should be zeroed out
|
||||
if rounded_n > scale_n:
|
||||
output_scale = torch.zeros(
|
||||
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
output_scale = torch.empty(
|
||||
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
torch.ops.sgl_kernel.scaled_fp4_quant.default(
|
||||
output, input, output_scale, input_global_scale
|
||||
)
|
||||
output_scale = output_scale.view(torch.float8_e4m3fn)
|
||||
return output, output_scale
|
||||
|
||||
|
||||
def qserve_w4a8_per_chn_gemm(
|
||||
in_feats: torch.Tensor,
|
||||
kernel: torch.Tensor,
|
||||
wscales: torch.Tensor,
|
||||
ascales: torch.Tensor,
|
||||
w_szs: torch.Tensor,
|
||||
a_ssums: torch.Tensor,
|
||||
out_feats: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if out_feats is None:
|
||||
# NOTE(HandH1998): qserve_w4a8_per_chn_gemm only supports out dtype=torch.float16 now
|
||||
out_feats = torch.empty(
|
||||
(in_feats.shape[0], kernel.shape[0]),
|
||||
device=in_feats.device,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
torch.ops.sgl_kernel.qserve_w4a8_per_chn_gemm.default(
|
||||
in_feats, kernel, wscales, ascales, w_szs, a_ssums, out_feats
|
||||
)
|
||||
return out_feats
|
||||
|
||||
|
||||
def qserve_w4a8_per_group_gemm(
|
||||
in_feats: torch.Tensor,
|
||||
kernel: torch.Tensor,
|
||||
zeros: torch.Tensor,
|
||||
scales_i8: torch.Tensor,
|
||||
wscales: torch.Tensor,
|
||||
ascales: torch.Tensor,
|
||||
out_feats: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if out_feats is None:
|
||||
# NOTE(HandH1998): qserve_w4a8_per_group_gemm only supports out dtype=torch.float16 now
|
||||
out_feats = torch.empty(
|
||||
(in_feats.shape[0], kernel.shape[0]),
|
||||
device=in_feats.device,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
torch.ops.sgl_kernel.qserve_w4a8_per_group_gemm.default(
|
||||
in_feats, kernel, zeros, scales_i8, wscales, ascales, out_feats
|
||||
)
|
||||
return out_feats
|
||||
|
||||
|
||||
def dsv3_router_gemm(
|
||||
hidden_states: torch.Tensor,
|
||||
router_weights: torch.Tensor,
|
||||
out_dtype: torch.dtype = torch.bfloat16,
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
router_weights.shape[0],
|
||||
device=hidden_states.device,
|
||||
dtype=out_dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.dsv3_router_gemm(
|
||||
output,
|
||||
hidden_states,
|
||||
router_weights,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape):
|
||||
output_tensor = torch.empty(
|
||||
output_tensor_shape,
|
||||
device=input_tensor.device,
|
||||
dtype=input_tensor.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.shuffle_rows.default(input_tensor, dst2src_map, output_tensor)
|
||||
return output_tensor
|
||||
|
||||
|
||||
def scaled_fp4_grouped_quant(
|
||||
input_tensor: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Quantize input tensor to FP4 and return quantized tensor and scale, for
|
||||
grouped gemm inputs (e.g., grouped_gemm_nt_masked for flashinfer).
|
||||
Args:
|
||||
input: The input tensor to be quantized to FP4, with shape (l, m, k)
|
||||
l is number of groups, m is number of tokens per group, k is number of features.
|
||||
input_global_scale: A scalar scaling factor for the entire tensor, with
|
||||
shape (l,).
|
||||
Outputs:
|
||||
output: The quantized tensor in FP4, with shape (m, k // 2, l) but the physical
|
||||
layout is (l, m, k // 2). `// 2` is because two fp4 values are packed into
|
||||
an uint8.
|
||||
output_scales: The blockscale tensor in FP8-E4M3, with shape (32, 4, rm, 4, rk, l)
|
||||
but the physical layout is (l, rm, rk, 32, 4, 4).
|
||||
Note:
|
||||
For the shape of output_scales, `32 * 4 * rm` is a padded m to nearest multiple of 128.
|
||||
`4 * rk` is a padded `k // 16` to nearest multiple of 4. These layout constants are
|
||||
required by the NVIDIA Blackwell MMA operations.
|
||||
"""
|
||||
device = input_tensor.device
|
||||
l, m, k = input_tensor.shape
|
||||
sf_vec_size = 16
|
||||
assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}."
|
||||
|
||||
scale_k = k // sf_vec_size
|
||||
padded_k = (scale_k + (4 - 1)) // 4 * 4
|
||||
padded_k_int32 = padded_k // 4
|
||||
padded_m = (m + (128 - 1)) // 128 * 128
|
||||
output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8)
|
||||
output_scales = torch.empty(
|
||||
l, padded_m, padded_k_int32, device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
torch.ops.sgl_kernel.silu_and_mul_scaled_fp4_experts_quant.default(
|
||||
output.view(l * m, k // 2),
|
||||
output_scales.view(l * padded_m, padded_k_int32),
|
||||
input_tensor.view(l * m, k),
|
||||
input_global_scale,
|
||||
mask,
|
||||
use_silu_and_mul=False,
|
||||
)
|
||||
# The physical layout of the output is (l, m, k // 2), but we want to return a
|
||||
# logical layout (m, k // 2, l) required by the flashinfer masked group gemm.
|
||||
output = output.permute(1, 2, 0)
|
||||
# The physical layout of the output scales is already swizzled as (l, rm, rk, 32, 4, 4), a
|
||||
# requirement for the flashinfer masked group gemm, where rm=m/128 and rk=k/4. The logic
|
||||
# layout is (32, 4, rm, 4, rk, l).
|
||||
output_scales = output_scales.view(torch.float8_e4m3fn).view(
|
||||
l, padded_m // 128, padded_k // 4, 32, 4, 4
|
||||
)
|
||||
output_scales = output_scales.permute(3, 4, 1, 5, 2, 0)
|
||||
return output, output_scales
|
||||
|
||||
|
||||
def silu_and_mul_scaled_fp4_grouped_quant(
|
||||
input_tensor: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Quantize input tensor to FP4 and return quantized tensor and scale, for
|
||||
grouped gemm inputs (e.g., grouped_gemm_nt_masked for flashinfer).
|
||||
Args:
|
||||
input: The input tensor to be quantized to FP4, with shape (l, m, k * 2)
|
||||
l is number of groups, m is number of tokens per group, k is number of features.
|
||||
input_global_scale: A scalar scaling factor for the entire tensor, with
|
||||
shape (l,).
|
||||
mask: The mask tensor, with shape (l,)
|
||||
Outputs:
|
||||
output: The quantized tensor in FP4, with shape (m, k // 2, l) but the physical
|
||||
layout is (l, m, k // 2). `// 2` is because two fp4 values are packed into
|
||||
an uint8.
|
||||
output_scales: The blockscale tensor in FP8-E4M3, with shape (32, 4, rm, 4, rk, l)
|
||||
but the physical layout is (l, rm, rk, 32, 4, 4).
|
||||
Note:
|
||||
For the shape of output_scales, `32 * 4 * rm` is a padded m to nearest multiple of 128.
|
||||
`4 * rk` is a padded `k // 16` to nearest multiple of 4. These layout constants are
|
||||
required by the NVIDIA Blackwell MMA operations.
|
||||
"""
|
||||
device = input_tensor.device
|
||||
l, m, k_by_2 = input_tensor.shape
|
||||
k = k_by_2 // 2
|
||||
sf_vec_size = 16
|
||||
assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}."
|
||||
|
||||
scale_k = k // sf_vec_size
|
||||
padded_k = (scale_k + (4 - 1)) // 4 * 4
|
||||
padded_k_int32 = padded_k // 4
|
||||
padded_m = (m + (128 - 1)) // 128 * 128
|
||||
output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8)
|
||||
output_scales = torch.empty(
|
||||
l, padded_m, padded_k_int32, device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
torch.ops.sgl_kernel.silu_and_mul_scaled_fp4_experts_quant.default(
|
||||
output.view(l * m, k // 2),
|
||||
output_scales.view(l * padded_m, padded_k_int32),
|
||||
input_tensor.view(l * m, k_by_2),
|
||||
input_global_scale,
|
||||
mask,
|
||||
use_silu_and_mul=True,
|
||||
)
|
||||
# The physical layout of the output is (l, m, k // 2), but we want to return a
|
||||
# logical layout (m, k // 2, l) required by the flashinfer masked group gemm.
|
||||
output = output.permute(1, 2, 0)
|
||||
# The physical layout of the output scales is already swizzled as (l, rm, rk, 32, 4, 4), a
|
||||
# requirement for the flashinfer masked group gemm, where rm=m/128 and rk=k/4. The logic
|
||||
# layout is (32, 4, rm, 4, rk, l).
|
||||
output_scales = output_scales.view(torch.float8_e4m3fn).view(
|
||||
l, padded_m // 128, padded_k // 4, 32, 4, 4
|
||||
)
|
||||
output_scales = output_scales.permute(3, 4, 1, 5, 2, 0)
|
||||
return output, output_scales
|
||||
|
||||
|
||||
def scaled_fp4_experts_quant(
|
||||
input_tensor: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
blockscale_offsets: torch.Tensor,
|
||||
topk: int,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP4 and return quantized tensor and scale, for
|
||||
packed MoE Inputs.
|
||||
Args:
|
||||
input: The input tensor to be quantized to FP4
|
||||
expert_map: The expert map tensor
|
||||
input_global_scale: A scalar scaling factor for the entire tensor.
|
||||
expert_offsets: The expert offsets tensor
|
||||
blockscale_offsets: The blockscale offsets tensor
|
||||
Outputs:
|
||||
output: The quantized tensor in FP4
|
||||
output_scales: The blockscale tensor in FP8-E4M3
|
||||
"""
|
||||
assert (
|
||||
input_tensor.ndim == 2
|
||||
), f"input.ndim needs to be == 2, but got {input_tensor.ndim}."
|
||||
if expert_map is not None:
|
||||
(m, k) = input_tensor.shape
|
||||
output_tensor_shape = (m * topk, k)
|
||||
input_tensor = shuffle_rows(input_tensor, expert_map, output_tensor_shape)
|
||||
m_numtopk, k = input_tensor.shape
|
||||
# Control the maximum number of tokens per expert supported by the
|
||||
# NVFP4 MoE Expert Quantization. This is used to prevent the kernel
|
||||
# from running out of memory. This value can also be increased to support
|
||||
# larger models.
|
||||
import os
|
||||
|
||||
MAX_TOKENS_PER_EXPERT = os.environ.get("MODELOPT_MAX_TOKENS_PER_EXPERT", 65536)
|
||||
assert m_numtopk <= MAX_TOKENS_PER_EXPERT * topk, (
|
||||
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
|
||||
f"{MAX_TOKENS_PER_EXPERT})"
|
||||
f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use"
|
||||
f" MODELOPT_MAX_TOKENS_PER_EXPERT to set this value."
|
||||
)
|
||||
scales_k = k // 16
|
||||
padded_k = (scales_k + (4 - 1)) // 4
|
||||
|
||||
# output is uint8 and packed fp4 values
|
||||
output = torch.empty(
|
||||
m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8
|
||||
)
|
||||
# padded part should be zeroed out
|
||||
if padded_k > scales_k:
|
||||
output_scales = torch.zeros(
|
||||
MAX_TOKENS_PER_EXPERT * topk,
|
||||
padded_k,
|
||||
dtype=torch.int32,
|
||||
device=input_tensor.device,
|
||||
)
|
||||
else:
|
||||
output_scales = torch.empty(
|
||||
MAX_TOKENS_PER_EXPERT * topk,
|
||||
padded_k,
|
||||
dtype=torch.int32,
|
||||
device=input_tensor.device,
|
||||
)
|
||||
torch.ops.sgl_kernel.scaled_fp4_experts_quant.default(
|
||||
output,
|
||||
output_scales,
|
||||
input_tensor,
|
||||
input_global_scale,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
)
|
||||
output_scales = output_scales.view(torch.float8_e4m3fn)
|
||||
return output, output_scales
|
||||
|
||||
|
||||
# GPTQ kernels
|
||||
def gptq_marlin_gemm(
|
||||
a: torch.Tensor,
|
||||
c: Optional[torch.Tensor],
|
||||
b_q_weight: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
global_scale: Optional[torch.Tensor],
|
||||
b_zeros: Optional[torch.Tensor],
|
||||
g_idx: Optional[torch.Tensor],
|
||||
perm: Optional[torch.Tensor],
|
||||
workspace: torch.Tensor,
|
||||
b_q_type: ScalarType,
|
||||
size_m: int,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
is_k_full: bool = True,
|
||||
use_atomic_add: bool = False,
|
||||
use_fp32_reduce: bool = False,
|
||||
is_zp_float: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.gptq_marlin_gemm(
|
||||
a,
|
||||
c,
|
||||
b_q_weight,
|
||||
b_scales,
|
||||
global_scale,
|
||||
b_zeros,
|
||||
g_idx,
|
||||
perm,
|
||||
workspace,
|
||||
b_q_type.id,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
is_k_full,
|
||||
use_atomic_add,
|
||||
use_fp32_reduce,
|
||||
is_zp_float,
|
||||
)
|
||||
|
||||
|
||||
def gptq_gemm(
|
||||
a: torch.Tensor,
|
||||
b_q_weight: torch.Tensor,
|
||||
b_gptq_qzeros: torch.Tensor,
|
||||
b_gptq_scales: torch.Tensor,
|
||||
b_g_idx: torch.Tensor,
|
||||
use_shuffle: bool,
|
||||
bit: int,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.gptq_gemm(
|
||||
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit
|
||||
)
|
||||
|
||||
|
||||
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None:
|
||||
torch.torch.ops.sgl_kernel.gptq_shuffle(q_weight, q_perm, bit)
|
||||
@@ -0,0 +1,15 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def apply_token_bitmask_inplace_cuda(
|
||||
logits: torch.Tensor,
|
||||
bitmask: torch.Tensor,
|
||||
indices: Optional[Union[List[int], torch.Tensor]] = None,
|
||||
) -> None:
|
||||
if isinstance(indices, list):
|
||||
indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
|
||||
if indices is not None:
|
||||
indices = indices.to(logits.device)
|
||||
torch.ops.sgl_kernel.apply_token_bitmask_inplace_cuda(logits, bitmask, indices)
|
||||
@@ -0,0 +1,218 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def is_hip() -> bool:
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
def transfer_kv_per_layer(
|
||||
src_k: torch.Tensor,
|
||||
dst_k: torch.Tensor,
|
||||
src_v: torch.Tensor,
|
||||
dst_v: torch.Tensor,
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
item_size: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 16 if _is_hip else 32,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_per_layer(
|
||||
src_k,
|
||||
dst_k,
|
||||
src_v,
|
||||
dst_v,
|
||||
src_indices,
|
||||
dst_indices,
|
||||
item_size,
|
||||
block_quota,
|
||||
num_warps_per_block,
|
||||
)
|
||||
|
||||
|
||||
def transfer_kv_per_layer_pf_lf(
|
||||
src_k: torch.Tensor,
|
||||
dst_k: torch.Tensor,
|
||||
src_v: torch.Tensor,
|
||||
dst_v: torch.Tensor,
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
layer_id: int,
|
||||
item_size: int,
|
||||
src_layout_dim: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 16 if _is_hip else 32,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_per_layer_pf_lf(
|
||||
src_k,
|
||||
dst_k,
|
||||
src_v,
|
||||
dst_v,
|
||||
src_indices,
|
||||
dst_indices,
|
||||
layer_id,
|
||||
item_size,
|
||||
src_layout_dim,
|
||||
block_quota,
|
||||
num_warps_per_block,
|
||||
)
|
||||
|
||||
|
||||
def transfer_kv_all_layer(
|
||||
src_k_layers: torch.Tensor,
|
||||
dst_k_layers: torch.Tensor,
|
||||
src_v_layers: torch.Tensor,
|
||||
dst_v_layers: torch.Tensor,
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
item_size: int,
|
||||
num_layers: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 16 if _is_hip else 32,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_all_layer(
|
||||
src_k_layers,
|
||||
dst_k_layers,
|
||||
src_v_layers,
|
||||
dst_v_layers,
|
||||
src_indices,
|
||||
dst_indices,
|
||||
item_size,
|
||||
num_layers,
|
||||
block_quota,
|
||||
num_warps_per_block,
|
||||
)
|
||||
|
||||
|
||||
def transfer_kv_all_layer_lf_pf(
|
||||
src_k_layers: torch.Tensor,
|
||||
dst_k: torch.Tensor,
|
||||
src_v_layers: torch.Tensor,
|
||||
dst_v: torch.Tensor,
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
item_size: int,
|
||||
dst_layout_dim: int,
|
||||
num_layers: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 16 if _is_hip else 32,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_all_layer_lf_pf(
|
||||
src_k_layers,
|
||||
dst_k,
|
||||
src_v_layers,
|
||||
dst_v,
|
||||
src_indices,
|
||||
dst_indices,
|
||||
item_size,
|
||||
dst_layout_dim,
|
||||
num_layers,
|
||||
block_quota,
|
||||
num_warps_per_block,
|
||||
)
|
||||
|
||||
|
||||
def transfer_kv_direct(
|
||||
src_layers: List[torch.Tensor],
|
||||
dst_layers: List[torch.Tensor],
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
page_size: int,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_direct(
|
||||
src_layers, dst_layers, src_indices, dst_indices, page_size
|
||||
)
|
||||
|
||||
|
||||
def transfer_kv_per_layer_mla(
|
||||
src: torch.Tensor,
|
||||
dst: torch.Tensor,
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
item_size: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 16 if _is_hip else 32,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_per_layer_mla(
|
||||
src,
|
||||
dst,
|
||||
src_indices,
|
||||
dst_indices,
|
||||
item_size,
|
||||
block_quota,
|
||||
num_warps_per_block,
|
||||
)
|
||||
|
||||
|
||||
def transfer_kv_per_layer_mla_pf_lf(
|
||||
src: torch.Tensor,
|
||||
dst: torch.Tensor,
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
layer_id: int,
|
||||
item_size: int,
|
||||
src_layout_dim: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 16 if _is_hip else 32,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_per_layer_mla_pf_lf(
|
||||
src,
|
||||
dst,
|
||||
src_indices,
|
||||
dst_indices,
|
||||
layer_id,
|
||||
item_size,
|
||||
src_layout_dim,
|
||||
block_quota,
|
||||
num_warps_per_block,
|
||||
)
|
||||
|
||||
|
||||
def transfer_kv_all_layer_mla(
|
||||
src_layers: torch.Tensor,
|
||||
dst_layers: torch.Tensor,
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
item_size: int,
|
||||
num_layers: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 16 if _is_hip else 32,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_all_layer_mla(
|
||||
src_layers,
|
||||
dst_layers,
|
||||
src_indices,
|
||||
dst_indices,
|
||||
item_size,
|
||||
num_layers,
|
||||
block_quota,
|
||||
num_warps_per_block,
|
||||
)
|
||||
|
||||
|
||||
def transfer_kv_all_layer_mla_lf_pf(
|
||||
src_layers: torch.Tensor,
|
||||
dst: torch.Tensor,
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
item_size: int,
|
||||
dst_layout_dim: int,
|
||||
num_layers: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 16 if _is_hip else 32,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_all_layer_mla_lf_pf(
|
||||
src_layers,
|
||||
dst,
|
||||
src_indices,
|
||||
dst_indices,
|
||||
item_size,
|
||||
dst_layout_dim,
|
||||
num_layers,
|
||||
block_quota,
|
||||
num_warps_per_block,
|
||||
)
|
||||
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
|
||||
|
||||
def gptq_marlin_repack(
|
||||
b_q_weight,
|
||||
perm,
|
||||
size_k,
|
||||
size_n,
|
||||
num_bits,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.gptq_marlin_repack(
|
||||
b_q_weight,
|
||||
perm,
|
||||
size_k,
|
||||
size_n,
|
||||
num_bits,
|
||||
)
|
||||
|
||||
|
||||
def awq_marlin_repack(
|
||||
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
|
||||
|
||||
|
||||
def awq_marlin_moe_repack(
|
||||
b_q_weight: torch.Tensor,
|
||||
perm: torch.Tensor,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
num_bits: int,
|
||||
) -> torch.Tensor:
|
||||
num_experts = b_q_weight.shape[0]
|
||||
assert size_k % 16 == 0
|
||||
output = torch.empty(
|
||||
(num_experts, size_k // 16, size_n * (num_bits // 2)),
|
||||
device=b_q_weight.device,
|
||||
dtype=b_q_weight.dtype,
|
||||
)
|
||||
for e in range(num_experts):
|
||||
output[e] = torch.ops.sgl_kernel.awq_marlin_repack(
|
||||
b_q_weight[e], size_k, size_n, num_bits
|
||||
)
|
||||
return output
|
||||
@@ -0,0 +1,18 @@
|
||||
import torch
|
||||
|
||||
|
||||
def set_kv_buffer_kernel(
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
loc: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
fallback: bool = False,
|
||||
):
|
||||
try:
|
||||
if fallback:
|
||||
raise RuntimeError("Fallback to torch implementation")
|
||||
torch.ops.sgl_kernel.store_kv_cache(k_cache, v_cache, loc, k, v)
|
||||
except RuntimeError: # ok, fallback to torch implementation
|
||||
k_cache[loc] = k
|
||||
v_cache[loc] = v
|
||||
261
sgl-kernel/build/lib.linux-x86_64-cpython-310/sgl_kernel/moe.py
Normal file
261
sgl-kernel/build/lib.linux-x86_64-cpython-310/sgl_kernel/moe.py
Normal file
@@ -0,0 +1,261 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_token_ids,
|
||||
experts_ids,
|
||||
num_tokens_post_pad,
|
||||
cumsum_buffer,
|
||||
pad_sorted_token_ids=False,
|
||||
):
|
||||
torch.ops.sgl_kernel.moe_align_block_size.default(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_token_ids,
|
||||
experts_ids,
|
||||
num_tokens_post_pad,
|
||||
cumsum_buffer,
|
||||
pad_sorted_token_ids,
|
||||
)
|
||||
|
||||
|
||||
def topk_softmax(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
gating_output: float,
|
||||
renormalize: bool = False,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.topk_softmax.default(
|
||||
topk_weights, topk_ids, gating_output, renormalize
|
||||
)
|
||||
|
||||
|
||||
def moe_fused_gate(
|
||||
input_tensor,
|
||||
bias,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
num_fused_shared_experts=0,
|
||||
routed_scaling_factor=0,
|
||||
apply_routed_scaling_factor_on_output=False,
|
||||
):
|
||||
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
|
||||
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
|
||||
# as the group weight to select expert groups and then select topk experts within the selected groups
|
||||
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
|
||||
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now.
|
||||
# for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
|
||||
# num_fused_shared_experts: if > 0, the last several experts will be
|
||||
# replaced with shared experts. the shared experts will be divided by the
|
||||
# routed_scaling_factor - this is intended to cancel out later when routed+shared
|
||||
# output is scaled so that shared experts are not scaled.
|
||||
# routed_scaling_factor: if > 0, the experts will be scaled by this factor
|
||||
# apply_routed_scaling_factor_on_output: if true, output will be
|
||||
# scaled by the routed_scaling_factor
|
||||
return torch.ops.sgl_kernel.moe_fused_gate.default(
|
||||
input_tensor,
|
||||
bias,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
apply_routed_scaling_factor_on_output,
|
||||
)
|
||||
|
||||
|
||||
def ep_moe_pre_reorder(
|
||||
input_tensor,
|
||||
gateup_input,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
a1_scales,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
use_per_token_if_dynamic,
|
||||
):
|
||||
return torch.ops.sgl_kernel.ep_moe_pre_reorder.default(
|
||||
input_tensor,
|
||||
gateup_input,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
a1_scales,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
use_per_token_if_dynamic,
|
||||
)
|
||||
|
||||
|
||||
def ep_moe_silu_and_mul(
|
||||
gateup_output,
|
||||
down_input,
|
||||
reorder_topk_ids,
|
||||
scales,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
):
|
||||
return torch.ops.sgl_kernel.ep_moe_silu_and_mul.default(
|
||||
gateup_output,
|
||||
down_input,
|
||||
reorder_topk_ids,
|
||||
scales,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
)
|
||||
|
||||
|
||||
def ep_moe_post_reorder(
|
||||
down_output,
|
||||
output,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
):
|
||||
return torch.ops.sgl_kernel.ep_moe_post_reorder.default(
|
||||
down_output,
|
||||
output,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
topk,
|
||||
)
|
||||
|
||||
|
||||
def fp8_blockwise_scaled_grouped_mm(
|
||||
output,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
workspace,
|
||||
):
|
||||
torch.ops.sgl_kernel.fp8_blockwise_scaled_grouped_mm.default(
|
||||
output,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
workspace,
|
||||
)
|
||||
|
||||
|
||||
def prepare_moe_input(
|
||||
topk_ids,
|
||||
expert_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
input_permutation,
|
||||
output_permutation,
|
||||
num_experts,
|
||||
n,
|
||||
k,
|
||||
blockscale_offsets: Optional[torch.Tensor] = None,
|
||||
):
|
||||
torch.ops.sgl_kernel.prepare_moe_input.default(
|
||||
topk_ids,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
input_permutation,
|
||||
output_permutation,
|
||||
num_experts,
|
||||
n,
|
||||
k,
|
||||
)
|
||||
|
||||
|
||||
def apply_shuffle_mul_sum(
|
||||
input,
|
||||
output,
|
||||
permutation,
|
||||
factors,
|
||||
):
|
||||
torch.ops.sgl_kernel.apply_shuffle_mul_sum.default(
|
||||
input, output, permutation, factors
|
||||
)
|
||||
|
||||
|
||||
def cutlass_fp4_group_mm(
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
a_blockscale,
|
||||
b_blockscale,
|
||||
alphas,
|
||||
out_dtype,
|
||||
device,
|
||||
params: Dict[str, Any],
|
||||
):
|
||||
"""
|
||||
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
|
||||
the gemms for each combination based on the specified problem sizes.
|
||||
|
||||
This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward.
|
||||
- a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized
|
||||
input and expert weights.
|
||||
- a_/b_scales: The blockscales in FP8-E4M3 precision
|
||||
- ab_strides/c_strides: Strides for the a/b tensors between rows.
|
||||
- expert_offsets/sf_offsets: Indices that mark at which token index
|
||||
each expert begins its computation. The number of tokens
|
||||
computed with expert E is expert_offsets[E + 1] -
|
||||
expert_offsets[E] And the sf_size per expert is
|
||||
sf_offset[E+1] - sf_offset[E]
|
||||
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
|
||||
MMs used in the fused MoE operation.
|
||||
"""
|
||||
m_topk = a_fp4.shape[0]
|
||||
n = b_fp4.shape[1]
|
||||
c_shape = (m_topk, n)
|
||||
c = torch.empty(c_shape, device=device, dtype=out_dtype)
|
||||
torch.ops.sgl_kernel.cutlass_fp4_group_mm.default(
|
||||
c,
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
a_blockscale,
|
||||
b_blockscale,
|
||||
alphas,
|
||||
params["ab_strides"],
|
||||
params["c_strides"],
|
||||
params["problem_sizes"],
|
||||
params["expert_offsets"],
|
||||
params["blockscale_offsets"],
|
||||
)
|
||||
return c.to(dtype=out_dtype)
|
||||
@@ -0,0 +1,543 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from sgl_kernel.utils import _to_tensor_scalar_tuple
|
||||
|
||||
|
||||
def _top_k_renorm_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
maybe_top_k_arr: Optional[torch.Tensor],
|
||||
top_k_val: int,
|
||||
) -> torch.Tensor:
|
||||
probs = probs.float()
|
||||
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
||||
renorm_probs = torch.empty_like(probs)
|
||||
torch.ops.sgl_kernel.top_k_renorm_probs.default(
|
||||
probs, renorm_probs, maybe_top_k_arr, top_k_val
|
||||
)
|
||||
return renorm_probs
|
||||
|
||||
|
||||
def top_k_renorm_probs(
|
||||
probs: torch.Tensor,
|
||||
top_k: Union[torch.Tensor, int],
|
||||
) -> torch.Tensor:
|
||||
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
|
||||
Fused GPU kernel for renormalizing probabilities by top-k thresholding.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
probs: torch.Tensor
|
||||
Probabilities, shape ``(batch_size, num_classes)``.
|
||||
top_k: Union[torch.Tensor, int]
|
||||
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for
|
||||
for re-normalizing probabilities, should be in ``(0, num_classes)``.
|
||||
If a scalar, the same threshold is used for all requests.
|
||||
If a tensor, each request has its own threshold.
|
||||
We keep the top-k probabilities, set the rest to zero, and renormalize the probabilities.
|
||||
|
||||
Returns
|
||||
-------
|
||||
renorm_probs: torch.Tensor
|
||||
Renormalized probabilities, shape ``(batch_size, num_classes)``.
|
||||
|
||||
Note
|
||||
----
|
||||
This combination of ``top_k_renorm_probs`` and ``sampling_from_probs`` should be equivalent to
|
||||
``top_k_sampling_from_probs``.
|
||||
"""
|
||||
return _top_k_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_k))
|
||||
|
||||
|
||||
top_k_renorm_prob = top_k_renorm_probs
|
||||
|
||||
|
||||
def _top_p_renorm_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
maybe_top_p_arr: Optional[torch.Tensor],
|
||||
top_p_val: float,
|
||||
) -> torch.Tensor:
|
||||
probs = probs.float()
|
||||
maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||
renorm_probs = torch.empty_like(probs)
|
||||
torch.ops.sgl_kernel.top_p_renorm_probs.default(
|
||||
probs, renorm_probs, maybe_top_p_arr, top_p_val
|
||||
)
|
||||
return renorm_probs
|
||||
|
||||
|
||||
def top_p_renorm_probs(
|
||||
probs: torch.Tensor,
|
||||
top_p: Union[torch.Tensor, float],
|
||||
) -> torch.Tensor:
|
||||
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
|
||||
Fused GPU kernel for renormalizing probabilities by top-p thresholding.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
probs: torch.Tensor
|
||||
Probabilities, shape ``(batch_size, num_classes)``.
|
||||
top_p: Union[torch.Tensor, float]
|
||||
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-p threshold for for
|
||||
re-normalizing probabilities, should be in ``(0, 1)``.
|
||||
If a scalar, the same threshold is used for all requests.
|
||||
If a tensor, each request has its own threshold.
|
||||
We mask out the probabilities less than `threshold` where the cumulative sum
|
||||
of ``probs[probs >= threshold]`` is `top_p`, and renormalize the probabilities.
|
||||
|
||||
Returns
|
||||
-------
|
||||
renorm_probs: torch.Tensor
|
||||
Renormalized probabilities, shape ``(batch_size, num_classes)``.
|
||||
|
||||
Note
|
||||
----
|
||||
This combination of ``top_p_renorm_probs`` and ``sampling_from_probs`` should be equivalent to
|
||||
``top_p_sampling_from_probs``.
|
||||
|
||||
"""
|
||||
return _top_p_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_p))
|
||||
|
||||
|
||||
top_p_renorm_prob = top_p_renorm_probs
|
||||
|
||||
|
||||
def _top_p_sampling_from_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
indices: Optional[torch.Tensor],
|
||||
maybe_top_p_arr: Optional[torch.Tensor],
|
||||
top_p_val: float,
|
||||
deterministic: bool,
|
||||
generator: Optional[torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
with probs.device as device:
|
||||
probs = probs.float()
|
||||
maybe_top_p_arr = (
|
||||
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||
)
|
||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||
torch.ops.sgl_kernel.top_p_sampling_from_probs.default(
|
||||
probs,
|
||||
samples,
|
||||
indices,
|
||||
maybe_top_p_arr,
|
||||
top_p_val,
|
||||
deterministic,
|
||||
generator,
|
||||
)
|
||||
return samples
|
||||
|
||||
|
||||
def top_p_sampling_from_probs(
|
||||
probs: torch.Tensor,
|
||||
top_p: Union[torch.Tensor, float],
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
deterministic: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
check_nan: bool = False,
|
||||
) -> torch.Tensor:
|
||||
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
|
||||
Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities,
|
||||
this operator implements GPU-based rejection sampling without explicit sorting.
|
||||
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
|
||||
|
||||
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
|
||||
which is more efficient than the naive implementation that launches a series of kernels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
probs: torch.Tensor
|
||||
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
|
||||
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
|
||||
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
|
||||
probability distributions.
|
||||
top_p: Union[torch.Tensor, float]
|
||||
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling.
|
||||
If a scalar, the same threshold is used for all requests.
|
||||
If a tensor, each request has its own threshold.
|
||||
indices: Optional[torch.Tensor]
|
||||
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
|
||||
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
|
||||
This allows reusing the same probability distribution for multiple outputs.
|
||||
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
|
||||
deterministic: bool
|
||||
Whether to use deterministic kernel implementation, default is ``True``.
|
||||
generator: Optional[torch.Generator]
|
||||
A random number generator for the operation.
|
||||
check_nan: bool
|
||||
Whether to check nan in :attr:`probs`, default is ``False``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
samples: torch.Tensor
|
||||
Sampled categories, shape ``(batch_size,)``.
|
||||
|
||||
Note
|
||||
----
|
||||
This function expects float32 inputs, and the output is int32.
|
||||
|
||||
"""
|
||||
if check_nan:
|
||||
if torch.any(torch.isnan(probs)):
|
||||
raise ValueError("Input probs contains NaN.")
|
||||
return _top_p_sampling_from_probs_internal(
|
||||
probs, indices, *_to_tensor_scalar_tuple(top_p), deterministic, generator
|
||||
)
|
||||
|
||||
|
||||
def _top_k_top_p_sampling_from_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
indices: Optional[torch.Tensor],
|
||||
maybe_top_k_arr: Optional[torch.Tensor],
|
||||
top_k_val: int,
|
||||
maybe_top_p_arr: Optional[torch.Tensor],
|
||||
top_p_val: float,
|
||||
deterministic: bool,
|
||||
generator: Optional[torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
with probs.device as device:
|
||||
probs = probs.float()
|
||||
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
||||
maybe_top_p_arr = (
|
||||
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||
)
|
||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||
torch.ops.sgl_kernel.top_k_top_p_sampling_from_probs.default(
|
||||
probs,
|
||||
samples,
|
||||
indices,
|
||||
maybe_top_k_arr,
|
||||
top_k_val,
|
||||
maybe_top_p_arr,
|
||||
top_p_val,
|
||||
deterministic,
|
||||
generator,
|
||||
)
|
||||
return samples
|
||||
|
||||
|
||||
def top_k_top_p_sampling_from_probs(
|
||||
probs: torch.Tensor,
|
||||
top_k: Union[torch.Tensor, int],
|
||||
top_p: Union[torch.Tensor, float],
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
filter_apply_order: str = "top_k_first",
|
||||
deterministic: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
check_nan: bool = False,
|
||||
) -> torch.Tensor:
|
||||
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
|
||||
Fused GPU kernel for top-k and top-p sampling from probabilities,
|
||||
|
||||
this operator implements GPU-based rejection sampling without explicit sorting.
|
||||
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
|
||||
|
||||
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
|
||||
which is more efficient than the naive implementation that launches a series of kernels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
probs: torch.Tensor
|
||||
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
|
||||
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
|
||||
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
|
||||
probability distributions.
|
||||
top_k: Union[torch.Tensor, int]
|
||||
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling.
|
||||
If a scalar, the same threshold is used for all requests.
|
||||
If a tensor, each request has its own threshold.
|
||||
top_p: Union[torch.Tensor, float]
|
||||
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling.
|
||||
If a scalar, the same threshold is used for all requests.
|
||||
If a tensor, each request has its own threshold.
|
||||
indices: Optional[torch.Tensor]
|
||||
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
|
||||
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
|
||||
This allows reusing the same probability distribution for multiple outputs.
|
||||
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
|
||||
filter_apply_order: str
|
||||
The order of applying top-k and top-p sampling, should be either ``"top_k_first"`` or ``"joint"``.
|
||||
If ``"top_k_first"``, we first apply top-k filter, then apply top-p sampling on the top-k results.
|
||||
If ``"joint"``, we apply top-k and top-p filter simultaneously in each round. Default is ``"top_k_first"``.
|
||||
deterministic: bool
|
||||
Whether to use deterministic kernel implementation, default is ``True``.
|
||||
generator: Optional[torch.Generator]
|
||||
A random number generator for the operation.
|
||||
check_nan: bool
|
||||
Whether to check nan in :attr:`probs`, default is ``False``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
samples: torch.Tensor
|
||||
Sampled categories, shape ``(batch_size,)``.
|
||||
|
||||
Note
|
||||
----
|
||||
This function expects float32 inputs, and the output is int32.
|
||||
|
||||
"""
|
||||
if filter_apply_order == "top_k_first":
|
||||
renorm_probs = top_k_renorm_probs(probs, top_k)
|
||||
return top_p_sampling_from_probs(
|
||||
renorm_probs,
|
||||
top_p,
|
||||
indices,
|
||||
deterministic,
|
||||
check_nan=check_nan,
|
||||
generator=generator,
|
||||
)
|
||||
elif filter_apply_order == "joint":
|
||||
if check_nan:
|
||||
if torch.any(torch.isnan(probs)):
|
||||
raise ValueError("Input probs contains NaN.")
|
||||
return _top_k_top_p_sampling_from_probs_internal(
|
||||
probs,
|
||||
indices,
|
||||
*_to_tensor_scalar_tuple(top_k),
|
||||
*_to_tensor_scalar_tuple(top_p),
|
||||
deterministic,
|
||||
generator,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
|
||||
|
||||
|
||||
def _min_p_sampling_from_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
indices: Optional[torch.Tensor],
|
||||
maybe_min_p_arr: Optional[torch.Tensor],
|
||||
min_p_val: float,
|
||||
deterministic: bool,
|
||||
generator: Optional[torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
with probs.device as device:
|
||||
probs = probs.float()
|
||||
maybe_min_p_arr = (
|
||||
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
|
||||
)
|
||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||
torch.ops.sgl_kernel.min_p_sampling_from_probs.default(
|
||||
probs,
|
||||
samples,
|
||||
indices,
|
||||
maybe_min_p_arr,
|
||||
min_p_val,
|
||||
deterministic,
|
||||
generator,
|
||||
)
|
||||
return samples
|
||||
|
||||
|
||||
def min_p_sampling_from_probs(
|
||||
probs: torch.Tensor,
|
||||
min_p: Union[torch.Tensor, float],
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
deterministic: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
check_nan: bool = False,
|
||||
) -> torch.Tensor:
|
||||
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
|
||||
Fused GPU kernel for `min_p sampling <https://arxiv.org/abs/2407.01082>`_ from probabilities,
|
||||
|
||||
this operator implements GPU-based rejection sampling without explicit sorting.
|
||||
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
|
||||
|
||||
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
|
||||
which is more efficient than the naive implementation that launches a series of kernels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
probs: torch.Tensor
|
||||
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
|
||||
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
|
||||
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
|
||||
probability distributions.
|
||||
min_p: Union[torch.Tensor, float]
|
||||
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for min-p sampling.
|
||||
If a scalar, the same threshold is used for all requests.
|
||||
If a tensor, each request has its own threshold.
|
||||
indices: Optional[torch.Tensor]
|
||||
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
|
||||
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
|
||||
This allows reusing the same probability distribution for multiple outputs.
|
||||
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
|
||||
deterministic: bool
|
||||
Whether to use deterministic kernel implementation, default is ``True``.
|
||||
generator: Optional[torch.Generator]
|
||||
A random number generator for the operation.
|
||||
check_nan: bool
|
||||
Whether to check nan in :attr:`probs`, default is ``False``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
samples: torch.Tensor
|
||||
Sampled categories, shape ``(batch_size,)``.
|
||||
|
||||
Note
|
||||
----
|
||||
This function expects float32 inputs, and the output is int32.
|
||||
"""
|
||||
if check_nan:
|
||||
if torch.any(torch.isnan(probs)):
|
||||
raise ValueError("Input probs contains NaN.")
|
||||
return _min_p_sampling_from_probs_internal(
|
||||
probs, indices, *_to_tensor_scalar_tuple(min_p), deterministic, generator
|
||||
)
|
||||
|
||||
|
||||
def _top_k_mask_logits_internal(
|
||||
logits: torch.Tensor,
|
||||
maybe_top_k_arr: Optional[torch.Tensor],
|
||||
top_k_val: int,
|
||||
) -> torch.Tensor:
|
||||
logits = logits.float()
|
||||
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
||||
mask_logits = torch.empty_like(logits)
|
||||
torch.ops.sgl_kernel.top_k_mask_logits.default(
|
||||
logits, mask_logits, maybe_top_k_arr, top_k_val
|
||||
)
|
||||
return mask_logits
|
||||
|
||||
|
||||
def top_k_mask_logits(
|
||||
logits: torch.Tensor,
|
||||
top_k: Union[torch.Tensor, int],
|
||||
) -> torch.Tensor:
|
||||
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
|
||||
Fused GPU kernel for masking logits by top-k thresholding.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
logits: torch.Tensor
|
||||
Logits before softmax, shape ``(batch_size, num_classes)``.
|
||||
top_k: Union[torch.Tensor, int]
|
||||
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for
|
||||
for masking logits, should be in ``(0, num_classes)``.
|
||||
If a scalar, the same threshold is used for all requests.
|
||||
If a tensor, each request has its own threshold.
|
||||
We keep the top-k logits, set the rest to negative infinity.
|
||||
|
||||
Returns
|
||||
-------
|
||||
masked_logits: torch.Tensor
|
||||
Masked logits, shape ``(batch_size, num_classes)``.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> import torch
|
||||
>>> import flashinfer
|
||||
>>> torch.manual_seed(42)
|
||||
>>> batch_size = 4
|
||||
>>> vocab_size = 5
|
||||
>>> top_k = 3
|
||||
>>> logits = torch.randn(batch_size, vocab_size).to(0)
|
||||
>>> logits
|
||||
tensor([[ 1.9269, 1.4873, 0.9007, -2.1055, -0.7581],
|
||||
[ 1.0783, 0.8008, 1.6806, 0.3559, -0.6866],
|
||||
[-0.4934, 0.2415, -0.2316, 0.0418, -0.2516],
|
||||
[ 0.8599, -0.3097, -0.3957, 0.8034, -0.6216]], device='cuda:0')
|
||||
>>> masked_logits = flashinfer.sampling.top_k_mask_logits(logits, top_k)
|
||||
>>> masked_logits
|
||||
tensor([[ 1.9269, 1.4873, 0.9007, -inf, -inf],
|
||||
[ 1.0783, 0.8008, 1.6806, -inf, -inf],
|
||||
[ -inf, 0.2415, -0.2316, 0.0418, -inf],
|
||||
[ 0.8599, -0.3097, -inf, 0.8034, -inf]], device='cuda:0')
|
||||
|
||||
Note
|
||||
----
|
||||
The combination of ``top_k_mask_logits`` and ``softmax`` should be equivalent to ``top_k_renorm_probs``.
|
||||
|
||||
See Also
|
||||
--------
|
||||
top_k_renorm_probs
|
||||
"""
|
||||
return _top_k_mask_logits_internal(logits, *_to_tensor_scalar_tuple(top_k))
|
||||
|
||||
|
||||
def top_k_top_p_sampling_from_logits(
|
||||
logits: torch.Tensor,
|
||||
top_k: Union[torch.Tensor, int],
|
||||
top_p: Union[torch.Tensor, float],
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
filter_apply_order: str = "top_k_first",
|
||||
deterministic: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
check_nan: bool = False,
|
||||
) -> torch.Tensor:
|
||||
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
|
||||
Fused GPU kernel for top-k and top-p sampling from probabilities,
|
||||
|
||||
this operator implements GPU-based rejection sampling without explicit sorting.
|
||||
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
|
||||
|
||||
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
|
||||
which is more efficient than the naive implementation that launches a series of kernels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
logits: torch.Tensor
|
||||
Pre-softmax logits for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
|
||||
and the i-th output will be sampled from the i-th row of logits. When indices is provided,
|
||||
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
|
||||
probability distributions.
|
||||
top_k: Union[torch.Tensor, int]
|
||||
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling.
|
||||
If a scalar, the same threshold is used for all requests.
|
||||
If a tensor, each request has its own threshold.
|
||||
top_p: Union[torch.Tensor, float]
|
||||
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling.
|
||||
If a scalar, the same threshold is used for all requests.
|
||||
If a tensor, each request has its own threshold.
|
||||
indices: Optional[torch.Tensor]
|
||||
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
|
||||
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
|
||||
This allows reusing the same probability distribution for multiple outputs.
|
||||
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
|
||||
filter_apply_order: str
|
||||
The order of applying top-k and top-p sampling, should be either ``"top_k_first"`` or ``"joint"``.
|
||||
If ``"top_k_first"``, we first apply top-k filter, then apply top-p sampling on the top-k results.
|
||||
If ``"joint"``, we apply top-k and top-p filter simultaneously in each round. Default is ``"top_k_first"``.
|
||||
deterministic: bool
|
||||
Whether to use deterministic kernel implementation, default is ``True``.
|
||||
generator: Optional[torch.Generator]
|
||||
A random number generator for the operation.
|
||||
check_nan: bool
|
||||
Whether to check nan in :attr:`probs`, default is ``False``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
samples: torch.Tensor
|
||||
Sampled categories, shape ``(batch_size,)``.
|
||||
|
||||
Note
|
||||
----
|
||||
This function expects float32 inputs, and the output is int32.
|
||||
|
||||
"""
|
||||
if filter_apply_order == "top_k_first":
|
||||
masked_logits = top_k_mask_logits(logits, top_k)
|
||||
probs = torch.softmax(masked_logits, dim=-1)
|
||||
return top_p_sampling_from_probs(
|
||||
probs,
|
||||
top_p,
|
||||
indices,
|
||||
deterministic,
|
||||
check_nan=check_nan,
|
||||
generator=generator,
|
||||
)
|
||||
elif filter_apply_order == "joint":
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
if check_nan:
|
||||
if torch.any(torch.isnan(probs)):
|
||||
raise ValueError("Input probs contains NaN.")
|
||||
return _top_k_top_p_sampling_from_probs_internal(
|
||||
probs,
|
||||
indices,
|
||||
*_to_tensor_scalar_tuple(top_k),
|
||||
*_to_tensor_scalar_tuple(top_p),
|
||||
deterministic,
|
||||
generator,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
|
||||
@@ -0,0 +1,352 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
import struct
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
_SCALAR_TYPES_ID_MAP = {}
|
||||
|
||||
|
||||
# Mirrors enum in `core/scalar_type.hpp`
|
||||
class NanRepr(Enum):
|
||||
NONE = 0 # nans are not supported
|
||||
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
|
||||
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
|
||||
|
||||
|
||||
# This ScalarType class is a parallel implementation of the C++ ScalarType
|
||||
# class found in csrc/core/scalar_type.hpp. These two classes should be kept
|
||||
# in sync until the inductor fully supports custom C++ classes.
|
||||
@dataclass(frozen=True)
|
||||
class ScalarType:
|
||||
"""
|
||||
ScalarType can represent a wide range of floating point and integer
|
||||
types, in particular it can be used to represent sub-byte data types
|
||||
(something that torch.dtype currently does not support). It is also
|
||||
capable of representing types with a bias, i.e.:
|
||||
`stored_value = value + bias`,
|
||||
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
|
||||
of 8). The implementation for this class can be found in
|
||||
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
|
||||
with that file.
|
||||
"""
|
||||
|
||||
exponent: int
|
||||
"""
|
||||
Number of bits in the exponent if this is a floating point type
|
||||
(zero if this an integer type)
|
||||
"""
|
||||
|
||||
mantissa: int
|
||||
"""
|
||||
Number of bits in the mantissa if this is a floating point type,
|
||||
or the number bits representing an integer excluding the sign bit if
|
||||
this an integer type.
|
||||
"""
|
||||
|
||||
signed: bool
|
||||
"If the type is signed (i.e. has a sign bit)"
|
||||
|
||||
bias: int
|
||||
"""
|
||||
bias used to encode the values in this scalar type
|
||||
(value = stored_value - bias, default 0) for example if we store the
|
||||
type as an unsigned integer with a bias of 128 then the value 0 will be
|
||||
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
|
||||
"""
|
||||
|
||||
_finite_values_only: bool = False
|
||||
"""
|
||||
Private: if infs are supported, used `has_infs()` instead.
|
||||
"""
|
||||
|
||||
nan_repr: NanRepr = NanRepr.IEEE_754
|
||||
"""
|
||||
How NaNs are represent in this scalar type, returns NanRepr value.
|
||||
(not applicable for integer types)
|
||||
"""
|
||||
|
||||
def _floating_point_max_int(self) -> int:
|
||||
assert (
|
||||
self.mantissa <= 52 and self.exponent <= 11
|
||||
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
||||
|
||||
max_mantissa = (1 << self.mantissa) - 1
|
||||
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
|
||||
max_mantissa = max_mantissa - 1
|
||||
|
||||
max_exponent = (1 << self.exponent) - 2
|
||||
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN or self.nan_repr == NanRepr.NONE:
|
||||
assert (
|
||||
self.exponent < 11
|
||||
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
||||
max_exponent = max_exponent + 1
|
||||
|
||||
# adjust the exponent to match that of a double
|
||||
# for now we assume the exponent bias is the standard 2^(e-1) -1, (where
|
||||
# e is the exponent bits), there is some precedent for non-standard
|
||||
# biases, example `float8_e4m3b11fnuz` here:
|
||||
# https://github.com/jax-ml/ml_dtypes but to avoid premature over
|
||||
# complication we are just assuming the standard exponent bias until
|
||||
# there is a need to support non-standard biases
|
||||
exponent_bias = (1 << (self.exponent - 1)) - 1
|
||||
exponent_bias_double = (1 << 10) - 1 # double e = 11
|
||||
|
||||
max_exponent_double = max_exponent - exponent_bias + exponent_bias_double
|
||||
|
||||
# shift the mantissa and exponent into the proper positions for an
|
||||
# IEEE double and bitwise-or them together.
|
||||
return (max_mantissa << (52 - self.mantissa)) | (max_exponent_double << 52)
|
||||
|
||||
def _floating_point_max(self) -> float:
|
||||
double_raw = self._floating_point_max_int()
|
||||
return struct.unpack("!d", struct.pack("!Q", double_raw))[0]
|
||||
|
||||
def _raw_max(self) -> Union[int, float]:
|
||||
if self.is_floating_point():
|
||||
return self._floating_point_max()
|
||||
else:
|
||||
assert (
|
||||
self.size_bits < 64 or self.size_bits == 64 and self.is_signed()
|
||||
), "Cannot represent max as an int"
|
||||
return (1 << self.mantissa) - 1
|
||||
|
||||
def _raw_min(self) -> Union[int, float]:
|
||||
if self.is_floating_point():
|
||||
assert (
|
||||
self.is_signed()
|
||||
), "We currently assume all floating point types are signed"
|
||||
sign_bit_double = 1 << 63
|
||||
|
||||
max_raw = self._floating_point_max_int()
|
||||
min_raw = max_raw | sign_bit_double
|
||||
return struct.unpack("!d", struct.pack("!Q", min_raw))[0]
|
||||
else:
|
||||
assert (
|
||||
not self.is_signed() or self.size_bits <= 64
|
||||
), "Cannot represent min as a int64_t"
|
||||
|
||||
if self.is_signed():
|
||||
return -(1 << (self.size_bits - 1))
|
||||
else:
|
||||
return 0
|
||||
|
||||
@functools.cached_property
|
||||
def id(self) -> int:
|
||||
"""
|
||||
Convert the ScalarType to an int which can be passed to pytorch custom
|
||||
ops. This layout of the int must be kept in sync with the C++
|
||||
ScalarType's from_id method.
|
||||
"""
|
||||
val = 0
|
||||
offset = 0
|
||||
|
||||
def or_and_advance(member, bit_width):
|
||||
nonlocal val
|
||||
nonlocal offset
|
||||
bit_mask = (1 << bit_width) - 1
|
||||
val = val | (int(member) & bit_mask) << offset
|
||||
offset = offset + bit_width
|
||||
|
||||
or_and_advance(self.exponent, 8)
|
||||
or_and_advance(self.mantissa, 8)
|
||||
or_and_advance(self.signed, 1)
|
||||
or_and_advance(self.bias, 32)
|
||||
or_and_advance(self._finite_values_only, 1)
|
||||
or_and_advance(self.nan_repr.value, 8)
|
||||
|
||||
assert offset <= 64, f"ScalarType fields too big {offset} to fit into an int64"
|
||||
|
||||
_SCALAR_TYPES_ID_MAP[val] = self
|
||||
|
||||
return val
|
||||
|
||||
@property
|
||||
def size_bits(self) -> int:
|
||||
return self.exponent + self.mantissa + int(self.signed)
|
||||
|
||||
def min(self) -> Union[int, float]:
|
||||
"""
|
||||
Min representable value for this scalar type.
|
||||
(accounting for bias if there is one)
|
||||
"""
|
||||
return self._raw_min() - self.bias
|
||||
|
||||
def max(self) -> Union[int, float]:
|
||||
"""
|
||||
Max representable value for this scalar type.
|
||||
(accounting for bias if there is one)
|
||||
"""
|
||||
return self._raw_max() - self.bias
|
||||
|
||||
def is_signed(self) -> bool:
|
||||
"""
|
||||
If the type is signed (i.e. has a sign bit), same as `signed`
|
||||
added for consistency with:
|
||||
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
|
||||
"""
|
||||
return self.signed
|
||||
|
||||
def is_floating_point(self) -> bool:
|
||||
"If the type is a floating point type"
|
||||
return self.exponent != 0
|
||||
|
||||
def is_integer(self) -> bool:
|
||||
"If the type is an integer type"
|
||||
return self.exponent == 0
|
||||
|
||||
def has_bias(self) -> bool:
|
||||
"If the type has a non-zero bias"
|
||||
return self.bias != 0
|
||||
|
||||
def has_infs(self) -> bool:
|
||||
"If the type is floating point and supports infinity"
|
||||
return not self._finite_values_only
|
||||
|
||||
def has_nans(self) -> bool:
|
||||
return self.nan_repr != NanRepr.NONE
|
||||
|
||||
def is_ieee_754(self) -> bool:
|
||||
"""
|
||||
If the type is a floating point type that follows IEEE 754
|
||||
conventions
|
||||
"""
|
||||
return self.nan_repr == NanRepr.IEEE_754 and not self._finite_values_only
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||
for floating point types (leading f) the scheme is:
|
||||
`float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
||||
flags:
|
||||
- no-flags: means it follows IEEE 754 conventions
|
||||
- f: means finite values only (no infinities)
|
||||
- n: means nans are supported (non-standard encoding)
|
||||
for integer types the scheme is:
|
||||
`[u]int<size_bits>[b<bias>]`
|
||||
- if bias is not present it means its zero
|
||||
"""
|
||||
if self.is_floating_point():
|
||||
ret = (
|
||||
"float"
|
||||
+ str(self.size_bits)
|
||||
+ "_e"
|
||||
+ str(self.exponent)
|
||||
+ "m"
|
||||
+ str(self.mantissa)
|
||||
)
|
||||
|
||||
if not self.is_ieee_754():
|
||||
if self._finite_values_only:
|
||||
ret = ret + "f"
|
||||
if self.nan_repr != NanRepr.NONE:
|
||||
ret = ret + "n"
|
||||
|
||||
return ret
|
||||
else:
|
||||
ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
|
||||
if self.has_bias():
|
||||
ret = ret + "b" + str(self.bias)
|
||||
return ret
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "ScalarType." + self.__str__()
|
||||
|
||||
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
|
||||
# opcheck to work.
|
||||
def __len__(self) -> int:
|
||||
raise TypeError
|
||||
|
||||
#
|
||||
# Convenience Constructors
|
||||
#
|
||||
|
||||
@classmethod
|
||||
def int_(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
|
||||
"Create a signed integer scalar type (size_bits includes sign-bit)."
|
||||
ret = cls(0, size_bits - 1, True, bias if bias else 0)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def uint(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
|
||||
"""Create a unsigned integer scalar type."""
|
||||
ret = cls(0, size_bits, False, bias if bias else 0)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def float_IEEE754(cls, exponent: int, mantissa: int) -> "ScalarType":
|
||||
"""
|
||||
Create a standard floating point type
|
||||
(i.e. follows IEEE 754 conventions).
|
||||
"""
|
||||
assert mantissa > 0 and exponent > 0
|
||||
ret = cls(exponent, mantissa, True, 0)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def float_(
|
||||
cls, exponent: int, mantissa: int, finite_values_only: bool, nan_repr: NanRepr
|
||||
) -> "ScalarType":
|
||||
"""
|
||||
Create a non-standard floating point type
|
||||
(i.e. does not follow IEEE 754 conventions).
|
||||
"""
|
||||
assert mantissa > 0 and exponent > 0
|
||||
assert nan_repr != NanRepr.IEEE_754, (
|
||||
"use `float_IEEE754` constructor for floating point types that "
|
||||
"follow IEEE 754 conventions"
|
||||
)
|
||||
ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def from_id(cls, scalar_type_id: int):
|
||||
if scalar_type_id not in _SCALAR_TYPES_ID_MAP:
|
||||
raise ValueError(f"scalar_type_id {scalar_type_id} doesn't exists.")
|
||||
return _SCALAR_TYPES_ID_MAP[scalar_type_id]
|
||||
|
||||
|
||||
# naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||
# for floating point types (leading f) the scheme is:
|
||||
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
||||
# flags:
|
||||
# - no-flags: means it follows IEEE 754 conventions
|
||||
# - f: means finite values only (no infinities)
|
||||
# - n: means nans are supported (non-standard encoding)
|
||||
# for integer types the scheme is:
|
||||
# `[u]int<size_bits>[b<bias>]`
|
||||
# - if bias is not present it means its zero
|
||||
|
||||
|
||||
class scalar_types:
|
||||
int4 = ScalarType.int_(4, None)
|
||||
uint4 = ScalarType.uint(4, None)
|
||||
int8 = ScalarType.int_(8, None)
|
||||
uint8 = ScalarType.uint(8, None)
|
||||
float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
|
||||
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
|
||||
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
|
||||
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
|
||||
|
||||
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
|
||||
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
|
||||
|
||||
# fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
||||
float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE)
|
||||
|
||||
# "gptq" types
|
||||
uint2b2 = ScalarType.uint(2, 2)
|
||||
uint3b4 = ScalarType.uint(3, 4)
|
||||
uint4b8 = ScalarType.uint(4, 8)
|
||||
uint8b128 = ScalarType.uint(8, 128)
|
||||
|
||||
# colloquial names
|
||||
bfloat16 = float16_e8m7
|
||||
float16 = float16_e5m10
|
||||
@@ -0,0 +1,293 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def maybe_contiguous(x):
|
||||
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
||||
|
||||
|
||||
# Sparse attention utils
|
||||
def convert_vertical_slash_indexes(
|
||||
q_seqlens: torch.Tensor, # [BATCH, ]
|
||||
kv_seqlens: torch.Tensor, # [BATCH, ]
|
||||
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
|
||||
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
|
||||
context_size: int,
|
||||
block_size_M: int,
|
||||
block_size_N: int,
|
||||
causal: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
batch_size = slash_indexes.size(0)
|
||||
num_heads = slash_indexes.size(1)
|
||||
nnz_slash = slash_indexes.size(2)
|
||||
nnz_vertical = vertical_indexes.size(2)
|
||||
num_rows = (context_size + block_size_M - 1) // block_size_M
|
||||
|
||||
block_count = torch.zeros(
|
||||
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
|
||||
)
|
||||
block_offset = torch.zeros(
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
nnz_slash,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device,
|
||||
)
|
||||
column_count = torch.zeros(
|
||||
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
|
||||
)
|
||||
column_index = torch.zeros(
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
nnz_vertical,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device,
|
||||
)
|
||||
|
||||
torch.ops.sgl_kernel.convert_vertical_slash_indexes.default(
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
q_seqlens,
|
||||
kv_seqlens,
|
||||
vertical_indexes,
|
||||
slash_indexes,
|
||||
context_size,
|
||||
block_size_M,
|
||||
block_size_N,
|
||||
causal,
|
||||
)
|
||||
return block_count, block_offset, column_count, column_index
|
||||
|
||||
|
||||
def convert_vertical_slash_indexes_mergehead(
|
||||
q_seqlens: torch.Tensor, # [BATCH, ]
|
||||
kv_seqlens: torch.Tensor, # [BATCH, ]
|
||||
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
|
||||
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
|
||||
# [N_HEADS] : different head use different number of indices
|
||||
vertical_indices_count: torch.Tensor,
|
||||
slash_indices_count: torch.Tensor,
|
||||
context_size: int,
|
||||
block_size_M: int,
|
||||
block_size_N: int,
|
||||
causal: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
batch_size = slash_indexes.size(0)
|
||||
num_heads = slash_indexes.size(1)
|
||||
nnz_slash = slash_indexes.size(2)
|
||||
nnz_vertical = vertical_indexes.size(2)
|
||||
num_rows = (context_size + block_size_M - 1) // block_size_M
|
||||
|
||||
block_count = torch.empty(
|
||||
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
|
||||
)
|
||||
block_offset = torch.empty(
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
nnz_slash,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device,
|
||||
)
|
||||
column_count = torch.empty(
|
||||
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
|
||||
)
|
||||
column_index = torch.empty(
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
nnz_vertical,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device,
|
||||
)
|
||||
|
||||
torch.ops.sgl_kernel.convert_vertical_slash_indexes_mergehead.default(
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
q_seqlens,
|
||||
kv_seqlens,
|
||||
vertical_indexes,
|
||||
slash_indexes,
|
||||
vertical_indices_count,
|
||||
slash_indices_count,
|
||||
context_size,
|
||||
block_size_M,
|
||||
block_size_N,
|
||||
causal,
|
||||
)
|
||||
return block_count, block_offset, column_count, column_index
|
||||
|
||||
|
||||
def sparse_attn_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
softcap=0.0, # 0.0 means deactivated
|
||||
alibi_slopes=None,
|
||||
deterministic=False,
|
||||
return_attn_probs=False,
|
||||
*,
|
||||
return_softmax_lse=False,
|
||||
out=None,
|
||||
):
|
||||
"""Compute attention with vertical and slash sparsity patterns.
|
||||
Most Arguments are the same with the flash_attn_func interface, except for 4 extra args:
|
||||
block_count and block_offset for slash sparsity patterns, and
|
||||
column_count and column_index for vertical sparsity patterns.
|
||||
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
|
||||
|
||||
Arguments:
|
||||
q: (batch_size, seqlen, nheads, headdim)
|
||||
k: (batch_size, seqlen, nheads_k, headdim)
|
||||
v: (batch_size, seqlen, nheads_k, headdim)
|
||||
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||||
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
|
||||
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||||
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
||||
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
||||
is added to the attention score of query i and key j.
|
||||
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
||||
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
Return:
|
||||
out: (batch_size, seqlen, nheads, headdim).
|
||||
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
out, softmax_lse = torch.ops.sgl_kernel.fwd_sparse.default(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
out,
|
||||
alibi_slopes,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
causal,
|
||||
softcap,
|
||||
return_attn_probs and dropout_p > 0,
|
||||
None,
|
||||
)
|
||||
return (out, softmax_lse) if return_softmax_lse else out
|
||||
|
||||
|
||||
def sparse_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
softcap=0.0, # 0.0 means deactivated
|
||||
alibi_slopes=None,
|
||||
deterministic=False,
|
||||
return_attn_probs=False,
|
||||
*,
|
||||
return_softmax_lse=False,
|
||||
out=None,
|
||||
):
|
||||
"""Compute attention with vertical and slash sparsity patterns.
|
||||
Most Arguments are the same with the flash_attn_varlen_func interface, except for 4 extra args:
|
||||
block_count and block_offset for slash sparsity patterns, and
|
||||
column_count and column_index for vertical sparsity patterns.
|
||||
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
|
||||
|
||||
Arguments:
|
||||
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
||||
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||||
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||||
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||||
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
|
||||
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||||
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
|
||||
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into q.
|
||||
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into kv.
|
||||
max_seqlen_q: int. Maximum query sequence length in the batch.
|
||||
max_seqlen_k: int. Maximum key sequence length in the batch.
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
softcap: float. Anything > 0 activates softcapping attention.
|
||||
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
||||
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
||||
is added to the attention score of query i and key j.
|
||||
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
||||
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
out, softmax_lse = torch.ops.sgl_kernel.varlen_fwd_sparse.default(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
out,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
None,
|
||||
alibi_slopes,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
softcap,
|
||||
return_attn_probs and dropout_p > 0,
|
||||
None,
|
||||
)
|
||||
return (out, softmax_lse) if return_softmax_lse else out
|
||||
@@ -0,0 +1,63 @@
|
||||
import torch
|
||||
from torch.cuda.streams import ExternalStream
|
||||
|
||||
try:
|
||||
from . import spatial_ops # triggers TORCH extension registration
|
||||
except Exception as _e:
|
||||
_spatial_import_error = _e
|
||||
else:
|
||||
_spatial_import_error = None
|
||||
|
||||
_IMPORT_ERROR = ImportError(
|
||||
"Failed to load sgl_kernel.spatial_ops extension. Ensure CUDA Driver >= 12.4"
|
||||
)
|
||||
|
||||
|
||||
def create_greenctx_stream_by_value(
|
||||
SM_a: int, SM_b: int, device_id: int = None
|
||||
) -> tuple[ExternalStream, ExternalStream]:
|
||||
"""
|
||||
Create two streams for greenctx.
|
||||
Args:
|
||||
sm_A (int): The SM of stream A.
|
||||
sm_B (int): The weight of stream B.
|
||||
device_id (int): The device id.
|
||||
Returns:
|
||||
tuple[ExternalStream, ExternalStream]: The two streams.
|
||||
"""
|
||||
if _spatial_import_error is not None:
|
||||
raise _IMPORT_ERROR from _spatial_import_error
|
||||
if device_id is None:
|
||||
device_id = torch.cuda.current_device()
|
||||
|
||||
res = torch.ops.sgl_kernel.create_greenctx_stream_by_value(SM_a, SM_b, device_id)
|
||||
|
||||
stream_a = ExternalStream(
|
||||
stream_ptr=res[0], device=torch.device(f"cuda:{device_id}")
|
||||
)
|
||||
stream_b = ExternalStream(
|
||||
stream_ptr=res[1], device=torch.device(f"cuda:{device_id}")
|
||||
)
|
||||
|
||||
return stream_a, stream_b
|
||||
|
||||
|
||||
def get_sm_available(device_id: int = None) -> int:
|
||||
"""
|
||||
Get the SMs available on the device.
|
||||
Args:
|
||||
device_id (int): The device id.
|
||||
Returns:
|
||||
int: The SMs available.
|
||||
"""
|
||||
if _spatial_import_error is not None:
|
||||
raise _IMPORT_ERROR from _spatial_import_error
|
||||
if device_id is None:
|
||||
device_id = torch.cuda.current_device()
|
||||
|
||||
device_props = torch.cuda.get_device_properties(device_id)
|
||||
|
||||
# Get the number of Streaming Multiprocessors (SMs)
|
||||
sm_count = device_props.multi_processor_count
|
||||
|
||||
return sm_count
|
||||
@@ -0,0 +1,107 @@
|
||||
import torch
|
||||
from sgl_kernel.utils import get_cuda_stream
|
||||
|
||||
|
||||
def tree_speculative_sampling_target_only(
|
||||
predicts: torch.Tensor, # mutable
|
||||
accept_index: torch.Tensor, # mutable
|
||||
accept_token_num: torch.Tensor, # mutable
|
||||
candidates: torch.Tensor,
|
||||
retrive_index: torch.Tensor,
|
||||
retrive_next_token: torch.Tensor,
|
||||
retrive_next_sibling: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
uniform_samples_for_final_sampling: torch.Tensor,
|
||||
target_probs: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
threshold_single: float = 1.0,
|
||||
threshold_acc: float = 1.0,
|
||||
deterministic: bool = True,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.tree_speculative_sampling_target_only.default(
|
||||
predicts,
|
||||
accept_index,
|
||||
accept_token_num,
|
||||
candidates,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
uniform_samples,
|
||||
uniform_samples_for_final_sampling,
|
||||
target_probs,
|
||||
draft_probs,
|
||||
threshold_single,
|
||||
threshold_acc,
|
||||
deterministic,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
|
||||
|
||||
def verify_tree_greedy(
|
||||
predicts: torch.Tensor, # mutable
|
||||
accept_index: torch.Tensor, # mutable
|
||||
accept_token_num: torch.Tensor, # mutable
|
||||
candidates: torch.Tensor,
|
||||
retrive_index: torch.Tensor,
|
||||
retrive_next_token: torch.Tensor,
|
||||
retrive_next_sibling: torch.Tensor,
|
||||
target_predict: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.verify_tree_greedy.default(
|
||||
predicts,
|
||||
accept_index,
|
||||
accept_token_num,
|
||||
candidates,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
target_predict,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
|
||||
|
||||
def build_tree_kernel_efficient(
|
||||
parent_list: torch.Tensor,
|
||||
selected_index: torch.Tensor,
|
||||
verified_seq_len: torch.Tensor,
|
||||
tree_mask: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
retrive_index: torch.Tensor,
|
||||
retrive_next_token: torch.Tensor,
|
||||
retrive_next_sibling: torch.Tensor,
|
||||
topk: int,
|
||||
depth: int,
|
||||
draft_token_num: int,
|
||||
tree_mask_mode: int,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.build_tree_kernel_efficient.default(
|
||||
parent_list,
|
||||
selected_index,
|
||||
verified_seq_len,
|
||||
tree_mask,
|
||||
positions,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
topk,
|
||||
depth,
|
||||
draft_token_num,
|
||||
tree_mask_mode,
|
||||
)
|
||||
|
||||
|
||||
def segment_packbits(
|
||||
x: torch.Tensor,
|
||||
input_indptr: torch.Tensor,
|
||||
output_indptr: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
batch_size: int,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.segment_packbits.default(
|
||||
x,
|
||||
input_indptr,
|
||||
output_indptr,
|
||||
y,
|
||||
batch_size,
|
||||
torch.cuda.current_stream().cuda_stream,
|
||||
)
|
||||
@@ -0,0 +1,217 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
|
||||
|
||||
|
||||
# vLLM torch native
|
||||
def _apply_rotary_emb(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
is_neox_style: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: [num_tokens, num_heads, head_size]
|
||||
cos: [num_tokens, head_size // 2]
|
||||
sin: [num_tokens, head_size // 2]
|
||||
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
||||
positional embeddings.
|
||||
"""
|
||||
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||
if is_neox_style:
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
else:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
o1 = x1 * cos - x2 * sin
|
||||
o2 = x2 * cos + x1 * sin
|
||||
if is_neox_style:
|
||||
return torch.cat((o1, o2), dim=-1)
|
||||
else:
|
||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
# Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
self.cos_sin_cache: torch.Tensor
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (
|
||||
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
|
||||
)
|
||||
)
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
"""Compute the cos and sin cache."""
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""A PyTorch-native implementation of forward()."""
|
||||
if offsets is not None:
|
||||
positions = positions + offsets
|
||||
|
||||
positions = positions.flatten()
|
||||
num_tokens = positions.shape[0]
|
||||
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
||||
|
||||
# Modification: float32 is required for the rotary embedding to work correctly
|
||||
query = query.to(torch.float32)
|
||||
key = key.to(torch.float32)
|
||||
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
|
||||
# Modification: convert to the correct dtype
|
||||
query = query.to(self.dtype)
|
||||
key = key.to(self.dtype)
|
||||
return query, key
|
||||
|
||||
|
||||
class FlashInferRotaryEmbedding(RotaryEmbedding):
|
||||
def forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
apply_rope_with_cos_sin_cache_inplace(
|
||||
positions=positions,
|
||||
query=query,
|
||||
key=key,
|
||||
fused_set_kv_buffer_arg=fused_set_kv_buffer_arg,
|
||||
head_size=self.head_size,
|
||||
cos_sin_cache=self.cos_sin_cache,
|
||||
is_neox=self.is_neox_style,
|
||||
)
|
||||
|
||||
return query, key
|
||||
|
||||
|
||||
class MHATokenToKVPool:
|
||||
KV_POOL_SIZE = 16384
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_num: int,
|
||||
head_dim: int,
|
||||
):
|
||||
self.head_num = head_num
|
||||
self.head_dim = head_dim
|
||||
self.size = MHATokenToKVPool.KV_POOL_SIZE
|
||||
self.page_size = 1
|
||||
self.store_dtype = torch.bfloat16
|
||||
self.device = "cuda"
|
||||
self.layer_num = 1
|
||||
self.start_layer = 0
|
||||
self._create_buffers()
|
||||
|
||||
def _create_buffers(self):
|
||||
self.k_buffer = [
|
||||
torch.zeros(
|
||||
(self.size + self.page_size, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.zeros(
|
||||
(self.size + self.page_size, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
):
|
||||
layer_id = 0
|
||||
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
||||
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
||||
|
||||
|
||||
def create_inputs(
|
||||
head_size: int,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
device,
|
||||
dtype: torch.dtype,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
):
|
||||
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
|
||||
query = torch.randn(
|
||||
batch_size * seq_len, num_q_heads * head_size, dtype=dtype, device=device
|
||||
)
|
||||
key = torch.randn(
|
||||
batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device
|
||||
)
|
||||
value = torch.randn(
|
||||
batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device
|
||||
)
|
||||
out_cache_loc = torch.randperm(
|
||||
MHATokenToKVPool.KV_POOL_SIZE, dtype=torch.int64, device=device
|
||||
)[: batch_size * seq_len].clone()
|
||||
|
||||
return dict(
|
||||
pos_ids=pos_ids, query=query, key=key, value=value, out_cache_loc=out_cache_loc
|
||||
)
|
||||
@@ -0,0 +1,11 @@
|
||||
import torch
|
||||
|
||||
|
||||
def fast_topk(values, topk, dim):
|
||||
if topk == 1:
|
||||
# Use max along the specified dimension to get both value and index
|
||||
return torch.max(values, dim=dim, keepdim=True)
|
||||
else:
|
||||
# Use topk for efficiency with larger k values
|
||||
# TODO: implement faster cuda kernels for large vocab sizes
|
||||
return torch.topk(values, topk, dim=dim)
|
||||
@@ -0,0 +1,50 @@
|
||||
# Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import functools
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_cuda_stream() -> int:
|
||||
return torch.cuda.current_stream().cuda_stream
|
||||
|
||||
|
||||
_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {}
|
||||
|
||||
|
||||
def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor:
|
||||
key = (name, device)
|
||||
buf = _cache_buf.get(key)
|
||||
if buf is None:
|
||||
buf = torch.empty(bytes, dtype=torch.uint8, device=device)
|
||||
_cache_buf[key] = buf
|
||||
return buf
|
||||
|
||||
|
||||
def _to_tensor_scalar_tuple(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return (x, 0)
|
||||
else:
|
||||
return (None, x)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def is_arch_support_pdl() -> bool:
|
||||
# Hopper arch's compute capability == 9.0
|
||||
device = torch.cuda.current_device()
|
||||
major, minor = torch.cuda.get_device_capability(device)
|
||||
return major >= 9
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = "0.3.8"
|
||||
BIN
sgl-kernel/build/temp.linux-x86_64-cpython-310/.ninja_deps
Normal file
BIN
sgl-kernel/build/temp.linux-x86_64-cpython-310/.ninja_deps
Normal file
Binary file not shown.
10
sgl-kernel/build/temp.linux-x86_64-cpython-310/.ninja_log
Normal file
10
sgl-kernel/build/temp.linux-x86_64-cpython-310/.ninja_log
Normal file
@@ -0,0 +1,10 @@
|
||||
# ninja log v5
|
||||
4 12793 1756950714004023222 /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/speculative/eagle_utils.o c4ef5c8f5ca38169
|
||||
4 22797 1756950724004023564 /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/grammar/apply_token_bitmask_inplace_cuda.o 983ca8e755dab2fa
|
||||
3 24387 1756950725592023618 /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/allreduce/custom_all_reduce.o 95a36ff854253806
|
||||
4 29769 1756950730660023792 /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/common_extension_rocm.o d99ffa7422128b8b
|
||||
3 47747 1756950748952024417 /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/allreduce/quick_all_reduce.o ffc38859847f9f7
|
||||
32 22789 1756951323296044067 /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/kvcacheio/transfer.o 51f92c934bf3b3a4
|
||||
32 22927 1756951323432044072 /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/moe/moe_align_kernel.o 3cbca550065cd70f
|
||||
32 23521 1756951324028044092 /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/elementwise/activation.o c8aa216837c116a0
|
||||
33 26809 1756951327312044205 /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/moe/moe_topk_softmax_kernels.o a28bfee01bc84adc
|
||||
38
sgl-kernel/build/temp.linux-x86_64-cpython-310/build.ninja
Normal file
38
sgl-kernel/build/temp.linux-x86_64-cpython-310/build.ninja
Normal file
@@ -0,0 +1,38 @@
|
||||
ninja_required_version = 1.3
|
||||
cxx = c++
|
||||
nvcc = /opt/dtk/bin/hipcc
|
||||
|
||||
cflags = -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -fPIC -I/home/git_sglang/sglang/sgl-kernel/include -I/home/git_sglang/sglang/sgl-kernel/csrc -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10 -c
|
||||
post_cflags = -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -O3 -Wno-switch-bool -Wno-macro-redefined -Wno-deprecated-declarations -w -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=common_ops -D_GLIBCXX_USE_CXX11_ABI=1 -std=c++17
|
||||
cuda_cflags = -I/home/git_sglang/sglang/sgl-kernel/include -I/home/git_sglang/sglang/sgl-kernel/csrc -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10 -c
|
||||
cuda_post_cflags = -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -fPIC -O3 -std=c++17 -D__HIP_PLATFORM_HCC__=1 --offload-arch=gfx928 --offload-arch=gfx936 --gpu-max-threads-per-block=1024 -Wno-macro-redefined '' -funroll-loops -Rpass-analysis=unroll-loops -w -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=common_ops -D_GLIBCXX_USE_CXX11_ABI=1 -fno-gpu-rdc
|
||||
cuda_dlink_post_cflags =
|
||||
ldflags =
|
||||
|
||||
rule compile
|
||||
command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
|
||||
depfile = $out.d
|
||||
deps = gcc
|
||||
|
||||
rule cuda_compile
|
||||
command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
build /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/allreduce/custom_all_reduce.o: cuda_compile /home/git_sglang/sglang/sgl-kernel/csrc/allreduce/custom_all_reduce.hip
|
||||
build /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/allreduce/quick_all_reduce.o: cuda_compile /home/git_sglang/sglang/sgl-kernel/csrc/allreduce/quick_all_reduce.hip
|
||||
build /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/common_extension_rocm.o: compile /home/git_sglang/sglang/sgl-kernel/csrc/common_extension_rocm.cc
|
||||
build /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/elementwise/activation.o: cuda_compile /home/git_sglang/sglang/sgl-kernel/csrc/elementwise/activation.hip
|
||||
build /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/grammar/apply_token_bitmask_inplace_cuda.o: cuda_compile /home/git_sglang/sglang/sgl-kernel/csrc/grammar/apply_token_bitmask_inplace_cuda.hip
|
||||
build /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/kvcacheio/transfer.o: cuda_compile /home/git_sglang/sglang/sgl-kernel/csrc/kvcacheio/transfer.hip
|
||||
build /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/moe/moe_align_kernel.o: cuda_compile /home/git_sglang/sglang/sgl-kernel/csrc/moe/moe_align_kernel.hip
|
||||
build /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/moe/moe_topk_softmax_kernels.o: cuda_compile /home/git_sglang/sglang/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.hip
|
||||
build /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/speculative/eagle_utils.o: cuda_compile /home/git_sglang/sglang/sgl-kernel/csrc/speculative/eagle_utils.hip
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
21
sgl-kernel/cmake/utils.cmake
Normal file
21
sgl-kernel/cmake/utils.cmake
Normal file
@@ -0,0 +1,21 @@
|
||||
# Adapt from: https://github.com/neuralmagic/vllm-flash-attention/blob/main/cmake/utils.cmake
|
||||
#
|
||||
# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in
|
||||
# `CUDA_ARCH_FLAGS`.
|
||||
#
|
||||
# Example:
|
||||
# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75"
|
||||
# clear_cuda_arches(CUDA_ARCH_FLAGS)
|
||||
# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75"
|
||||
# CMAKE_CUDA_FLAGS="-Wall"
|
||||
#
|
||||
macro(clear_cuda_arches CUDA_ARCH_FLAGS)
|
||||
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
|
||||
string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS
|
||||
${CMAKE_CUDA_FLAGS})
|
||||
|
||||
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
|
||||
# and passed back via the `CUDA_ARCHITECTURES` property.
|
||||
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
|
||||
${CMAKE_CUDA_FLAGS})
|
||||
endmacro()
|
||||
137
sgl-kernel/csrc/allreduce/custom_all_reduce.cu
Normal file
137
sgl-kernel/csrc/allreduce/custom_all_reduce.cu
Normal file
@@ -0,0 +1,137 @@
|
||||
// Adapted from: https://github.com/vllm-project/vllm/blob/v0.8.2/csrc/custom_all_reduce.cu
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "custom_all_reduce.cuh"
|
||||
|
||||
// Fake pointer type, must match fptr_t type in ops.h.
|
||||
// We use this type alias to indicate when pointers are passed in as int64_t.
|
||||
using fptr_t = int64_t;
|
||||
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||
|
||||
fptr_t
|
||||
init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs, torch::Tensor& rank_data, int64_t rank, bool full_nvlink) {
|
||||
int world_size = fake_ipc_ptrs.size();
|
||||
if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported");
|
||||
if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now");
|
||||
if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in");
|
||||
|
||||
sglang::Signal* ipc_ptrs[8];
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
ipc_ptrs[i] = reinterpret_cast<sglang::Signal*>(fake_ipc_ptrs[i]);
|
||||
}
|
||||
return (fptr_t) new sglang::CustomAllreduce(
|
||||
ipc_ptrs, rank_data.data_ptr(), rank_data.numel(), rank, world_size, full_nvlink);
|
||||
}
|
||||
|
||||
/**
|
||||
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
|
||||
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
|
||||
* because it allows transpose of contiguous slice (i.e. slicing the first
|
||||
* dimension). Currently, we require this because stride information is not
|
||||
* passed into the kernels and we treat input tensors as flat.
|
||||
*
|
||||
* Examples
|
||||
* A = torch.zeros(3, 3, 3)
|
||||
* 1. A: OK
|
||||
* 2. A[1:]: OK
|
||||
* 3. A.permute(2, 0, 1): OK
|
||||
* 4. A[1:].permute(2, 0, 1): OK
|
||||
* 5. A[None].expand(2, -1, -1, -1): Not OK
|
||||
* 6. A[:, 1:, 1:]: Not OK
|
||||
*/
|
||||
bool _is_weak_contiguous(torch::Tensor& t) {
|
||||
return t.is_contiguous() ||
|
||||
(t.storage().nbytes() - t.storage_offset() * t.element_size() == t.numel() * t.element_size());
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs an out-of-place allreduce and stores result in out.
|
||||
*
|
||||
* If _reg_buffer is null, assumes inp.data_ptr() is already IPC-registered.
|
||||
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
|
||||
* copied into _reg_buffer.
|
||||
*/
|
||||
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
TORCH_CHECK(_is_weak_contiguous(out));
|
||||
TORCH_CHECK(_is_weak_contiguous(inp));
|
||||
auto input_size = inp.numel() * inp.element_size();
|
||||
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
|
||||
if (reg_buffer) {
|
||||
TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes);
|
||||
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size, cudaMemcpyDeviceToDevice, stream));
|
||||
} else {
|
||||
reg_buffer = inp.data_ptr();
|
||||
}
|
||||
switch (out.scalar_type()) {
|
||||
case at::ScalarType::Float: {
|
||||
fa->allreduce<float>(
|
||||
stream, reinterpret_cast<float*>(reg_buffer), reinterpret_cast<float*>(out.data_ptr()), out.numel());
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Half: {
|
||||
fa->allreduce<half>(
|
||||
stream, reinterpret_cast<half*>(reg_buffer), reinterpret_cast<half*>(out.data_ptr()), out.numel());
|
||||
break;
|
||||
}
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
case at::ScalarType::BFloat16: {
|
||||
fa->allreduce<nv_bfloat16>(
|
||||
stream,
|
||||
reinterpret_cast<nv_bfloat16*>(reg_buffer),
|
||||
reinterpret_cast<nv_bfloat16*>(out.data_ptr()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
default:
|
||||
throw std::runtime_error("custom allreduce only supports float32, float16 and bfloat16");
|
||||
}
|
||||
}
|
||||
|
||||
void dispose(fptr_t _fa) {
|
||||
delete reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
}
|
||||
|
||||
int64_t meta_size() {
|
||||
return sizeof(sglang::Signal);
|
||||
}
|
||||
|
||||
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
|
||||
void* ipc_ptrs[8];
|
||||
for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
|
||||
ipc_ptrs[i] = reinterpret_cast<void*>(fake_ipc_ptrs[i]);
|
||||
}
|
||||
fa->register_buffer(ipc_ptrs);
|
||||
}
|
||||
|
||||
// Use vector<int64_t> to represent byte data for python binding compatibility.
|
||||
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
auto [handle, offsets] = fa->get_graph_buffer_ipc_meta();
|
||||
std::vector<int64_t> bytes(handle.begin(), handle.end());
|
||||
return std::make_tuple(bytes, offsets);
|
||||
}
|
||||
|
||||
// Use vector<int64_t> to represent byte data for python binding compatibility.
|
||||
void register_graph_buffers(
|
||||
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets) {
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
std::vector<std::string> bytes;
|
||||
bytes.reserve(handles.size());
|
||||
for (int i = 0; i < handles.size(); i++) {
|
||||
bytes.emplace_back(handles[i].begin(), handles[i].end());
|
||||
}
|
||||
bytes.reserve(handles.size());
|
||||
fa->register_graph_buffers(bytes, offsets);
|
||||
}
|
||||
489
sgl-kernel/csrc/allreduce/custom_all_reduce.cuh
Normal file
489
sgl-kernel/csrc/allreduce/custom_all_reduce.cuh
Normal file
@@ -0,0 +1,489 @@
|
||||
// Adapted from https://github.com/vllm-project/vllm/blob/v0.8.2/csrc/custom_all_reduce.cuh
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <array>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
namespace sglang {
|
||||
|
||||
constexpr int kMaxBlocks = 36;
|
||||
// Counter may overflow, but it's fine since unsigned int overflow is
|
||||
// well-defined behavior.
|
||||
using FlagType = uint32_t;
|
||||
struct Signal {
|
||||
alignas(128) FlagType self_counter[kMaxBlocks][8];
|
||||
// Two sets of peer counters are needed for two syncs. The reason is that
|
||||
// it's possible for peer GPU block to arrive at the second sync point while
|
||||
// the current GPU block haven't passed the first sync point. Thus, peer GPU
|
||||
// may write counter+1 while current GPU is busy waiting for counter. We use
|
||||
// alternating counter array to avoid this possibility.
|
||||
alignas(128) FlagType peer_counter[2][kMaxBlocks][8];
|
||||
};
|
||||
|
||||
struct __align__(16) RankData {
|
||||
const void* __restrict__ ptrs[8];
|
||||
};
|
||||
|
||||
struct __align__(16) RankSignals {
|
||||
Signal* signals[8];
|
||||
};
|
||||
|
||||
// like std::array, but aligned
|
||||
template <typename T, int sz>
|
||||
struct __align__(alignof(T) * sz) array_t {
|
||||
T data[sz];
|
||||
using type = T;
|
||||
static constexpr int size = sz;
|
||||
};
|
||||
|
||||
// use packed type to maximize memory efficiency
|
||||
// goal: generate ld.128 and st.128 instructions
|
||||
template <typename T>
|
||||
struct packed_t {
|
||||
// the (P)acked type for load/store
|
||||
using P = array_t<T, 16 / sizeof(T)>;
|
||||
// the (A)ccumulator type for reduction
|
||||
using A = array_t<float, 16 / sizeof(T)>;
|
||||
};
|
||||
|
||||
#define DINLINE __device__ __forceinline__
|
||||
|
||||
// scalar cast functions
|
||||
DINLINE float upcast_s(half val) {
|
||||
return __half2float(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
DINLINE T downcast_s(float val);
|
||||
template <>
|
||||
DINLINE half downcast_s(float val) {
|
||||
return __float2half(val);
|
||||
}
|
||||
|
||||
// scalar add functions
|
||||
// for some reason when compiling with Pytorch, the + operator for half and
|
||||
// bfloat is disabled so we call the intrinsics directly
|
||||
DINLINE half& assign_add(half& a, half b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
DINLINE float& assign_add(float& a, float b) {
|
||||
return a += b;
|
||||
}
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
DINLINE float upcast_s(nv_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
template <>
|
||||
DINLINE nv_bfloat16 downcast_s(float val) {
|
||||
return __float2bfloat16(val);
|
||||
}
|
||||
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
assign_add(a.data[i], b.data[i]);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
array_t<float, N> out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
out.data[i] = upcast_s(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename O>
|
||||
DINLINE O downcast(array_t<float, O::size> val) {
|
||||
if constexpr (std::is_same<typename O::type, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
O out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < O::size; i++) {
|
||||
out.data[i] = downcast_s<typename O::type>(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
||||
asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
|
||||
#else
|
||||
asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
|
||||
#endif
|
||||
}
|
||||
|
||||
static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
|
||||
FlagType flag;
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
||||
asm volatile("ld.acquire.sys.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
|
||||
#else
|
||||
asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;" : "=r"(flag) : "l"(flag_addr));
|
||||
#endif
|
||||
return flag;
|
||||
}
|
||||
|
||||
static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) {
|
||||
asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
|
||||
}
|
||||
|
||||
static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
|
||||
FlagType flag;
|
||||
asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
|
||||
return flag;
|
||||
}
|
||||
|
||||
// is_start: whether this is the very first synchronization barrier.
|
||||
// need_fence: whether a memory fence is needed. If true, a release-acquire
|
||||
// semantic is used to enforce memory access order before and after this
|
||||
// barrier.
|
||||
template <int ngpus, bool is_start, bool need_fence = false>
|
||||
DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg, int rank) {
|
||||
if constexpr (!is_start) __syncthreads();
|
||||
static_assert(!(is_start && need_fence)); // Start barrier shouldn't need fence.
|
||||
if (threadIdx.x < ngpus) {
|
||||
// Increment the counter. Technically we only need one counter, but we use
|
||||
// multiple per block to eliminate the need to share the counter via smem.
|
||||
auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1;
|
||||
// Write the expected counter value to peer and wait for correct value from
|
||||
// peer.
|
||||
auto peer_counter_ptr = &sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank];
|
||||
auto self_counter_ptr = &self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x];
|
||||
if constexpr (need_fence) {
|
||||
st_flag_release(peer_counter_ptr, val);
|
||||
while (ld_flag_acquire(self_counter_ptr) != val)
|
||||
;
|
||||
} else {
|
||||
st_flag_volatile(peer_counter_ptr, val);
|
||||
while (ld_flag_volatile(self_counter_ptr) != val)
|
||||
;
|
||||
}
|
||||
}
|
||||
if constexpr (is_start || need_fence) __syncthreads();
|
||||
}
|
||||
|
||||
template <typename P, int ngpus, typename A>
|
||||
DINLINE P packed_reduce(const P* ptrs[], int idx) {
|
||||
A tmp = upcast(ptrs[0][idx]);
|
||||
#pragma unroll
|
||||
for (int i = 1; i < ngpus; i++) {
|
||||
packed_assign_add(tmp, upcast(ptrs[i][idx]));
|
||||
}
|
||||
return downcast<P>(tmp);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(
|
||||
RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size) {
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
// note: we don't reorder the address so the accumulation order is the same
|
||||
// for all ranks, ensuring bitwise identical results
|
||||
auto dp = *_dp;
|
||||
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
|
||||
// do the actual reduction
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) {
|
||||
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
|
||||
}
|
||||
multi_gpu_barrier<ngpus, false>(sg, self_sg, rank);
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
DINLINE P* get_tmp_buf(Signal* sg) {
|
||||
return (P*)(((Signal*)sg) + 1);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(
|
||||
RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
int part = size / ngpus;
|
||||
int start = rank * part;
|
||||
int end = rank == ngpus - 1 ? size : start + part;
|
||||
int largest_part = part + size % ngpus;
|
||||
const P* ptrs[ngpus];
|
||||
P* tmps[ngpus];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int target = (rank + i) % ngpus;
|
||||
ptrs[i] = (const P*)_dp->ptrs[target];
|
||||
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
||||
}
|
||||
auto tmp_out = tmps[0];
|
||||
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
|
||||
// stage 1: reduce scatter
|
||||
for (int idx = start + tid; idx < end; idx += stride) {
|
||||
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
|
||||
}
|
||||
multi_gpu_barrier<ngpus, false, true>(sg, self_sg, rank);
|
||||
|
||||
// stage 2: allgather. Note: it's important to match the tid between
|
||||
// the two stages, because visibility across devices is only guaranteed
|
||||
// between threads that have the same tid. If thread i computes the sum of
|
||||
// start + i in the first stage, then thread i also gathers start + i from all
|
||||
// ranks.
|
||||
for (int idx = tid; idx < largest_part; idx += stride) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int gather_from_rank = ((rank + i) % ngpus);
|
||||
if (gather_from_rank == ngpus - 1 || idx < part) {
|
||||
int dst_idx = gather_from_rank * part + idx;
|
||||
((P*)result)[dst_idx] = tmps[i][idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
|
||||
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
|
||||
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
|
||||
|
||||
class CustomAllreduce {
|
||||
public:
|
||||
int rank_;
|
||||
int world_size_;
|
||||
bool full_nvlink_;
|
||||
|
||||
RankSignals sg_;
|
||||
// Stores an map from a pointer to its peer pointters from all ranks.
|
||||
std::unordered_map<void*, RankData*> buffers_;
|
||||
Signal* self_sg_;
|
||||
|
||||
// Stores rank data from all ranks. This is mainly for cuda graph purposes.
|
||||
// For cuda graph to work, all kernel arguments must be fixed during graph
|
||||
// capture time. However, the peer pointers are not known during graph capture
|
||||
// time. Therefore, during capture, we increment the rank data pointer and use
|
||||
// that as the argument to the kernel. The kernel arguments are stored in
|
||||
// graph_unreg_buffers_. The actual peer pointers will be filled in at the
|
||||
// memory pointed to by the pointers in graph_unreg_buffers_ when
|
||||
// the IPC handles are exchanged between ranks.
|
||||
//
|
||||
// The overall process looks like this:
|
||||
// 1. Graph capture.
|
||||
// 2. Each rank obtains the IPC handles for each addresses used during cuda
|
||||
// graph capture using get_graph_buffer_ipc_meta.
|
||||
// 3. (In Python) all gather the IPC handles.
|
||||
// 4. Obtain the peer pointers by opening the IPC handles, and store them in
|
||||
// the rank data array at corresponding positions.
|
||||
RankData *d_rank_data_base_, *d_rank_data_end_;
|
||||
std::vector<void*> graph_unreg_buffers_;
|
||||
// a map from IPC handles to opened IPC pointers
|
||||
std::map<IPC_KEY, char*> ipc_handles_;
|
||||
|
||||
/**
|
||||
* Signals are an array of ipc-enabled buffers from all ranks.
|
||||
* For each of the buffer, the layout is as follows:
|
||||
* | -- sizeof(Signal) -- | ------ a few MB ----- |
|
||||
* The first section is for allreduce synchronization, and the second section
|
||||
* is for storing the intermediate results required by some allreduce algos.
|
||||
*
|
||||
* Note: this class does not own any device memory. Any required buffers
|
||||
* are passed in from the constructor.
|
||||
*/
|
||||
CustomAllreduce(
|
||||
Signal** signals, void* rank_data, size_t rank_data_sz, int rank, int world_size, bool full_nvlink = true)
|
||||
: rank_(rank),
|
||||
world_size_(world_size),
|
||||
full_nvlink_(full_nvlink),
|
||||
self_sg_(signals[rank]),
|
||||
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
|
||||
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
sg_.signals[i] = signals[i];
|
||||
}
|
||||
}
|
||||
|
||||
char* open_ipc_handle(const void* ipc_handle) {
|
||||
auto [it, new_handle] = ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
|
||||
if (new_handle) {
|
||||
char* ipc_ptr;
|
||||
CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle(
|
||||
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle), cudaIpcMemLazyEnablePeerAccess));
|
||||
it->second = ipc_ptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::pair<std::string, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
auto handle_sz = sizeof(cudaIpcMemHandle_t);
|
||||
std::string handles(handle_sz * num_buffers, static_cast<char>(0));
|
||||
std::vector<int64_t> offsets(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto ptr = graph_unreg_buffers_[i];
|
||||
void* base_ptr;
|
||||
// note: must share the base address of each allocation, or we get wrong
|
||||
// address
|
||||
if (cuPointerGetAttribute(&base_ptr, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (CUdeviceptr)ptr) != CUDA_SUCCESS)
|
||||
throw std::runtime_error("failed to get pointer attr");
|
||||
CHECK_CUDA_SUCCESS(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
|
||||
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
|
||||
}
|
||||
return std::make_pair(handles, offsets);
|
||||
}
|
||||
|
||||
void check_rank_data_capacity(size_t num = 1) {
|
||||
if (d_rank_data_base_ + num > d_rank_data_end_)
|
||||
throw std::runtime_error(
|
||||
"Rank data buffer is overflowed by " + std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
|
||||
}
|
||||
|
||||
/**
|
||||
* Register already-shared IPC pointers.
|
||||
*/
|
||||
void register_buffer(void** ptrs) {
|
||||
check_rank_data_capacity();
|
||||
RankData data;
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
data.ptrs[i] = ptrs[i];
|
||||
}
|
||||
auto d_data = d_rank_data_base_++;
|
||||
CHECK_CUDA_SUCCESS(cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
|
||||
buffers_[ptrs[rank_]] = d_data;
|
||||
}
|
||||
|
||||
// Note: when registering graph buffers, we intentionally choose to not
|
||||
// deduplicate the addresses. That means if the allocator reuses some
|
||||
// addresses, they will be registered again. This is to account for the remote
|
||||
// possibility of different allocation patterns between ranks. For example,
|
||||
// rank 1 may get the same input address for the second allreduce, but rank 2
|
||||
// got a different address. IPC handles have internal reference counting
|
||||
// mechanism so overhead should be small.
|
||||
void
|
||||
register_graph_buffers(const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets) {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
check_rank_data_capacity(num_buffers);
|
||||
std::vector<RankData> rank_data(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto self_ptr = graph_unreg_buffers_[i];
|
||||
auto& rd = rank_data[i];
|
||||
for (int j = 0; j < world_size_; j++) {
|
||||
if (j != rank_) {
|
||||
char* handle = open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
|
||||
handle += offsets[j][i];
|
||||
rd.ptrs[j] = handle;
|
||||
} else {
|
||||
rd.ptrs[j] = self_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
CHECK_CUDA_SUCCESS(
|
||||
cudaMemcpy(d_rank_data_base_, rank_data.data(), sizeof(RankData) * num_buffers, cudaMemcpyHostToDevice));
|
||||
d_rank_data_base_ += num_buffers;
|
||||
graph_unreg_buffers_.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs allreduce, assuming input has already been registered.
|
||||
*
|
||||
* Block and grid default configs are results after careful grid search. Using
|
||||
* 36 blocks give the best or close to the best runtime on the devices I
|
||||
* tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only
|
||||
* take a small amount of SMs. Not quite sure the underlying reason, but my
|
||||
* guess is that too many SMs will cause contention on NVLink bus.
|
||||
*/
|
||||
template <typename T>
|
||||
void allreduce(cudaStream_t stream, T* input, T* output, int size, int threads = 512, int block_limit = 36) {
|
||||
auto d = packed_t<T>::P::size;
|
||||
if (size % d != 0)
|
||||
throw std::runtime_error(
|
||||
"custom allreduce currently requires input length to be multiple "
|
||||
"of " +
|
||||
std::to_string(d));
|
||||
if (block_limit > kMaxBlocks)
|
||||
throw std::runtime_error(
|
||||
"max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit));
|
||||
|
||||
RankData* ptrs;
|
||||
cudaStreamCaptureStatus status;
|
||||
CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &status));
|
||||
if (status == cudaStreamCaptureStatusActive) {
|
||||
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
|
||||
graph_unreg_buffers_.push_back(input);
|
||||
} else {
|
||||
auto it = buffers_.find(input);
|
||||
if (it == buffers_.end())
|
||||
throw std::runtime_error(
|
||||
"buffer address " + std::to_string(reinterpret_cast<uint64_t>(input)) + " is not registered!");
|
||||
ptrs = it->second;
|
||||
}
|
||||
|
||||
size /= d;
|
||||
auto bytes = size * sizeof(typename packed_t<T>::P);
|
||||
int blocks = std::min(block_limit, (size + threads - 1) / threads);
|
||||
#define KL(ngpus, name) name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, rank_, size);
|
||||
// TODO(hanzhi713): Threshold is different for A100 and H100.
|
||||
// Add per device threshold.
|
||||
#define REDUCE_CASE(ngpus) \
|
||||
case ngpus: { \
|
||||
if (world_size_ == 2) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else if (full_nvlink_) { \
|
||||
if ((world_size_ <= 4 && bytes < 512 * 1024) || (world_size_ <= 8 && bytes < 256 * 1024)) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else { \
|
||||
KL(ngpus, cross_device_reduce_2stage); \
|
||||
} \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
|
||||
switch (world_size_) {
|
||||
REDUCE_CASE(2)
|
||||
REDUCE_CASE(4)
|
||||
REDUCE_CASE(6)
|
||||
REDUCE_CASE(8)
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
|
||||
"gpus = " +
|
||||
std::to_string(world_size_));
|
||||
}
|
||||
#undef REDUCE_CASE
|
||||
#undef KL
|
||||
}
|
||||
|
||||
~CustomAllreduce() {
|
||||
for (auto [_, ptr] : ipc_handles_) {
|
||||
CHECK_CUDA_SUCCESS(cudaIpcCloseMemHandle(ptr));
|
||||
}
|
||||
}
|
||||
};
|
||||
/**
|
||||
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
|
||||
a template instantiation:
|
||||
* template void sglang::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
|
||||
half *, int, int, int);
|
||||
*/
|
||||
} // namespace sglang
|
||||
180
sgl-kernel/csrc/allreduce/custom_all_reduce.hip
Normal file
180
sgl-kernel/csrc/allreduce/custom_all_reduce.hip
Normal file
@@ -0,0 +1,180 @@
|
||||
// !!! This is a file automatically generated by hipify!!!
|
||||
#include <ATen/hip/Exceptions.h>
|
||||
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
|
||||
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "custom_all_reduce_hip.cuh"
|
||||
|
||||
// fake pointer type, must match fptr_t type in ops.h
|
||||
using fptr_t = int64_t;
|
||||
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||
|
||||
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets, int64_t rank,
|
||||
bool full_nvlink) {
|
||||
int world_size = offsets.size();
|
||||
if (world_size > 8)
|
||||
throw std::invalid_argument("world size > 8 is not supported");
|
||||
if (world_size % 2 != 0)
|
||||
throw std::invalid_argument("Odd num gpus is not supported for now");
|
||||
if (world_size != handles.size())
|
||||
throw std::invalid_argument(
|
||||
"handles length should equal to offsets length");
|
||||
if (rank < 0 || rank >= world_size)
|
||||
throw std::invalid_argument("invalid rank passed in");
|
||||
|
||||
hipIpcMemHandle_t ipc_handles[8];
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(hipIpcMemHandle_t));
|
||||
}
|
||||
return (fptr_t) new sglang::CustomAllreduce(
|
||||
reinterpret_cast<sglang::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
|
||||
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
|
||||
}
|
||||
|
||||
/**
|
||||
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
|
||||
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
|
||||
* because it allows transpose of contiguous slice (i.e. slicing the first
|
||||
* dimension). Currently, we require this because stride information is not
|
||||
* passed into the kernels and we treat input tensors as flat.
|
||||
*
|
||||
* Examples
|
||||
* A = torch.zeros(3, 3, 3)
|
||||
* 1. A: OK
|
||||
* 2. A[1:]: OK
|
||||
* 3. A.permute(2, 0, 1): OK
|
||||
* 4. A[1:].permute(2, 0, 1): OK
|
||||
* 5. A[None].expand(2, -1, -1, -1): Not OK
|
||||
* 6. A[:, 1:, 1:]: Not OK
|
||||
*/
|
||||
bool _is_weak_contiguous(torch::Tensor& t) {
|
||||
return t.is_contiguous() ||
|
||||
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
|
||||
t.numel() * t.element_size());
|
||||
}
|
||||
|
||||
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||
hipStream_t stream) {
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
TORCH_CHECK(_is_weak_contiguous(out));
|
||||
switch (out.scalar_type()) {
|
||||
case at::ScalarType::Float: {
|
||||
fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()),
|
||||
reinterpret_cast<float*>(out.data_ptr()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Half: {
|
||||
fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
|
||||
reinterpret_cast<half*>(out.data_ptr()), out.numel());
|
||||
break;
|
||||
}
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
case at::ScalarType::BFloat16: {
|
||||
fa->allreduce<nv_bfloat16>(
|
||||
stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()),
|
||||
reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports float32, float16 and bfloat16");
|
||||
}
|
||||
}
|
||||
|
||||
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
|
||||
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
|
||||
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
_all_reduce(_fa, inp, out, stream);
|
||||
}
|
||||
|
||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
|
||||
torch::Tensor& out) {
|
||||
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
|
||||
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
|
||||
auto input_size = inp.numel() * inp.element_size();
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
|
||||
"registered buffer is too small to contain the input");
|
||||
AT_CUDA_CHECK(hipMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
|
||||
input_size, hipMemcpyDeviceToDevice, stream));
|
||||
_all_reduce(_fa, reg_buffer, out, stream);
|
||||
}
|
||||
|
||||
void dispose(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
delete fa;
|
||||
}
|
||||
|
||||
int64_t meta_size() { return sizeof(sglang::Signal); }
|
||||
|
||||
void register_buffer(fptr_t _fa, torch::Tensor& t,
|
||||
const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets) {
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
fa->register_buffer(handles, offsets, t.data_ptr());
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||
fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
||||
auto handles =
|
||||
torch::empty({static_cast<int64_t>(handle_bytes.size())}, options);
|
||||
std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size());
|
||||
return {handles, std::move(offsets)};
|
||||
}
|
||||
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
||||
const std::vector<std::vector<int64_t>>& offsets) {
|
||||
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
|
||||
fa->register_graph_buffers(handles, offsets);
|
||||
}
|
||||
|
||||
void free_meta_buffer(void* buffer) { CUDACHECK(hipFree(buffer)); }
|
||||
|
||||
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp) {
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
||||
auto data_handle =
|
||||
torch::empty({static_cast<int64_t>(sizeof(hipIpcMemHandle_t))}, options);
|
||||
CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)data_handle.data_ptr(),
|
||||
inp.data_ptr()));
|
||||
return data_handle;
|
||||
}
|
||||
|
||||
torch::Tensor allocate_meta_buffer(int64_t size) {
|
||||
auto device_index = c10::hip::current_device();
|
||||
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
|
||||
void* buffer;
|
||||
hipStreamCaptureMode mode = hipStreamCaptureModeRelaxed;
|
||||
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
AT_CUDA_CHECK(hipThreadExchangeStreamCaptureMode(&mode));
|
||||
AT_CUDA_CHECK(
|
||||
hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
|
||||
AT_CUDA_CHECK(hipMemsetAsync(buffer, 0, size, stream));
|
||||
AT_CUDA_CHECK(hipStreamSynchronize(stream));
|
||||
AT_CUDA_CHECK(hipThreadExchangeStreamCaptureMode(&mode));
|
||||
auto options = torch::TensorOptions()
|
||||
.dtype(torch::kI8)
|
||||
.device(torch::kCUDA, device_index);
|
||||
return torch::from_blob(buffer, {size}, free_meta_buffer, options);
|
||||
}
|
||||
|
||||
std::vector<uint8_t> get_device_bdf(int dev) {
|
||||
char busIdStr[] = "0000:00:00.0";
|
||||
std::vector<uint8_t> bdf(sizeof(busIdStr), 0);
|
||||
CUDACHECK(hipDeviceGetPCIBusId((char*)bdf.data(), sizeof(busIdStr), dev));
|
||||
bdf.resize(bdf.size() - 1); // remove trailing NULL
|
||||
return bdf;
|
||||
}
|
||||
582
sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
Normal file
582
sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh
Normal file
@@ -0,0 +1,582 @@
|
||||
// !!! This is a file automatically generated by hipify!!!
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_bf16.h>
|
||||
typedef __hip_bfloat16 nv_bfloat16;
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
#endif
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#define CUDACHECK(cmd) \
|
||||
do { \
|
||||
hipError_t e = cmd; \
|
||||
if (e != hipSuccess) { \
|
||||
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, hipGetErrorString(e)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace sglang {
|
||||
|
||||
constexpr int kMaxBlocks = 64;
|
||||
// note: we don't want to use atomics for signals because peer atomics are no
|
||||
// supported on PCIe links
|
||||
struct Signal {
|
||||
alignas(128) uint32_t start[kMaxBlocks][8];
|
||||
alignas(128) uint32_t end[kMaxBlocks][8];
|
||||
alignas(128) uint32_t _flag[kMaxBlocks]; // incremental flags for each rank
|
||||
};
|
||||
|
||||
#ifdef USE_ROCM
|
||||
struct __align__(16) RankData {
|
||||
const void* ptrs[8];
|
||||
};
|
||||
#else
|
||||
struct __align__(16) RankData {
|
||||
const void* __restrict__ ptrs[8];
|
||||
};
|
||||
#endif
|
||||
|
||||
struct __align__(16) RankSignals {
|
||||
#ifndef USE_ROCM
|
||||
volatile
|
||||
#endif
|
||||
Signal* signals[8];
|
||||
};
|
||||
|
||||
// like std::array, but aligned
|
||||
template <typename T, int sz>
|
||||
struct __align__(alignof(T) * sz) array_t {
|
||||
T data[sz];
|
||||
using type = T;
|
||||
static constexpr int size = sz;
|
||||
};
|
||||
|
||||
// use packed type to maximize memory efficiency
|
||||
// goal: generate ld.128 and st.128 instructions
|
||||
template <typename T>
|
||||
struct packed_t {
|
||||
// the (P)acked type for load/store
|
||||
using P = array_t<T, 16 / sizeof(T)>;
|
||||
// the (A)ccumulator type for reduction
|
||||
using A = array_t<float, 16 / sizeof(T)>;
|
||||
};
|
||||
|
||||
#define DINLINE __device__ __forceinline__
|
||||
|
||||
// scalar cast functions
|
||||
DINLINE float upcast_s(half val) {
|
||||
return __half2float(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
DINLINE T downcast_s(float val);
|
||||
template <>
|
||||
DINLINE half downcast_s(float val) {
|
||||
return __float2half(val);
|
||||
}
|
||||
|
||||
// scalar add functions
|
||||
// for some reason when compiling with Pytorch, the + operator for half and
|
||||
// bfloat is disabled so we call the intrinsics directly
|
||||
DINLINE half& assign_add(half& a, half b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
DINLINE float& assign_add(float& a, float b) {
|
||||
return a += b;
|
||||
}
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
DINLINE float upcast_s(nv_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
template <>
|
||||
DINLINE nv_bfloat16 downcast_s(float val) {
|
||||
return __float2bfloat16(val);
|
||||
}
|
||||
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
assign_add(a.data[i], b.data[i]);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
array_t<float, N> out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
out.data[i] = upcast_s(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename O>
|
||||
DINLINE O downcast(array_t<float, O::size> val) {
|
||||
if constexpr (std::is_same<typename O::type, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
O out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < O::size; i++) {
|
||||
out.data[i] = downcast_s<typename O::type>(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
// This function is meant to be used as the first synchronization in the all
|
||||
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
|
||||
// prior memory accesses. Note: volatile writes will not be reordered against
|
||||
// other volatile writes.
|
||||
template <int ngpus>
|
||||
DINLINE void start_sync(
|
||||
const RankSignals& sg,
|
||||
#ifndef USE_ROCM
|
||||
volatile
|
||||
#endif
|
||||
Signal* self_sg,
|
||||
int rank) {
|
||||
#ifdef USE_ROCM
|
||||
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
|
||||
if (threadIdx.x < ngpus) {
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
__hip_atomic_store(
|
||||
&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM);
|
||||
// wait until we got true from all ranks
|
||||
while (__hip_atomic_load(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT) <
|
||||
flag)
|
||||
;
|
||||
}
|
||||
__syncthreads();
|
||||
// use one thread to update flag
|
||||
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
|
||||
#else
|
||||
if (threadIdx.x < ngpus) {
|
||||
// reset flag for next time
|
||||
self_sg->end[blockIdx.x][threadIdx.x] = 0;
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
|
||||
// wait until we got true from all ranks
|
||||
while (!self_sg->start[blockIdx.x][threadIdx.x])
|
||||
;
|
||||
}
|
||||
__syncthreads();
|
||||
#endif
|
||||
}
|
||||
|
||||
// This function is meant to be used as the second or the final synchronization
|
||||
// barrier in the all reduce kernel. If it's the final synchronization barrier,
|
||||
// we don't need to make any visibility guarantees for prior memory accesses.
|
||||
template <int ngpus, bool final_sync = false>
|
||||
DINLINE void end_sync(
|
||||
const RankSignals& sg,
|
||||
#ifndef USE_ROCM
|
||||
volatile
|
||||
#endif
|
||||
Signal* self_sg,
|
||||
int rank) {
|
||||
#ifdef USE_ROCM
|
||||
__syncthreads();
|
||||
// eliminate the case that prior writes are not visible after signals become
|
||||
// visible. Note that I did not managed to make this happen through a lot of
|
||||
// testing. Might be the case that hardware provides stronger guarantee than
|
||||
// the memory model.
|
||||
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
|
||||
if (threadIdx.x < ngpus) {
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
__hip_atomic_store(
|
||||
&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
|
||||
flag,
|
||||
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
|
||||
__HIP_MEMORY_SCOPE_SYSTEM);
|
||||
// wait until we got true from all ranks
|
||||
while (__hip_atomic_load(
|
||||
&self_sg->end[blockIdx.x][threadIdx.x],
|
||||
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
|
||||
__HIP_MEMORY_SCOPE_AGENT) < flag)
|
||||
;
|
||||
}
|
||||
__syncthreads();
|
||||
// use one thread to update flag
|
||||
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
|
||||
#else
|
||||
__syncthreads();
|
||||
// eliminate the case that prior writes are not visible after signals become
|
||||
// visible. Note that I did not managed to make this happen through a lot of
|
||||
// testing. Might be the case that hardware provides stronger guarantee than
|
||||
// the memory model.
|
||||
if constexpr (!final_sync) __threadfence_system();
|
||||
if (threadIdx.x < ngpus) {
|
||||
// reset flag for next time
|
||||
self_sg->start[blockIdx.x][threadIdx.x] = 0;
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
|
||||
// wait until we got true from all ranks
|
||||
while (!self_sg->end[blockIdx.x][threadIdx.x])
|
||||
;
|
||||
}
|
||||
if constexpr (!final_sync) __syncthreads();
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename P, int ngpus, typename A>
|
||||
DINLINE P packed_reduce(const P* ptrs[], int idx) {
|
||||
A tmp = upcast(ptrs[0][idx]);
|
||||
#pragma unroll
|
||||
for (int i = 1; i < ngpus; i++) {
|
||||
packed_assign_add(tmp, upcast(ptrs[i][idx]));
|
||||
}
|
||||
return downcast<P>(tmp);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(
|
||||
RankData* _dp,
|
||||
RankSignals sg,
|
||||
#ifndef USE_ROCM
|
||||
volatile
|
||||
#endif
|
||||
Signal* self_sg,
|
||||
T* __restrict__ result,
|
||||
int rank,
|
||||
int size) {
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
// note: we don't reorder the address so the accumulation order is the same
|
||||
// for all ranks, ensuring bitwise identical results
|
||||
auto dp = *_dp;
|
||||
start_sync<ngpus>(sg, self_sg, rank);
|
||||
// do the actual reduction
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) {
|
||||
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
|
||||
}
|
||||
end_sync<ngpus, true>(sg, self_sg, rank);
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
#ifdef USE_ROCM
|
||||
DINLINE P* get_tmp_buf(Signal* sg) {
|
||||
#else
|
||||
DINLINE P* get_tmp_buf(volatile Signal* sg) {
|
||||
#endif
|
||||
return (P*)(((Signal*)sg) + 1);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(
|
||||
RankData* _dp,
|
||||
RankSignals sg,
|
||||
#ifndef USE_ROCM
|
||||
volatile
|
||||
#endif
|
||||
Signal* self_sg,
|
||||
T* __restrict__ result,
|
||||
int rank,
|
||||
int size) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
int part = size / ngpus;
|
||||
int start = rank * part;
|
||||
int end = rank == ngpus - 1 ? size : start + part;
|
||||
int largest_part = part + size % ngpus;
|
||||
const P* ptrs[ngpus];
|
||||
P* tmps[ngpus];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int target = (rank + i) % ngpus;
|
||||
ptrs[i] = (const P*)_dp->ptrs[target];
|
||||
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
||||
}
|
||||
auto tmp_out = tmps[0];
|
||||
start_sync<ngpus>(sg, self_sg, rank);
|
||||
// stage 1: reduce scatter
|
||||
for (int idx = start + tid; idx < end; idx += stride) {
|
||||
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
|
||||
}
|
||||
end_sync<ngpus>(sg, self_sg, rank);
|
||||
|
||||
// stage 2: allgather. Note: it's important to match the tid between
|
||||
// the two stages, because visibility across devices is only guaranteed
|
||||
// between threads that have the same tid. If thread i computes the sum of
|
||||
// start + i in the first stage, then thread i also gathers start + i from all
|
||||
// ranks.
|
||||
for (int idx = tid; idx < largest_part; idx += stride) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int gather_from_rank = ((rank + i) % ngpus);
|
||||
if (gather_from_rank == ngpus - 1 || idx < part) {
|
||||
int dst_idx = gather_from_rank * part + idx;
|
||||
((P*)result)[dst_idx] = tmps[i][idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using IPC_KEY = std::array<uint8_t, sizeof(hipIpcMemHandle_t)>;
|
||||
static_assert(sizeof(IPC_KEY) == sizeof(hipIpcMemHandle_t));
|
||||
static_assert(alignof(IPC_KEY) == alignof(hipIpcMemHandle_t));
|
||||
|
||||
class CustomAllreduce {
|
||||
public:
|
||||
int rank_;
|
||||
int world_size_;
|
||||
bool full_nvlink_;
|
||||
|
||||
// below are device pointers
|
||||
RankSignals sg_;
|
||||
std::unordered_map<void*, RankData*> buffers_;
|
||||
Signal* self_sg_;
|
||||
|
||||
// stores the registered device pointers from all ranks
|
||||
RankData *d_rank_data_base_, *d_rank_data_end_;
|
||||
std::vector<void*> graph_unreg_buffers_;
|
||||
// a map from IPC handles to opened IPC pointers
|
||||
std::map<IPC_KEY, char*> ipc_handles_;
|
||||
|
||||
/**
|
||||
* meta is a pointer to device metadata and temporary buffer for allreduce.
|
||||
*
|
||||
* There's a total of sizeof(Signal) of prefix before the actual data,
|
||||
* so meta + 1 points to actual temporary buffer.
|
||||
*
|
||||
* note: this class does not own any device memory. Any required buffers
|
||||
* are passed in from the constructor
|
||||
*/
|
||||
CustomAllreduce(
|
||||
Signal* meta,
|
||||
void* rank_data,
|
||||
size_t rank_data_sz,
|
||||
const hipIpcMemHandle_t* handles,
|
||||
const std::vector<int64_t>& offsets,
|
||||
int rank,
|
||||
bool full_nvlink = true)
|
||||
: rank_(rank),
|
||||
world_size_(offsets.size()),
|
||||
full_nvlink_(full_nvlink),
|
||||
self_sg_(meta),
|
||||
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
|
||||
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
Signal* rank_sg;
|
||||
if (i != rank_) {
|
||||
char* handle = open_ipc_handle(&handles[i]);
|
||||
handle += offsets[i];
|
||||
rank_sg = (Signal*)handle;
|
||||
} else {
|
||||
rank_sg = self_sg_;
|
||||
}
|
||||
sg_.signals[i] = rank_sg;
|
||||
}
|
||||
}
|
||||
|
||||
char* open_ipc_handle(const void* ipc_handle) {
|
||||
auto [it, new_handle] = ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
|
||||
if (new_handle) {
|
||||
char* ipc_ptr;
|
||||
CUDACHECK(hipIpcOpenMemHandle(
|
||||
(void**)&ipc_ptr, *((const hipIpcMemHandle_t*)ipc_handle), hipIpcMemLazyEnablePeerAccess));
|
||||
it->second = ipc_ptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
auto handle_sz = sizeof(hipIpcMemHandle_t);
|
||||
std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
|
||||
std::vector<int64_t> offsets(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto ptr = graph_unreg_buffers_[i];
|
||||
void* base_ptr;
|
||||
// note: must share the base address of each allocation, or we get wrong
|
||||
// address
|
||||
if (hipPointerGetAttribute(
|
||||
&base_ptr,
|
||||
#ifdef USE_ROCM
|
||||
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
||||
#else
|
||||
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
||||
#endif
|
||||
(hipDeviceptr_t)ptr) != hipSuccess)
|
||||
throw std::runtime_error("failed to get pointer attr");
|
||||
CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
|
||||
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
|
||||
}
|
||||
return std::make_pair(handles, offsets);
|
||||
}
|
||||
|
||||
void check_rank_data_capacity(size_t num = 1) {
|
||||
if (d_rank_data_base_ + num > d_rank_data_end_)
|
||||
throw std::runtime_error(
|
||||
"Rank data buffer is overflowed by " + std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
|
||||
}
|
||||
|
||||
void register_buffer(const std::vector<std::string>& handles, const std::vector<int64_t>& offsets, void* self) {
|
||||
check_rank_data_capacity();
|
||||
RankData data;
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
if (i != rank_) {
|
||||
char* handle = open_ipc_handle(handles[i].data());
|
||||
handle += offsets[i];
|
||||
data.ptrs[i] = handle;
|
||||
} else {
|
||||
data.ptrs[i] = self;
|
||||
}
|
||||
}
|
||||
auto d_data = d_rank_data_base_++;
|
||||
CUDACHECK(hipMemcpy(d_data, &data, sizeof(RankData), hipMemcpyHostToDevice));
|
||||
buffers_[self] = d_data;
|
||||
}
|
||||
|
||||
// note: when registering graph buffers, we intentionally choose to not
|
||||
// deduplicate the addresses. That means if the allocator reuses some
|
||||
// addresses, they will be registered again. This is to account for the remote
|
||||
// possibility of different allocation patterns between ranks. For example,
|
||||
// rank 1 may get the same input address for the second allreduce, but rank 2
|
||||
// got a different address. IPC handles have internal reference counting
|
||||
// mechanism so overhead should be small.
|
||||
void
|
||||
register_graph_buffers(const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets) {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
check_rank_data_capacity(num_buffers);
|
||||
std::vector<RankData> rank_data(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto self_ptr = graph_unreg_buffers_[i];
|
||||
auto& rd = rank_data[i];
|
||||
for (int j = 0; j < world_size_; j++) {
|
||||
if (j != rank_) {
|
||||
char* handle = open_ipc_handle(&handles[j][i * sizeof(hipIpcMemHandle_t)]);
|
||||
handle += offsets[j][i];
|
||||
rd.ptrs[j] = handle;
|
||||
} else {
|
||||
rd.ptrs[j] = self_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
CUDACHECK(hipMemcpy(d_rank_data_base_, rank_data.data(), sizeof(RankData) * num_buffers, hipMemcpyHostToDevice));
|
||||
d_rank_data_base_ += num_buffers;
|
||||
graph_unreg_buffers_.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* This is the result after careful grid search. Using 36 blocks give the best
|
||||
* or close to the best runtime on the devices I tried: A100, A10, A30, T4,
|
||||
* V100. You'll notice that NCCL kernels also only take a small amount of SMs.
|
||||
* Not quite sure the underlying reason, but my guess is that too many SMs
|
||||
* will cause contention on NVLink bus.
|
||||
*/
|
||||
template <typename T>
|
||||
void allreduce(
|
||||
hipStream_t stream,
|
||||
T* input,
|
||||
T* output,
|
||||
int size,
|
||||
#ifndef USE_ROCM
|
||||
int threads = 512,
|
||||
int block_limit = 36){
|
||||
#else
|
||||
int threads = 512,
|
||||
int block_limit = 16) {
|
||||
#endif
|
||||
auto d = packed_t<T>::P::size;
|
||||
if (size % d != 0)
|
||||
throw std::runtime_error(
|
||||
"custom allreduce currently requires input length to be multiple "
|
||||
"of " +
|
||||
std::to_string(d));
|
||||
if (block_limit > kMaxBlocks)
|
||||
throw std::runtime_error(
|
||||
"max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit));
|
||||
|
||||
RankData* ptrs;
|
||||
hipStreamCaptureStatus status;
|
||||
CUDACHECK(hipStreamIsCapturing(stream, &status));
|
||||
if (status == hipStreamCaptureStatusActive) {
|
||||
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
|
||||
graph_unreg_buffers_.push_back(input);
|
||||
} else {
|
||||
auto it = buffers_.find(input);
|
||||
if (it == buffers_.end())
|
||||
throw std::runtime_error(
|
||||
"buffer address " + std::to_string(reinterpret_cast<uint64_t>(input)) + " is not registered!");
|
||||
ptrs = it->second;
|
||||
}
|
||||
|
||||
size /= d;
|
||||
auto bytes = size * sizeof(typename packed_t<T>::P);
|
||||
int blocks = ::min(block_limit, (size + threads - 1) / threads);
|
||||
#define KL(ngpus, name) \
|
||||
hipLaunchKernelGGL( \
|
||||
(name<T, ngpus>), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_, size);
|
||||
#define REDUCE_CASE(ngpus) \
|
||||
case ngpus: { \
|
||||
if (world_size_ == 2) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else if (full_nvlink_) { \
|
||||
if ((world_size_ <= 4 && bytes < 512 * 1024) || (world_size_ <= 8 && bytes < 256 * 1024)) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else { \
|
||||
KL(ngpus, cross_device_reduce_2stage); \
|
||||
} \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
|
||||
switch (world_size_) {
|
||||
REDUCE_CASE(2)
|
||||
REDUCE_CASE(4)
|
||||
REDUCE_CASE(6)
|
||||
REDUCE_CASE(8)
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
|
||||
"gpus = " +
|
||||
std::to_string(world_size_));
|
||||
}
|
||||
#undef REDUCE_CASE
|
||||
#undef KL
|
||||
}
|
||||
|
||||
~CustomAllreduce() {
|
||||
for (auto [_, ptr] : ipc_handles_) {
|
||||
CUDACHECK(hipIpcCloseMemHandle(ptr));
|
||||
}
|
||||
}
|
||||
}; // namespace sglang
|
||||
/**
|
||||
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
|
||||
a template instantiation:
|
||||
* template void sglang::CustomAllreduce::allreduce<half>(hipStream_t, half *,
|
||||
half *, int, int, int);
|
||||
*/
|
||||
} // namespace sglang
|
||||
140
sgl-kernel/csrc/allreduce/mscclpp_allreduce.cu
Normal file
140
sgl-kernel/csrc/allreduce/mscclpp_allreduce.cu
Normal file
@@ -0,0 +1,140 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/all.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "mscclpp_allreduce.cuh"
|
||||
|
||||
enum MscclContextSelection {
|
||||
MSCCL1NODELL = 1,
|
||||
MSCCL2NODELL = 2,
|
||||
};
|
||||
|
||||
class MscclContext {
|
||||
public:
|
||||
MscclContextSelection selection_;
|
||||
std::shared_ptr<sglang::Msccl1NodeLLcontext> msccl_1nodeLL_context;
|
||||
std::shared_ptr<sglang::Msccl2NodeLLcontext> msccl_2nodeLL_context;
|
||||
MscclContext(MscclContextSelection selection) : selection_(selection) {}
|
||||
template <typename T>
|
||||
void allreduce(
|
||||
cudaStream_t stream, T* input, T* output, const size_t input_numel, int threads = 512, int block_limit = 21) {
|
||||
if (selection_ == MSCCL1NODELL) {
|
||||
msccl_1nodeLL_context->allreduce<T>(stream, input, output, input_numel, threads, block_limit);
|
||||
} else if (selection_ == MSCCL2NODELL) {
|
||||
msccl_2nodeLL_context->allreduce<T>(stream, input, output, input_numel, threads, block_limit);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using fptr_t = int64_t;
|
||||
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||
|
||||
torch::Tensor _unique_id2tensor(const mscclpp::UniqueId& unique_id) {
|
||||
auto options = torch::TensorOptions().dtype(torch::kByte).device(torch::kCPU);
|
||||
auto tensor = torch::empty({static_cast<int64_t>(unique_id.size())}, options);
|
||||
std::memcpy(tensor.data_ptr<uint8_t>(), unique_id.data(), unique_id.size());
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// Function to convert vector of int32_t back to array of uint8_t
|
||||
mscclpp::UniqueId _tensor2unique_id(const torch::Tensor& tensor) {
|
||||
mscclpp::UniqueId unique_id;
|
||||
std::memcpy(unique_id.data(), tensor.data_ptr<uint8_t>(), unique_id.size());
|
||||
return unique_id;
|
||||
}
|
||||
|
||||
torch::Tensor mscclpp_generate_unique_id() {
|
||||
mscclpp::UniqueId unique_id = mscclpp::TcpBootstrap::createUniqueId();
|
||||
return _unique_id2tensor(unique_id);
|
||||
}
|
||||
|
||||
fptr_t mscclpp_init_context(
|
||||
const torch::Tensor& unique_id,
|
||||
const int64_t rank,
|
||||
const int64_t world_size,
|
||||
torch::Tensor& scratch,
|
||||
torch::Tensor& put_buffer,
|
||||
const int64_t nranks_per_node,
|
||||
const std::vector<int64_t>& rank_to_node,
|
||||
const std::vector<int64_t>& rank_to_ib,
|
||||
const int64_t context_selection) {
|
||||
MscclContext* context_ptr = new MscclContext(static_cast<MscclContextSelection>(context_selection));
|
||||
mscclpp::UniqueId uid = _tensor2unique_id(unique_id);
|
||||
if (context_selection == MSCCL1NODELL) {
|
||||
void* scratch_ptr = reinterpret_cast<void*>(scratch.data_ptr());
|
||||
const size_t scratch_bytes = scratch.numel() * scratch.element_size();
|
||||
context_ptr->msccl_1nodeLL_context = std::make_shared<sglang::Msccl1NodeLLcontext>(
|
||||
uid, rank, world_size, scratch_ptr, scratch_bytes, nranks_per_node, rank_to_node, rank_to_ib);
|
||||
} else if (context_selection == MSCCL2NODELL) {
|
||||
void* scratch_ptr = reinterpret_cast<void*>(scratch.data_ptr());
|
||||
const size_t scratch_bytes = scratch.numel() * scratch.element_size();
|
||||
void* put_buffer_ptr = reinterpret_cast<void*>(put_buffer.data_ptr());
|
||||
const size_t put_buffer_bytes = put_buffer.numel() * put_buffer.element_size();
|
||||
context_ptr->msccl_2nodeLL_context = std::make_shared<sglang::Msccl2NodeLLcontext>(
|
||||
uid,
|
||||
rank,
|
||||
world_size,
|
||||
scratch_ptr,
|
||||
scratch_bytes,
|
||||
put_buffer_ptr,
|
||||
put_buffer_bytes,
|
||||
nranks_per_node,
|
||||
rank_to_node,
|
||||
rank_to_ib);
|
||||
} else {
|
||||
throw std::runtime_error("invalid context selection");
|
||||
}
|
||||
return (fptr_t)context_ptr;
|
||||
}
|
||||
|
||||
bool _mscclpp_is_weak_contiguous(torch::Tensor& t) {
|
||||
return t.is_contiguous() ||
|
||||
(t.storage().nbytes() - t.storage_offset() * t.element_size() == t.numel() * t.element_size());
|
||||
}
|
||||
void mscclpp_allreduce(fptr_t _context, torch::Tensor& inp, torch::Tensor& out, int64_t nthreads, int64_t nblocks) {
|
||||
MscclContext* context = reinterpret_cast<MscclContext*>(_context);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
TORCH_CHECK(_mscclpp_is_weak_contiguous(out));
|
||||
TORCH_CHECK(_mscclpp_is_weak_contiguous(inp));
|
||||
switch (out.scalar_type()) {
|
||||
case at::ScalarType::Float: {
|
||||
context->allreduce<float>(
|
||||
stream,
|
||||
reinterpret_cast<float*>(inp.data_ptr()),
|
||||
reinterpret_cast<float*>(out.data_ptr()),
|
||||
inp.numel(),
|
||||
nthreads,
|
||||
nblocks);
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Half: {
|
||||
context->allreduce<half>(
|
||||
stream,
|
||||
reinterpret_cast<half*>(inp.data_ptr()),
|
||||
reinterpret_cast<half*>(out.data_ptr()),
|
||||
inp.numel(),
|
||||
nthreads,
|
||||
nblocks);
|
||||
break;
|
||||
}
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
case at::ScalarType::BFloat16: {
|
||||
context->allreduce<__nv_bfloat16>(
|
||||
stream,
|
||||
reinterpret_cast<__nv_bfloat16*>(inp.data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16*>(out.data_ptr()),
|
||||
inp.numel(),
|
||||
nthreads,
|
||||
nblocks);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
default:
|
||||
throw std::runtime_error("custom allreduce only supports float32, float16 and bfloat16");
|
||||
}
|
||||
}
|
||||
779
sgl-kernel/csrc/allreduce/mscclpp_allreduce.cuh
Normal file
779
sgl-kernel/csrc/allreduce/mscclpp_allreduce.cuh
Normal file
@@ -0,0 +1,779 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
#pragma once
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_fp16.h>
|
||||
#else
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#endif
|
||||
|
||||
#include <mscclpp/concurrency_device.hpp>
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/memory_channel.hpp>
|
||||
#include <mscclpp/memory_channel_device.hpp>
|
||||
#include <mscclpp/nvls_device.hpp>
|
||||
#include <mscclpp/port_channel.hpp>
|
||||
#include <mscclpp/port_channel_device.hpp>
|
||||
|
||||
// comment this for test_mscclpp_allreduce.cu
|
||||
#include "utils.h"
|
||||
|
||||
namespace sglang {
|
||||
|
||||
__device__ mscclpp::DeviceSyncer deviceSyncer;
|
||||
__device__ mscclpp::DeviceSyncer allGatherDeviceSyncer;
|
||||
__device__ mscclpp::DeviceSyncer reduceScatterDeviceSyncer;
|
||||
__device__ mscclpp::DeviceSyncer ibDeviceSyncer;
|
||||
|
||||
template <typename To, typename From>
|
||||
__forceinline__ __device__ To bit_cast(const From& src) {
|
||||
static_assert(sizeof(To) == sizeof(From), "Size mismatch for bit_cast");
|
||||
|
||||
union {
|
||||
From f;
|
||||
To t;
|
||||
} u;
|
||||
u.f = src;
|
||||
return u.t;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T add_elements(T a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ __half2 add_elements(__half2 a, __half2 b) {
|
||||
return __hadd2(a, b);
|
||||
}
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
template <>
|
||||
__forceinline__ __device__ __nv_bfloat162 add_elements(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||
return __hadd2(a, b);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ int4 add_vectors_helper(int4 a, int4 b) {
|
||||
int4 ret;
|
||||
ret.w = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.w), bit_cast<T, int>(b.w)));
|
||||
ret.x = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.x), bit_cast<T, int>(b.x)));
|
||||
ret.y = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.y), bit_cast<T, int>(b.y)));
|
||||
ret.z = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.z), bit_cast<T, int>(b.z)));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ int4 add_vectors(int4 a, int4 b) {
|
||||
return add_vectors_helper<T>(a, b);
|
||||
}
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
template <>
|
||||
__forceinline__ __device__ int4 add_vectors<__nv_bfloat16>(int4 a, int4 b) {
|
||||
return add_vectors_helper<__nv_bfloat162>(a, b);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ int4 add_vectors<__half>(int4 a, int4 b) {
|
||||
return add_vectors_helper<__half2>(a, b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ uint2 add_vectors_helper(uint2 a, uint2 b) {
|
||||
uint2 ret;
|
||||
ret.x = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.x), bit_cast<T, int>(b.x)));
|
||||
ret.y = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.y), bit_cast<T, int>(b.y)));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ uint2 add_vectors(uint2 a, uint2 b) {
|
||||
return add_vectors_helper<T>(a, b);
|
||||
}
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
template <>
|
||||
__forceinline__ __device__ uint2 add_vectors<__nv_bfloat16>(uint2 a, uint2 b) {
|
||||
return add_vectors_helper<__nv_bfloat162>(a, b);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ uint2 add_vectors<__half>(uint2 a, uint2 b) {
|
||||
return add_vectors_helper<__half2>(a, b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ int add_vectors_helper(int a, int b) {
|
||||
return bit_cast<int, T>(add_elements(bit_cast<T, int>(a), bit_cast<T, int>(b)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ int add_vectors(int a, int b) {
|
||||
return add_vectors_helper<T>(a, b);
|
||||
}
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
template <>
|
||||
__forceinline__ __device__ int add_vectors<__nv_bfloat16>(int a, int b) {
|
||||
return add_vectors_helper<__nv_bfloat162>(a, b);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ int add_vectors<__half>(int a, int b) {
|
||||
return add_vectors_helper<__half2>(a, b);
|
||||
}
|
||||
|
||||
// -------------------------------------------------------
|
||||
// allreduce_LL_1node using LLPacket, origin allreduce2
|
||||
// -------------------------------------------------------
|
||||
|
||||
__device__ uint64_t globalFlag = 1;
|
||||
|
||||
template <typename TYPE>
|
||||
__global__ void __launch_bounds__(1024, 1) allreduce_LL_1node(
|
||||
mscclpp::MemoryChannelDeviceHandle* memChans,
|
||||
TYPE* buff,
|
||||
TYPE* scratch,
|
||||
void* resultBuff,
|
||||
int rank,
|
||||
int worldSize,
|
||||
size_t nelems) {
|
||||
nelems = nelems / (sizeof(int) / sizeof(TYPE));
|
||||
// This version of allreduce only works for single nodes
|
||||
const int nPeers = worldSize - 1;
|
||||
const size_t nPkts = nelems / 2;
|
||||
const int nelemsPerRank = nelems / worldSize;
|
||||
const int nPktsPerRank = nelemsPerRank / 2;
|
||||
// flag for packets. Initially 1
|
||||
const uint32_t flag = (uint32_t)globalFlag;
|
||||
// thread block & channel info
|
||||
const int nBlocksPerPeer = gridDim.x / nPeers;
|
||||
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
|
||||
const int peerIdx = blockIdx.x / nBlocksPerPeer;
|
||||
const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1;
|
||||
mscclpp::MemoryChannelDeviceHandle memChan = memChans[peerIdx];
|
||||
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
|
||||
// double buffering
|
||||
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket);
|
||||
void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset);
|
||||
size_t scratchOffset = scratchBaseOffset + rank * nPktsPerRank * sizeof(mscclpp::LLPacket);
|
||||
size_t scratchResultOffset =
|
||||
(flag & 1) ? 2 * nPkts * sizeof(mscclpp::LLPacket) : 3 * nPkts * sizeof(mscclpp::LLPacket);
|
||||
size_t srcOffset = remoteRank * nelemsPerRank * sizeof(int);
|
||||
uint2* src = (uint2*)((char*)buff + rank * nelemsPerRank * sizeof(int));
|
||||
uint2* dst = (uint2*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int));
|
||||
|
||||
// step 1: write to scratch buffer
|
||||
memChan.putPackets(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer, flag);
|
||||
// step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer
|
||||
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) {
|
||||
uint2 data = make_uint2(0, 0);
|
||||
for (int index = 0; index < nPeers; index++) {
|
||||
const int remoteRank = index < rank ? index : index + 1;
|
||||
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerRank;
|
||||
uint2 val = dstPkt[idx].read(flag);
|
||||
data = add_vectors<TYPE>(val, data);
|
||||
}
|
||||
data = add_vectors<TYPE>(data, src[idx]);
|
||||
dst[idx] = data;
|
||||
|
||||
mscclpp::LLPacket packet;
|
||||
packet.data1 = data.x;
|
||||
packet.flag1 = flag;
|
||||
packet.data2 = data.y;
|
||||
packet.flag2 = flag;
|
||||
size_t offset = scratchResultOffset / sizeof(mscclpp::LLPacket) + (idx + rank * nPktsPerRank);
|
||||
for (int index = 0; index < nPeers; index++) {
|
||||
memChans[index].write(offset, packet);
|
||||
}
|
||||
}
|
||||
// step 3: get data result from scratch buffer
|
||||
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)((char*)scratch + scratchResultOffset);
|
||||
const int dstOffset = remoteRank * nPktsPerRank;
|
||||
uint2* result = (uint2*)((char*)resultBuff + remoteRank * nelemsPerRank * sizeof(int));
|
||||
for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * nBlocksPerPeer) {
|
||||
uint2 data = dstPkt[idx + dstOffset].read(flag);
|
||||
result[idx].x = data.x;
|
||||
result[idx].y = data.y;
|
||||
}
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0) {
|
||||
globalFlag += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------
|
||||
// allreduce_LL_2node using LLPacket, origin allreduce5
|
||||
// -------------------------------------------------------
|
||||
|
||||
template <typename TYPE>
|
||||
__global__ void __launch_bounds__(1024, 1) allreduce_LL_2node(
|
||||
mscclpp::MemoryChannelDeviceHandle* memChans,
|
||||
mscclpp::PortChannelDeviceHandle* portChans,
|
||||
TYPE* buff,
|
||||
TYPE* scratch,
|
||||
TYPE* putBuff,
|
||||
TYPE* resultBuff,
|
||||
int rank,
|
||||
int nRanksPerNode,
|
||||
int worldSize,
|
||||
size_t nelems) {
|
||||
nelems = nelems / (sizeof(int) / sizeof(TYPE));
|
||||
// This version of allreduce only works for single nodes
|
||||
const int nPeersInNode = nRanksPerNode - 1;
|
||||
const int nPkts = nelems / 2;
|
||||
const int nelemsPerLocalRank = nelems / nRanksPerNode;
|
||||
const int nPktsPerLocalRank = nelemsPerLocalRank / 2;
|
||||
const int localRankId = rank % nRanksPerNode;
|
||||
// flag for packets. Initially 1
|
||||
const uint32_t flag = (uint32_t)globalFlag;
|
||||
// thread block & channel info
|
||||
const int nBlocksPerPeer = gridDim.x / nPeersInNode;
|
||||
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
|
||||
const int peerIdx = blockIdx.x / nBlocksPerPeer;
|
||||
const int remoteRankIdx = peerIdx < localRankId ? peerIdx : peerIdx + 1;
|
||||
mscclpp::MemoryChannelDeviceHandle memChan = memChans[peerIdx];
|
||||
mscclpp::PortChannelDeviceHandle portChan = portChans[localRankId];
|
||||
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
|
||||
// double buffering
|
||||
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket);
|
||||
size_t putBaseOffset = (flag & 1) ? 0 : nPktsPerLocalRank * sizeof(mscclpp::LLPacket);
|
||||
void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset);
|
||||
size_t scratchOffset = scratchBaseOffset + localRankId * nPktsPerLocalRank * sizeof(mscclpp::LLPacket);
|
||||
size_t scratchResultOffset =
|
||||
(flag & 1) ? 2 * nPkts * sizeof(mscclpp::LLPacket) : 3 * nPkts * sizeof(mscclpp::LLPacket);
|
||||
size_t srcOffset = remoteRankIdx * nelemsPerLocalRank * sizeof(int);
|
||||
uint2* src = (uint2*)((char*)buff + localRankId * nelemsPerLocalRank * sizeof(int));
|
||||
uint2* dst = (uint2*)((char*)resultBuff + localRankId * nelemsPerLocalRank * sizeof(int));
|
||||
|
||||
// step 1: write to scratch buffer
|
||||
if (nRanksPerNode > 1) {
|
||||
memChan.putPackets(
|
||||
scratchOffset, srcOffset, nelemsPerLocalRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer, flag);
|
||||
}
|
||||
// step 2: get data from scratch buffer, do local reduce-scatter in each node.
|
||||
mscclpp::LLPacket* putPkt = (mscclpp::LLPacket*)((char*)putBuff + putBaseOffset);
|
||||
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerLocalRank; idx += blockDim.x * gridDim.x) {
|
||||
uint2 data = make_uint2(0, 0);
|
||||
for (int index = 0; index < nPeersInNode; index++) {
|
||||
const int remoteRank = index < localRankId ? index : index + 1;
|
||||
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerLocalRank;
|
||||
uint2 val = dstPkt[idx].read(flag);
|
||||
data = add_vectors<TYPE>(val, data);
|
||||
}
|
||||
data = add_vectors<TYPE>(data, src[idx]);
|
||||
putPkt[idx].write(data.x, data.y, flag);
|
||||
dst[idx] = data;
|
||||
}
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
// step 3. send local reduced data to remote node.
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0) {
|
||||
portChan.put(scratchOffset, putBaseOffset, nPktsPerLocalRank * sizeof(mscclpp::LLPacket));
|
||||
if ((flag & 63) == 0) {
|
||||
portChan.flush();
|
||||
}
|
||||
}
|
||||
// step 4. try to read the data from scratch buffer and write to local peers
|
||||
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + localRankId * nPktsPerLocalRank;
|
||||
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerLocalRank; idx += blockDim.x * gridDim.x) {
|
||||
uint2 res = dst[idx];
|
||||
uint2 val = dstPkt[idx].read(flag);
|
||||
res = add_vectors<TYPE>(res, val);
|
||||
|
||||
mscclpp::LLPacket packet;
|
||||
packet.data1 = res.x;
|
||||
packet.flag1 = flag;
|
||||
packet.data2 = res.y;
|
||||
packet.flag2 = flag;
|
||||
size_t offset = scratchResultOffset / sizeof(mscclpp::LLPacket) + (idx + localRankId * nPktsPerLocalRank);
|
||||
for (int index = 0; index < nPeersInNode; index++) {
|
||||
memChans[index].write(offset, packet);
|
||||
}
|
||||
dst[idx] = res;
|
||||
}
|
||||
|
||||
// step 5: get data result from scratch buffer
|
||||
dstPkt = (mscclpp::LLPacket*)((char*)scratch + scratchResultOffset);
|
||||
const int dstOffset = remoteRankIdx * nPktsPerLocalRank;
|
||||
uint2* result = (uint2*)((char*)resultBuff + remoteRankIdx * nelemsPerLocalRank * sizeof(int));
|
||||
if (nRanksPerNode > 1) {
|
||||
for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerLocalRank;
|
||||
idx += blockDim.x * nBlocksPerPeer) {
|
||||
uint2 data = dstPkt[idx + dstOffset].read(flag);
|
||||
result[idx] = data;
|
||||
}
|
||||
}
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0) {
|
||||
globalFlag += 1;
|
||||
}
|
||||
}
|
||||
|
||||
static const mscclpp::Transport IBs[] = {
|
||||
mscclpp::Transport::IB0,
|
||||
mscclpp::Transport::IB1,
|
||||
mscclpp::Transport::IB2,
|
||||
mscclpp::Transport::IB3,
|
||||
mscclpp::Transport::IB4,
|
||||
mscclpp::Transport::IB5,
|
||||
mscclpp::Transport::IB6,
|
||||
mscclpp::Transport::IB7};
|
||||
|
||||
class MscclCommGroup {
|
||||
public:
|
||||
std::shared_ptr<mscclpp::Communicator> comm_;
|
||||
const size_t rank_;
|
||||
const size_t world_size_;
|
||||
const std::vector<int64_t> rank_to_node_;
|
||||
const std::vector<int64_t> rank_to_ib_;
|
||||
MscclCommGroup(
|
||||
mscclpp::UniqueId unique_id,
|
||||
const size_t rank,
|
||||
const size_t world_size,
|
||||
const std::vector<int64_t>& rank_to_node,
|
||||
const std::vector<int64_t>& rank_to_ib)
|
||||
: rank_(rank), world_size_(world_size), rank_to_node_(rank_to_node), rank_to_ib_(rank_to_ib) {
|
||||
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(rank, world_size);
|
||||
bootstrap->initialize(unique_id);
|
||||
comm_ = std::make_shared<mscclpp::Communicator>(bootstrap);
|
||||
}
|
||||
template <typename T>
|
||||
void allreduce(cudaStream_t stream, T* output, size_t input_numel, int threads = 512, int block_limit = 21) {
|
||||
throw std::runtime_error("you should not call allreduce of a base context");
|
||||
}
|
||||
bool is_same_node(int r1, int r2) {
|
||||
return rank_to_node_[r1] == rank_to_node_[r2];
|
||||
}
|
||||
|
||||
void make_connection(
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& same_node_connections,
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& cross_node_connections) {
|
||||
same_node_connections.clear();
|
||||
cross_node_connections.clear();
|
||||
std::unordered_map<int, mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> conn_futures;
|
||||
for (int r = 0; r < world_size_; ++r) {
|
||||
if (r == rank_) continue;
|
||||
mscclpp::Transport transport = is_same_node(r, rank_) ? mscclpp::Transport::CudaIpc : IBs[rank_to_ib_[r]];
|
||||
conn_futures.emplace(r, comm_->connectOnSetup(r, 0, transport));
|
||||
}
|
||||
comm_->setup();
|
||||
for (int r = 0; r < world_size_; ++r) {
|
||||
if (r == rank_) continue;
|
||||
if (is_same_node(r, rank_)) {
|
||||
same_node_connections.emplace(r, conn_futures[r].get());
|
||||
} else {
|
||||
cross_node_connections.emplace(r, conn_futures[r].get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void make_memory_channels_with_scratch(
|
||||
void* tensor_ptr,
|
||||
const size_t tensor_bytes,
|
||||
void* scratch_ptr,
|
||||
const size_t scratch_bytes,
|
||||
const std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>>& semaphores,
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory>& registered_memories,
|
||||
std::unordered_map<int, mscclpp::MemoryChannel>& channels) {
|
||||
channels.clear();
|
||||
make_semaphores<mscclpp::MemoryDevice2DeviceSemaphore>(connections, semaphores);
|
||||
register_tensor_with_connections(scratch_ptr, scratch_bytes, connections, registered_memories);
|
||||
for (const auto& [peer, _] : connections) {
|
||||
channels.emplace(
|
||||
peer, mscclpp::MemoryChannel(semaphores[peer], registered_memories[peer], tensor_ptr, scratch_ptr));
|
||||
}
|
||||
}
|
||||
void make_port_channels_with_scratch(
|
||||
std::shared_ptr<mscclpp::ProxyService> proxyService,
|
||||
void* tensor_ptr,
|
||||
const size_t tensor_bytes,
|
||||
void* scratch_ptr,
|
||||
const size_t scratch_bytes,
|
||||
const std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Host2DeviceSemaphore>>& semaphores,
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory>& registered_memories,
|
||||
std::unordered_map<int, mscclpp::PortChannel>& channels) {
|
||||
channels.clear();
|
||||
make_semaphores<mscclpp::Host2DeviceSemaphore>(connections, semaphores);
|
||||
|
||||
mscclpp::TransportFlags flags;
|
||||
for (const auto& [_, conn] : connections) {
|
||||
flags |= conn->transport();
|
||||
}
|
||||
auto local_reg_memory = comm_->registerMemory(tensor_ptr, tensor_bytes, flags);
|
||||
|
||||
register_tensor_with_connections(scratch_ptr, scratch_bytes, connections, registered_memories);
|
||||
std::unordered_map<int, mscclpp::SemaphoreId> semaphore_ids;
|
||||
std::unordered_map<int, size_t> memory_ids;
|
||||
memory_ids[rank_] = proxyService->addMemory(local_reg_memory);
|
||||
for (const auto& [peer, memory] : registered_memories) {
|
||||
if (peer == rank_) continue;
|
||||
memory_ids[peer] = proxyService->addMemory(memory);
|
||||
}
|
||||
for (const auto& [peer, semaphore] : semaphores) {
|
||||
semaphore_ids[peer] = proxyService->addSemaphore(semaphore);
|
||||
}
|
||||
|
||||
for (const auto& [peer, _] : connections) {
|
||||
channels.emplace(peer, proxyService->portChannel(semaphore_ids[peer], memory_ids[peer], memory_ids[rank_]));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SemaphoreType>
|
||||
void make_semaphores(
|
||||
const std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::unordered_map<int, std::shared_ptr<SemaphoreType>>& semaphores) {
|
||||
semaphores.clear();
|
||||
for (const auto& [peer, conn] : connections) {
|
||||
semaphores[peer] = std::make_shared<SemaphoreType>(*comm_, conn);
|
||||
}
|
||||
comm_->setup();
|
||||
}
|
||||
|
||||
void register_tensor_with_connections(
|
||||
void* tensor_ptr,
|
||||
size_t tensor_bytes,
|
||||
const std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory>& registered_memories) {
|
||||
registered_memories.clear();
|
||||
mscclpp::TransportFlags all_transports;
|
||||
for (const auto& [_, connection] : connections) {
|
||||
all_transports |= connection->transport();
|
||||
}
|
||||
mscclpp::RegisteredMemory buf_reg_mem = comm_->registerMemory(tensor_ptr, tensor_bytes, all_transports);
|
||||
registered_memories[rank_] = buf_reg_mem;
|
||||
|
||||
std::unordered_map<int, mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remote_mem_futures;
|
||||
for (const auto& [r, connection] : connections) {
|
||||
comm_->sendMemoryOnSetup(buf_reg_mem, r, 0);
|
||||
auto remoteMemory = comm_->recvMemoryOnSetup(r, 0);
|
||||
remote_mem_futures.emplace(r, remoteMemory);
|
||||
}
|
||||
comm_->setup();
|
||||
for (auto& [r, mem_feature] : remote_mem_futures) {
|
||||
registered_memories.emplace(r, mem_feature.get());
|
||||
}
|
||||
}
|
||||
|
||||
void make_device_memory_handle_base_on_new_ptr(
|
||||
const std::unordered_map<int, mscclpp::MemoryChannel>& old_memory_channels,
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory>& registered_sm_memories,
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>>& memory_semaphores,
|
||||
std::unordered_map<int, mscclpp::MemoryChannel>& memory_channels,
|
||||
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle>& device_memory_handle,
|
||||
void* input,
|
||||
void* scratch,
|
||||
const cudaStream_t stream) {
|
||||
memory_channels.clear();
|
||||
for (const auto& [peer, channel] : old_memory_channels) {
|
||||
memory_channels.emplace(
|
||||
peer, mscclpp::MemoryChannel(memory_semaphores[peer], registered_sm_memories[peer], input, scratch));
|
||||
}
|
||||
std::vector<mscclpp::MemoryChannel> memory_channels_list;
|
||||
for (int r = 0; r < world_size_; r++) {
|
||||
if (r == rank_) continue;
|
||||
if (is_same_node(r, rank_)) {
|
||||
memory_channels_list.push_back(memory_channels[r]);
|
||||
}
|
||||
}
|
||||
std::vector<mscclpp::MemoryChannelDeviceHandle> memory_channel_handlers(memory_channels_list.size());
|
||||
std::transform(
|
||||
memory_channels_list.begin(),
|
||||
memory_channels_list.end(),
|
||||
memory_channel_handlers.begin(),
|
||||
[](const mscclpp::MemoryChannel& channel) { return channel.deviceHandle(); });
|
||||
mscclpp::gpuMemcpyAsync<mscclpp::MemoryChannelDeviceHandle>(
|
||||
device_memory_handle.data(),
|
||||
memory_channel_handlers.data(),
|
||||
memory_channel_handlers.size(),
|
||||
stream,
|
||||
cudaMemcpyHostToDevice);
|
||||
}
|
||||
};
|
||||
|
||||
class Msccl1NodeLLcontext {
|
||||
private:
|
||||
std::shared_ptr<MscclCommGroup> comm_group_ = nullptr;
|
||||
void* scratch_;
|
||||
const size_t scratch_bytes_;
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> same_node_connections_;
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> cross_node_connections_;
|
||||
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory> registered_sm_memories_;
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memory_semaphores_;
|
||||
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels_;
|
||||
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle> d_memHandles_;
|
||||
std::unordered_map<void*, std::unordered_map<int, mscclpp::MemoryChannel>> input_ptr2memory_channels_;
|
||||
std::unordered_map<void*, mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle>> input_ptr2d_memHandles_;
|
||||
cudaStream_t h2d_stream;
|
||||
const size_t nranks_per_node_;
|
||||
|
||||
public:
|
||||
Msccl1NodeLLcontext(
|
||||
mscclpp::UniqueId unique_id,
|
||||
const size_t rank,
|
||||
const size_t world_size,
|
||||
void* scratch,
|
||||
const size_t scratch_bytes,
|
||||
const size_t nranks_per_node,
|
||||
const std::vector<int64_t>& rank_to_node,
|
||||
const std::vector<int64_t>& rank_to_ib)
|
||||
: scratch_(scratch),
|
||||
scratch_bytes_(scratch_bytes),
|
||||
nranks_per_node_(nranks_per_node),
|
||||
d_memHandles_(nranks_per_node - 1) {
|
||||
CHECK_CUDA_SUCCESS(cudaStreamCreateWithFlags(&h2d_stream, cudaStreamNonBlocking));
|
||||
comm_group_ = std::make_shared<MscclCommGroup>(unique_id, rank, world_size, rank_to_node, rank_to_ib);
|
||||
comm_group_->make_connection(same_node_connections_, cross_node_connections_);
|
||||
comm_group_->make_memory_channels_with_scratch(
|
||||
scratch_,
|
||||
scratch_bytes_,
|
||||
scratch_,
|
||||
scratch_bytes_,
|
||||
same_node_connections_,
|
||||
memory_semaphores_,
|
||||
registered_sm_memories_,
|
||||
memory_channels_);
|
||||
std::vector<mscclpp::MemoryChannel> memory_channels_list;
|
||||
for (int r = 0; r < comm_group_->world_size_; r++) {
|
||||
if (r == comm_group_->rank_) continue;
|
||||
memory_channels_list.push_back(memory_channels_[r]);
|
||||
}
|
||||
std::vector<mscclpp::MemoryChannelDeviceHandle> memory_channel_handlers(memory_channels_list.size());
|
||||
std::transform(
|
||||
memory_channels_list.begin(),
|
||||
memory_channels_list.end(),
|
||||
memory_channel_handlers.begin(),
|
||||
[](const mscclpp::MemoryChannel& channel) { return channel.deviceHandle(); });
|
||||
mscclpp::gpuMemcpy<mscclpp::MemoryChannelDeviceHandle>(
|
||||
d_memHandles_.data(), memory_channel_handlers.data(), memory_channel_handlers.size(), cudaMemcpyHostToDevice);
|
||||
}
|
||||
|
||||
~Msccl1NodeLLcontext() {
|
||||
CHECK_CUDA_SUCCESS(cudaStreamDestroy(h2d_stream));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void allreduce(cudaStream_t stream, T* input, T* output, size_t input_numel, int nthreads = 512, int nblocks = 21) {
|
||||
dim3 nthrs(nthreads);
|
||||
dim3 nblks(nblocks);
|
||||
cudaStreamCaptureStatus capturing_status;
|
||||
CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &capturing_status));
|
||||
mscclpp::MemoryChannelDeviceHandle* memChans;
|
||||
if (capturing_status != cudaStreamCaptureStatusActive) {
|
||||
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels;
|
||||
comm_group_->make_device_memory_handle_base_on_new_ptr(
|
||||
memory_channels_,
|
||||
registered_sm_memories_,
|
||||
memory_semaphores_,
|
||||
memory_channels,
|
||||
d_memHandles_,
|
||||
input,
|
||||
scratch_,
|
||||
h2d_stream);
|
||||
CHECK_CUDA_SUCCESS(cudaStreamSynchronize(h2d_stream));
|
||||
memChans = d_memHandles_.data();
|
||||
} else {
|
||||
void* input_void_ptr = reinterpret_cast<void*>(input);
|
||||
if (input_ptr2d_memHandles_.find(input_void_ptr) == input_ptr2d_memHandles_.end()) {
|
||||
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels;
|
||||
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle> device_memory_handle(comm_group_->world_size_ - 1);
|
||||
comm_group_->make_device_memory_handle_base_on_new_ptr(
|
||||
memory_channels_,
|
||||
registered_sm_memories_,
|
||||
memory_semaphores_,
|
||||
memory_channels,
|
||||
device_memory_handle,
|
||||
input,
|
||||
scratch_,
|
||||
h2d_stream);
|
||||
input_ptr2memory_channels_.emplace(input_void_ptr, memory_channels);
|
||||
input_ptr2d_memHandles_.emplace(input_void_ptr, device_memory_handle);
|
||||
}
|
||||
auto it = input_ptr2d_memHandles_.find(input_void_ptr);
|
||||
memChans = it->second.data();
|
||||
}
|
||||
allreduce_LL_1node<T><<<nblks, nthrs, 0, stream>>>(
|
||||
memChans, (T*)input, (T*)scratch_, output, comm_group_->rank_, comm_group_->world_size_, input_numel);
|
||||
|
||||
cudaError_t status = cudaGetLastError();
|
||||
if (status != cudaSuccess) {
|
||||
printf("rank: %lu failed to launch allreduce_LL_1node: %s\n", comm_group_->rank_, cudaGetErrorString(status));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class Msccl2NodeLLcontext {
|
||||
private:
|
||||
std::shared_ptr<MscclCommGroup> comm_group_ = nullptr;
|
||||
void* scratch_;
|
||||
const size_t scratch_bytes_;
|
||||
void* put_buffer_;
|
||||
const size_t put_buffer_bytes_;
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> same_node_connections_;
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> cross_node_connections_;
|
||||
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory> registered_sm_memories_;
|
||||
std::unordered_map<int, mscclpp::RegisteredMemory> registered_port_memories_;
|
||||
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memory_semaphores_;
|
||||
std::unordered_map<int, std::shared_ptr<mscclpp::Host2DeviceSemaphore>> port_semaphores_;
|
||||
|
||||
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels_;
|
||||
std::unordered_map<int, mscclpp::PortChannel> port_channels_;
|
||||
|
||||
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle> d_memHandles_;
|
||||
mscclpp::GpuBuffer<mscclpp::PortChannelDeviceHandle> d_portHandles_;
|
||||
|
||||
std::shared_ptr<mscclpp::ProxyService> proxyService;
|
||||
cudaStream_t h2d_stream;
|
||||
const size_t nranks_per_node_;
|
||||
|
||||
std::unordered_map<void*, std::unordered_map<int, mscclpp::MemoryChannel>> input_ptr2memory_channels_;
|
||||
std::unordered_map<void*, mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle>> input_ptr2d_memHandles_;
|
||||
|
||||
public:
|
||||
Msccl2NodeLLcontext(
|
||||
mscclpp::UniqueId unique_id,
|
||||
const size_t rank,
|
||||
const size_t world_size,
|
||||
void* scratch,
|
||||
const size_t scratch_bytes,
|
||||
void* put_buffer,
|
||||
const size_t put_buffer_bytes,
|
||||
const size_t nranks_per_node,
|
||||
const std::vector<int64_t>& rank_to_node,
|
||||
const std::vector<int64_t>& rank_to_ib)
|
||||
: scratch_(scratch),
|
||||
scratch_bytes_(scratch_bytes),
|
||||
put_buffer_(put_buffer),
|
||||
put_buffer_bytes_(put_buffer_bytes),
|
||||
nranks_per_node_(nranks_per_node),
|
||||
d_memHandles_(nranks_per_node - 1),
|
||||
d_portHandles_(world_size - nranks_per_node) {
|
||||
CHECK_CUDA_SUCCESS(cudaStreamCreateWithFlags(&h2d_stream, cudaStreamNonBlocking));
|
||||
comm_group_ = std::make_shared<MscclCommGroup>(unique_id, rank, world_size, rank_to_node, rank_to_ib);
|
||||
proxyService = std::make_shared<mscclpp::ProxyService>();
|
||||
proxyService->startProxy();
|
||||
comm_group_->make_connection(same_node_connections_, cross_node_connections_);
|
||||
comm_group_->make_memory_channels_with_scratch(
|
||||
scratch_,
|
||||
scratch_bytes_,
|
||||
scratch_,
|
||||
scratch_bytes_,
|
||||
same_node_connections_,
|
||||
memory_semaphores_,
|
||||
registered_sm_memories_,
|
||||
memory_channels_);
|
||||
comm_group_->make_port_channels_with_scratch(
|
||||
proxyService,
|
||||
put_buffer_,
|
||||
put_buffer_bytes_,
|
||||
scratch_,
|
||||
scratch_bytes_,
|
||||
cross_node_connections_,
|
||||
port_semaphores_,
|
||||
registered_port_memories_,
|
||||
port_channels_);
|
||||
std::vector<mscclpp::MemoryChannel> memory_channels_list;
|
||||
std::vector<mscclpp::PortChannel> port_channels_list;
|
||||
for (int r = 0; r < comm_group_->world_size_; r++) {
|
||||
if (r == comm_group_->rank_) continue;
|
||||
if (comm_group_->is_same_node(r, comm_group_->rank_)) {
|
||||
memory_channels_list.push_back(memory_channels_[r]);
|
||||
} else {
|
||||
port_channels_list.push_back(port_channels_[r]);
|
||||
}
|
||||
}
|
||||
std::vector<mscclpp::MemoryChannelDeviceHandle> memory_channel_handlers(memory_channels_list.size());
|
||||
std::transform(
|
||||
memory_channels_list.begin(),
|
||||
memory_channels_list.end(),
|
||||
memory_channel_handlers.begin(),
|
||||
[](const mscclpp::MemoryChannel& channel) { return channel.deviceHandle(); });
|
||||
mscclpp::gpuMemcpy<mscclpp::MemoryChannelDeviceHandle>(
|
||||
d_memHandles_.data(), memory_channel_handlers.data(), memory_channel_handlers.size(), cudaMemcpyHostToDevice);
|
||||
|
||||
std::vector<mscclpp::PortChannelDeviceHandle> port_channel_handlers(port_channels_list.size());
|
||||
std::transform(
|
||||
port_channels_list.begin(),
|
||||
port_channels_list.end(),
|
||||
port_channel_handlers.begin(),
|
||||
[](const mscclpp::PortChannel& channel) { return channel.deviceHandle(); });
|
||||
mscclpp::gpuMemcpy<mscclpp::PortChannelDeviceHandle>(
|
||||
d_portHandles_.data(), port_channel_handlers.data(), port_channel_handlers.size(), cudaMemcpyHostToDevice);
|
||||
}
|
||||
|
||||
~Msccl2NodeLLcontext() {
|
||||
CHECK_CUDA_SUCCESS(cudaStreamDestroy(h2d_stream));
|
||||
if (proxyService) {
|
||||
proxyService->stopProxy();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
allreduce(cudaStream_t stream, T* input, T* output, const size_t input_numel, int nthreads = 512, int nblocks = 21) {
|
||||
dim3 nthrs(nthreads);
|
||||
dim3 nblks(nblocks);
|
||||
cudaStreamCaptureStatus capturing_status;
|
||||
CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &capturing_status));
|
||||
mscclpp::MemoryChannelDeviceHandle* memChans;
|
||||
if (capturing_status != cudaStreamCaptureStatusActive) {
|
||||
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels;
|
||||
comm_group_->make_device_memory_handle_base_on_new_ptr(
|
||||
memory_channels_,
|
||||
registered_sm_memories_,
|
||||
memory_semaphores_,
|
||||
memory_channels,
|
||||
d_memHandles_,
|
||||
input,
|
||||
scratch_,
|
||||
h2d_stream);
|
||||
CHECK_CUDA_SUCCESS(cudaStreamSynchronize(h2d_stream));
|
||||
memChans = d_memHandles_.data();
|
||||
} else {
|
||||
void* input_void_ptr = reinterpret_cast<void*>(input);
|
||||
if (input_ptr2d_memHandles_.find(input_void_ptr) == input_ptr2d_memHandles_.end()) {
|
||||
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels;
|
||||
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle> device_memory_handle(7);
|
||||
comm_group_->make_device_memory_handle_base_on_new_ptr(
|
||||
memory_channels_,
|
||||
registered_sm_memories_,
|
||||
memory_semaphores_,
|
||||
memory_channels,
|
||||
device_memory_handle,
|
||||
input,
|
||||
scratch_,
|
||||
h2d_stream);
|
||||
input_ptr2memory_channels_.emplace(input_void_ptr, memory_channels);
|
||||
input_ptr2d_memHandles_.emplace(input_void_ptr, device_memory_handle);
|
||||
}
|
||||
auto it = input_ptr2d_memHandles_.find(input_void_ptr);
|
||||
memChans = it->second.data();
|
||||
}
|
||||
allreduce_LL_2node<T><<<nblks, nthrs, 0, stream>>>(
|
||||
memChans,
|
||||
d_portHandles_.data(),
|
||||
(T*)input,
|
||||
(T*)scratch_,
|
||||
(T*)put_buffer_,
|
||||
output,
|
||||
comm_group_->rank_,
|
||||
nranks_per_node_,
|
||||
comm_group_->world_size_,
|
||||
input_numel);
|
||||
|
||||
cudaError_t status = cudaGetLastError();
|
||||
if (status != cudaSuccess) {
|
||||
printf("rank: %lu failed to launch allreduce_LL_2node: %s\n", comm_group_->rank_, cudaGetErrorString(status));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace sglang
|
||||
111
sgl-kernel/csrc/allreduce/quick_all_reduce.cu
Normal file
111
sgl-kernel/csrc/allreduce/quick_all_reduce.cu
Normal file
@@ -0,0 +1,111 @@
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
|
||||
#include "quick_all_reduce.h"
|
||||
|
||||
quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size, std::optional<int64_t> qr_max_size) {
|
||||
if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported");
|
||||
if (world_size == 6) throw std::invalid_argument("world size == 6 is not supported");
|
||||
if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now");
|
||||
if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in");
|
||||
quickreduce::DeviceComms* fptr = new quickreduce::DeviceComms();
|
||||
fptr->init(world_size, rank, qr_max_size);
|
||||
return (quickreduce::fptr_t)fptr;
|
||||
}
|
||||
|
||||
void qr_destroy(quickreduce::fptr_t _fa) {
|
||||
if (_fa) {
|
||||
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
|
||||
fa->destroy();
|
||||
delete fa;
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
|
||||
hipIpcMemHandle_t handle = fa->get_handle();
|
||||
auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
||||
auto data_handle = torch::empty({static_cast<int64_t>(sizeof(hipIpcMemHandle_t))}, options);
|
||||
std::memcpy(data_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t));
|
||||
return data_handle;
|
||||
}
|
||||
|
||||
void qr_open_handles(quickreduce::fptr_t _fa, const std::vector<torch::Tensor>& handles) {
|
||||
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
|
||||
std::vector<hipIpcMemHandle_t> ipc_handles;
|
||||
ipc_handles.reserve(handles.size());
|
||||
for (auto& handle : handles) {
|
||||
// Ensure the tensor is on the same device as the current device.
|
||||
hipIpcMemHandle_t ipc_handle;
|
||||
std::memcpy(&ipc_handle, handle.data_ptr(), sizeof(hipIpcMemHandle_t));
|
||||
ipc_handles.push_back(ipc_handle);
|
||||
}
|
||||
fa->open_ipc_handles(ipc_handles);
|
||||
}
|
||||
|
||||
void qr_all_reduce(
|
||||
quickreduce::fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half) {
|
||||
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||
auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA();
|
||||
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
TORCH_CHECK_LE(out.numel(), fa->kMaxProblemSize);
|
||||
if (out.scalar_type() == at::ScalarType::Half) {
|
||||
fa->allreduce<half, false>(
|
||||
reinterpret_cast<half*>(inp.data_ptr()),
|
||||
reinterpret_cast<half*>(out.data_ptr()),
|
||||
out.numel(),
|
||||
quant_level,
|
||||
stream);
|
||||
} else if (out.scalar_type() == at::ScalarType::BFloat16) {
|
||||
if (cast_bf2half) {
|
||||
fa->allreduce<half, true>(
|
||||
reinterpret_cast<half*>(inp.data_ptr()),
|
||||
reinterpret_cast<half*>(out.data_ptr()),
|
||||
out.numel(),
|
||||
quant_level,
|
||||
stream);
|
||||
} else {
|
||||
fa->allreduce<quickreduce::nv_bfloat16, false>(
|
||||
reinterpret_cast<quickreduce::nv_bfloat16*>(inp.data_ptr()),
|
||||
reinterpret_cast<quickreduce::nv_bfloat16*>(out.data_ptr()),
|
||||
out.numel(),
|
||||
quant_level,
|
||||
stream);
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("quick allreduce only supports float16 and bfloat16");
|
||||
}
|
||||
}
|
||||
|
||||
int64_t qr_max_size() {
|
||||
// The default is 2GB (2,147,483,648 bytes)
|
||||
return static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1;
|
||||
}
|
||||
|
||||
#define INSTANTIATE_FOR_WORLDSIZE(T, Codec, cast_bf2half) \
|
||||
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 2>, cast_bf2half>; \
|
||||
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 4>, cast_bf2half>; \
|
||||
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 8>, cast_bf2half>;
|
||||
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, true)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, true)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, true)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, true)
|
||||
|
||||
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecFP, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ4, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ6, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ8, false)
|
||||
|
||||
#endif // USE_ROCM
|
||||
633
sgl-kernel/csrc/allreduce/quick_all_reduce.cuh
Normal file
633
sgl-kernel/csrc/allreduce/quick_all_reduce.cuh
Normal file
@@ -0,0 +1,633 @@
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "quick_all_reduce_base.h"
|
||||
|
||||
namespace quickreduce {
|
||||
|
||||
struct CodecBase {
|
||||
const int thread;
|
||||
const int rank;
|
||||
const int group_leader;
|
||||
__quickreduce_device_inline__ CodecBase(int thread, int rank)
|
||||
: thread(thread), rank(rank), group_leader((threadIdx.x / kThreadGroupSize) * kThreadGroupSize) {
|
||||
set_fp16_ovfl(true);
|
||||
}
|
||||
};
|
||||
|
||||
// Default full precision codec.
|
||||
template <typename T, int world_size>
|
||||
struct CodecFP : public CodecBase {
|
||||
static constexpr int kWorldSize = world_size;
|
||||
static constexpr int kRankAtoms = kAtoms / kWorldSize;
|
||||
|
||||
// Codec tile size process by this workgroup.
|
||||
// Each thread processes atoms of f16x8_t (16B).
|
||||
static constexpr int kRankTransmittedTileSize = kBlockSize * kRankAtoms * sizeof(int32x4_t);
|
||||
static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTransmittedTileSize must be 16B aligned.");
|
||||
|
||||
// Total tile size for the collective communication.
|
||||
static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize;
|
||||
|
||||
__quickreduce_device_inline__ CodecFP(int thread, int rank) : CodecBase(thread, rank) {}
|
||||
|
||||
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, const int32x4_t* __restrict__ data) {
|
||||
for (int i = 0; i < kRankAtoms; i++) {
|
||||
__builtin_nontemporal_store(data[i], send_buffer + thread);
|
||||
send_buffer += kAtomStride;
|
||||
}
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, int32x4_t* __restrict__ data) {
|
||||
for (int i = 0; i < kRankAtoms; i++) {
|
||||
data[i] = __builtin_nontemporal_load(*recv_buffer + thread);
|
||||
*recv_buffer += kAtomStride;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Int4 symmetric quantization codec.
|
||||
// We quantize the FP16 data to block-scaled Int4 in blocks of 4 *
|
||||
// kThreadGroupSize.
|
||||
template <typename T, int world_size>
|
||||
struct CodecQ4 : public CodecBase {
|
||||
static constexpr int kWorldSize = world_size;
|
||||
|
||||
// Codec tile size process by this workgroup.
|
||||
// Each threads processes a fragment of fp16x8_t (16B),
|
||||
// into a int4x8_t (4B) and a fp16 scale shared among 32 values.
|
||||
static constexpr int kRankAtoms = kAtoms / kWorldSize;
|
||||
static constexpr int kRankTileStride = 1152;
|
||||
static constexpr int kRankTileScaleOffset = 1024;
|
||||
static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms;
|
||||
static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTransmittedTileSize must be 16B aligned.");
|
||||
|
||||
static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t);
|
||||
|
||||
// Total tile size for the collective communication.
|
||||
static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize;
|
||||
|
||||
// Constants configuration
|
||||
|
||||
// {-1/8.0h, -1/8.0h}, f16x2_t
|
||||
static constexpr int kScaleFactor = std::is_same<T, half>::value ? 0xB000B000 : 0xBE00BE00;
|
||||
|
||||
// {1e-7, 1e-7}, f16x2_t
|
||||
static constexpr int kScaleEpsilon = std::is_same<T, half>::value ? 0x00010001 : 0x33D733D7;
|
||||
|
||||
// {-8, -8}, f16x2_t
|
||||
static constexpr int kRangeMin = std::is_same<T, half>::value ? 0xC800C800 : 0xC100C100;
|
||||
|
||||
// {+7, +7}, f16x2_t
|
||||
static constexpr int kRangeMax = std::is_same<T, half>::value ? 0x47004700 : 0x40E040E0;
|
||||
|
||||
// {+8, +8}, int16x2_t
|
||||
static constexpr int kRangeBias = 0x00080008;
|
||||
|
||||
__quickreduce_device_inline__ CodecQ4(int thread, int rank) : CodecBase(thread, rank) {}
|
||||
|
||||
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, const int32x4_t* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
int32x4_t const atom = data[k];
|
||||
|
||||
// Compute the absolute maximum of the atom in the thread group
|
||||
// In 2 blocks of values, upper/lower halves of the f16x2_t
|
||||
int wblockmax = group_abs_max<T>(atom);
|
||||
|
||||
// Derive scales
|
||||
int decoding_scale;
|
||||
int encoding_scale;
|
||||
decoding_scale = packed_mul<T>(wblockmax, kScaleFactor);
|
||||
encoding_scale = packed_add<T>(decoding_scale, kScaleEpsilon);
|
||||
encoding_scale = packed_rcp<T>(encoding_scale);
|
||||
|
||||
// Apply scales to get quantized values
|
||||
int32x4_t w;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(atom[i], encoding_scale);
|
||||
w[i] = packed_max<T>(w[i], kRangeMin);
|
||||
w[i] = packed_min<T>(w[i], kRangeMax);
|
||||
}
|
||||
|
||||
// Convert from f16x2_t to uint16x2_t
|
||||
int32x4_t q;
|
||||
{
|
||||
int16_t* qi = reinterpret_cast<int16_t*>(&q);
|
||||
T* wh = reinterpret_cast<T*>(&w);
|
||||
for (int i = 0; i < 8; i++)
|
||||
qi[i] = (int16_t)rintf(T2float_cast(wh[i]));
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
q[i] = packed_add<int16_t>(q[i], kRangeBias);
|
||||
}
|
||||
}
|
||||
|
||||
// Pack 8 x q4 into int32_t
|
||||
int qw = q[0] | (q[1] << 4) | (q[2] << 8) | (q[3] << 12);
|
||||
|
||||
// Write quantized atom to send_buffer
|
||||
// note: only the group leader stores the scale
|
||||
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(send_buffer + k * kRankBufferTileStride);
|
||||
int32_t* qw_ptr = reinterpret_cast<int32_t*>(atom_ptr) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + (thread / 8);
|
||||
|
||||
__builtin_nontemporal_store(qw, qw_ptr);
|
||||
if (threadIdx.x == group_leader) {
|
||||
__builtin_nontemporal_store(decoding_scale, qs_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, int32x4_t* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
// Directly read quantized atom from recv_buffer
|
||||
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(*recv_buffer);
|
||||
int32_t* qw_ptr = reinterpret_cast<int32_t*>(atom_ptr) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + (thread / 8);
|
||||
|
||||
int32_t qw = __builtin_nontemporal_load(qw_ptr);
|
||||
int qs = __builtin_nontemporal_load(qs_ptr);
|
||||
|
||||
*recv_buffer += kRankBufferTileStride;
|
||||
|
||||
// Unpack q4 into f16x8_t
|
||||
int32x4_t w;
|
||||
{
|
||||
static constexpr uint kMask000F = 0x000F000F;
|
||||
static constexpr uint kHalf2_1024 = 0x64006400; // {1024.0, 1024.0}, fp16x2_t
|
||||
static uint constexpr kHalf2_1032 = 0xE408E408; // {-1032.0, -1032.0}, fp16x2_t
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kHalf2_1024;
|
||||
w[i] = packed_add<half>(q4, kHalf2_1032);
|
||||
} else {
|
||||
int32_t int16_2 = (qw >> (i * 4)) & kMask000F;
|
||||
int16_t low = static_cast<int16_t>(int16_2 & 0xFFFF);
|
||||
int16_t high = static_cast<int16_t>((int16_2 >> 16) & 0xFFFF);
|
||||
nv_bfloat16 bf_low = __float2bfloat16(static_cast<float>(low));
|
||||
nv_bfloat16 bf_high = __float2bfloat16(static_cast<float>(high));
|
||||
nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high);
|
||||
int32_t packed_bf16 = *reinterpret_cast<int32_t*>(&bf2);
|
||||
w[i] = packed_add<nv_bfloat16>(packed_bf16, kRangeMin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply decoding scales
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(w[i], qs);
|
||||
}
|
||||
|
||||
data[k] = w;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Int6 symmetric quantization codec.
|
||||
// We quantize the FP16 data to block-scaled Int6 in blocks of 4 *
|
||||
// kThreadGroupSize.
|
||||
template <typename T, int world_size>
|
||||
struct CodecQ6 : public CodecBase {
|
||||
static constexpr int kWorldSize = world_size;
|
||||
|
||||
// Codec tile size process by this workgroup.
|
||||
// Each threads processes a fragment of fp16x8_t (16B),
|
||||
// into a int6x8_t (4B + 2B) and a fp16 scale shared among 32 values.
|
||||
static constexpr int kRankAtoms = kAtoms / kWorldSize;
|
||||
static constexpr int kRankTileStride = 1664;
|
||||
static constexpr int kRankTileQ2Offset = 1024;
|
||||
static constexpr int kRankTileScaleOffset = 1536;
|
||||
static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms;
|
||||
static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTransmittedTileSize must be 16B aligned.");
|
||||
|
||||
static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t);
|
||||
|
||||
// Total tile size for the collective communication.
|
||||
static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize;
|
||||
|
||||
// Constants configuration
|
||||
|
||||
// {-1/32.0h, -1/32.0h}, fp16x2_t
|
||||
static constexpr int kScaleFactor = std::is_same<T, half>::value ? 0xA800A800 : 0xBD00BD00;
|
||||
|
||||
// {1e-7, 1e-7}, fp16x2_t
|
||||
static constexpr int kScaleEpsilon = std::is_same<T, half>::value ? 0x00010001 : 0x33D733D7;
|
||||
|
||||
// {-32, -32}, fp16x2_t
|
||||
static constexpr int kRangeMin = std::is_same<T, half>::value ? 0xD000D000 : 0xC200C200;
|
||||
|
||||
// {+31, +31}, fp16x2_t
|
||||
static constexpr int kRangeMax = std::is_same<T, half>::value ? 0x4FC04FC0 : 0x41F841F8;
|
||||
|
||||
// {+32, +32}, int16x2_t
|
||||
static constexpr int kRangeBias = 0x00200020;
|
||||
|
||||
__quickreduce_device_inline__ CodecQ6(int thread, int rank) : CodecBase(thread, rank) {}
|
||||
|
||||
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, const int32x4_t* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
int32x4_t const atom = data[k];
|
||||
|
||||
// Compute the absolute maximum of the atom in the thread group
|
||||
// In 2 blocks of values, upper/lower halves of the f16x2_t
|
||||
int wblockmax = group_abs_max<T>(atom);
|
||||
|
||||
// Derive scales
|
||||
int decoding_scale;
|
||||
int encoding_scale;
|
||||
decoding_scale = packed_mul<T>(wblockmax, kScaleFactor);
|
||||
encoding_scale = packed_add<T>(decoding_scale, kScaleEpsilon);
|
||||
encoding_scale = packed_rcp<T>(encoding_scale);
|
||||
|
||||
// Apply scales to get quantized values
|
||||
int32x4_t w;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(atom[i], encoding_scale);
|
||||
w[i] = packed_max<T>(w[i], kRangeMin);
|
||||
w[i] = packed_min<T>(w[i], kRangeMax);
|
||||
}
|
||||
|
||||
// Convert from f16x2_t to uint16x2_t
|
||||
int32x4_t q;
|
||||
{
|
||||
int16_t* qi = reinterpret_cast<int16_t*>(&q);
|
||||
T* wh = reinterpret_cast<T*>(&w);
|
||||
for (int i = 0; i < 8; i++)
|
||||
qi[i] = (int16_t)rintf(T2float_cast(wh[i]));
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
q[i] = packed_add<int16_t>(q[i], kRangeBias);
|
||||
}
|
||||
}
|
||||
|
||||
// Pack 8 x q6 into int32_t + int16_t
|
||||
uint32_t q4w;
|
||||
uint16_t q2w = 0;
|
||||
q4w = (q[0] & 0x000F000F) | ((q[1] & 0x000F000F) << 4) | ((q[2] & 0x000F000F) << 8) | ((q[3] & 0x000F000F) << 12);
|
||||
{
|
||||
int16_t* tw = reinterpret_cast<int16_t*>(&q);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
q2w |= (tw[i] >> 4) << (i * 2);
|
||||
}
|
||||
}
|
||||
// Write quantized atom to send_buffer
|
||||
// note: only the group leader stores the scale
|
||||
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(send_buffer + k * kRankBufferTileStride);
|
||||
uint32_t* q4w_ptr = reinterpret_cast<uint32_t*>(atom_ptr) + thread;
|
||||
uint16_t* q2w_ptr = reinterpret_cast<uint16_t*>(atom_ptr + kRankTileQ2Offset) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + (thread / 8);
|
||||
|
||||
__builtin_nontemporal_store(q4w, q4w_ptr);
|
||||
__builtin_nontemporal_store(q2w, q2w_ptr);
|
||||
if (threadIdx.x == group_leader) {
|
||||
__builtin_nontemporal_store(decoding_scale, qs_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, int32x4_t* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
// Directly read quantized atom from recv_buffer
|
||||
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(*recv_buffer);
|
||||
uint32_t* q4w_ptr = reinterpret_cast<uint32_t*>(atom_ptr) + thread;
|
||||
uint16_t* q2w_ptr = reinterpret_cast<uint16_t*>(atom_ptr + kRankTileQ2Offset) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + (thread / 8);
|
||||
|
||||
uint32_t q4w = __builtin_nontemporal_load(q4w_ptr);
|
||||
uint16_t q2w = __builtin_nontemporal_load(q2w_ptr);
|
||||
int qs = __builtin_nontemporal_load(qs_ptr);
|
||||
|
||||
*recv_buffer += kRankBufferTileStride;
|
||||
|
||||
// Unpack q6 into fp16x8_t
|
||||
int32x4_t w;
|
||||
{
|
||||
static uint constexpr kMask000F = 0x000F000F;
|
||||
static uint constexpr kHalf2_1024 = 0x64006400; // {1024.0, 1024.0}, fp16x2_t
|
||||
static uint constexpr kHalf2_1056 = 0xE420E420; // {-1056.0, -1056.0}, fp16x2_t
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int32_t q4 = q4w & kMask000F;
|
||||
int32_t q2 = (q2w & 0x3) | ((q2w & 0xC) << 14);
|
||||
q4w >>= 4;
|
||||
q2w >>= 4;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
int32_t q6 = q4 | (q2 << 4) | kHalf2_1024;
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(w[i]) : "v"(q6), "v"(kHalf2_1056));
|
||||
} else {
|
||||
int32_t int16_2 = q4 | (q2 << 4);
|
||||
int16_t low = static_cast<int16_t>(int16_2 & 0xFFFF);
|
||||
int16_t high = static_cast<int16_t>((int16_2 >> 16) & 0xFFFF);
|
||||
|
||||
nv_bfloat16 bf_low = __float2bfloat16(static_cast<float>(low));
|
||||
nv_bfloat16 bf_high = __float2bfloat16(static_cast<float>(high));
|
||||
nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high);
|
||||
int32_t packed_bf16 = *reinterpret_cast<int32_t*>(&bf2);
|
||||
w[i] = packed_add<nv_bfloat16>(packed_bf16, kRangeMin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply decoding scales
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(w[i], qs);
|
||||
}
|
||||
|
||||
// That's pretty much it...
|
||||
data[k] = w;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Int8 symmetric quantization codec.
|
||||
// We quantize the FP16 data to block-scaled Int8 in blocks of 4 *
|
||||
// kThreadGroupSize.
|
||||
template <typename T, int world_size>
|
||||
struct CodecQ8 : public CodecBase {
|
||||
static constexpr int kWorldSize = world_size;
|
||||
|
||||
// Codec tile size process by this workgroup.
|
||||
// Each threads processes a fragment of f16x8_t (16B),
|
||||
// into a int8x8_t (8B) and a f16 scale shared among 32 values.
|
||||
static constexpr int kRankAtoms = kAtoms / kWorldSize;
|
||||
static constexpr int kRankTileStride = 2176;
|
||||
static constexpr int kRankTileScaleOffset = 2048;
|
||||
static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms;
|
||||
static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTileSize must be 16B aligned.");
|
||||
|
||||
static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t);
|
||||
|
||||
// Total tile size for the collective communication.
|
||||
static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize;
|
||||
|
||||
// Constants configuration
|
||||
|
||||
// {-1/128.0h, -1/128.0h}, f16x2_t
|
||||
static constexpr int kScaleFactor = std::is_same<T, half>::value ? 0xA000A000 : 0xBC00BC00;
|
||||
|
||||
// {1e-7, 1e-7}, f16x2_t
|
||||
static constexpr int kScaleEpsilon = std::is_same<T, half>::value ? 0x00010001 : 0x33D733D7;
|
||||
|
||||
// {-128, -128}, f16x2_t
|
||||
static constexpr int kRangeMin = std::is_same<T, half>::value ? 0xD800D800 : 0xC300C300;
|
||||
// {+127, +127}, f16x2_t
|
||||
static constexpr int kRangeMax = std::is_same<T, half>::value ? 0x57F057F0 : 0x42FE42FE;
|
||||
|
||||
// {+128, +128}, int16x2_t
|
||||
static constexpr int kRangeBias = 0x00800080;
|
||||
|
||||
__quickreduce_device_inline__ CodecQ8(int thread, int rank) : CodecBase(thread, rank) {}
|
||||
|
||||
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, int32x4_t const* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
int32x4_t const atom = data[k];
|
||||
// Compute the absolute maximum of the atom in the thread group
|
||||
// In 2 blocks of values, upper/lower halves of the f16x2_t
|
||||
int wblockmax = group_abs_max<T>(atom);
|
||||
|
||||
// Derive scales
|
||||
int decoding_scale;
|
||||
int encoding_scale;
|
||||
decoding_scale = packed_mul<T>(wblockmax, kScaleFactor);
|
||||
encoding_scale = packed_add<T>(decoding_scale, kScaleEpsilon);
|
||||
encoding_scale = packed_rcp<T>(encoding_scale);
|
||||
|
||||
// Apply scales to get quantized values
|
||||
int32x4_t w;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(atom[i], encoding_scale);
|
||||
w[i] = packed_max<T>(w[i], kRangeMin);
|
||||
w[i] = packed_min<T>(w[i], kRangeMax);
|
||||
}
|
||||
|
||||
// Convert from f16x2_t to uint16x2_t
|
||||
int32x4_t q;
|
||||
{
|
||||
int16_t* qi = reinterpret_cast<int16_t*>(&q);
|
||||
T* wh = reinterpret_cast<T*>(&w);
|
||||
for (int i = 0; i < 8; i++)
|
||||
qi[i] = (int16_t)rintf(T2float_cast(wh[i]));
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
q[i] = packed_add<int16_t>(q[i], kRangeBias);
|
||||
}
|
||||
}
|
||||
|
||||
// Pack 8 x q8 into int32x2_t
|
||||
int32x2_t qw;
|
||||
qw[0] = q[0] | (q[1] << 8);
|
||||
qw[1] = q[2] | (q[3] << 8);
|
||||
|
||||
// Write quantized atom to send_buffer
|
||||
// note: only the group leader stores the scale
|
||||
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(send_buffer + k * kRankBufferTileStride);
|
||||
int32x2_t* qw_ptr = reinterpret_cast<int32x2_t*>(atom_ptr) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + (thread / 8);
|
||||
|
||||
__builtin_nontemporal_store(qw, qw_ptr);
|
||||
if (threadIdx.x == group_leader) {
|
||||
__builtin_nontemporal_store(decoding_scale, qs_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, int32x4_t* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
// Directly read quantized atom from recv_buffer
|
||||
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(*recv_buffer);
|
||||
int32x2_t* qw_ptr = reinterpret_cast<int32x2_t*>(atom_ptr) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + (thread / 8);
|
||||
|
||||
int32x2_t qw = __builtin_nontemporal_load(qw_ptr);
|
||||
int qs = __builtin_nontemporal_load(qs_ptr);
|
||||
|
||||
*recv_buffer += kRankBufferTileStride;
|
||||
|
||||
// Unpack q8 into fp16x8_t
|
||||
int32x4_t w;
|
||||
{
|
||||
static uint constexpr kMask00FF = 0x00FF00FF;
|
||||
|
||||
// {1024.0, 1024.0}, fp16x2_t
|
||||
static uint constexpr kHalf2_1024 = 0x64006400;
|
||||
|
||||
// {-1152.0, -1152.0}, fp16x2_t
|
||||
static uint constexpr kHalf2_1152 = 0xE480E480;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
int32_t q8 = ((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kHalf2_1024;
|
||||
w[i] = packed_add<half>(q8, kHalf2_1152);
|
||||
} else {
|
||||
int32_t int16_2 = (qw[i / 2] >> ((i % 2) * 8)) & kMask00FF;
|
||||
int16_t low = static_cast<int16_t>(int16_2 & 0xFFFF);
|
||||
int16_t high = static_cast<int16_t>((int16_2 >> 16) & 0xFFFF);
|
||||
nv_bfloat16 bf_low = __float2bfloat16(static_cast<float>(low));
|
||||
nv_bfloat16 bf_high = __float2bfloat16(static_cast<float>(high));
|
||||
nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high);
|
||||
int32_t packed_bf16 = *reinterpret_cast<int32_t*>(&bf2);
|
||||
w[i] = packed_add<nv_bfloat16>(packed_bf16, kRangeMin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply decoding scales
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(w[i], qs);
|
||||
}
|
||||
|
||||
data[k] = w;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Twoshot All Reduce
|
||||
template <typename T, class Codec, bool cast_bf2half>
|
||||
struct AllReduceTwoshot {
|
||||
static_assert(sizeof(T) == 2);
|
||||
|
||||
static constexpr int kWorldSize = Codec::kWorldSize;
|
||||
|
||||
__device__ static void
|
||||
run(T const* __restrict__ input,
|
||||
T* __restrict__ output,
|
||||
uint32_t const N, // number of elements
|
||||
int const block, // block index
|
||||
int const rank, // rank index
|
||||
uint8_t** __restrict__ buffer_list, // communication buffers
|
||||
uint32_t const data_offset, // offset to start of the data buffer
|
||||
uint32_t flag_color) {
|
||||
// Topology
|
||||
int thread = threadIdx.x + threadIdx.y * kWavefront;
|
||||
uint8_t* rank_buffer = buffer_list[rank];
|
||||
Codec codec(thread, rank);
|
||||
int block_id = blockIdx.x;
|
||||
int grid_size = gridDim.x;
|
||||
// --------------------------------------------------------
|
||||
// Read input into registers
|
||||
int32x4_t tA[kAtoms];
|
||||
|
||||
BufferResource src_buffer(const_cast<T*>(input), N * sizeof(T));
|
||||
uint32_t src_offset = block * kTileSize + thread * sizeof(int32x4_t);
|
||||
|
||||
for (int i = 0; i < kAtoms; i++) {
|
||||
tA[i] = buffer_load_dwordx4(src_buffer.descriptor, src_offset, 0, 0);
|
||||
src_offset += kAtomStride * sizeof(int32x4_t);
|
||||
if constexpr (cast_bf2half) {
|
||||
const nv_bfloat162* bf_buf = reinterpret_cast<const nv_bfloat162*>(&tA[i]);
|
||||
half2 half_buf[4];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
float2 f = __bfloat1622float2(bf_buf[j]);
|
||||
half_buf[j] = __float22half2_rn(f);
|
||||
}
|
||||
tA[i] = *reinterpret_cast<const int32x4_t*>(half_buf);
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------
|
||||
// Phase-1A: Write segment data into the communication buffer of the target
|
||||
// rank responsible for this segment.
|
||||
uint32_t comm_data0_offset = data_offset + block_id * Codec::kTransmittedTileSize;
|
||||
uint32_t comm_data1_offset = grid_size * Codec::kTransmittedTileSize + comm_data0_offset;
|
||||
|
||||
uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t));
|
||||
uint32_t comm_flags1_offset = grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset;
|
||||
|
||||
for (int r = 0; r < kWorldSize; r++) {
|
||||
int32x4_t* send_buffer =
|
||||
reinterpret_cast<int32x4_t*>(buffer_list[r] + comm_data0_offset + rank * Codec::kRankTransmittedTileSize);
|
||||
codec.send(send_buffer, &tA[r * Codec::kRankAtoms]);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
if (thread < kWorldSize) {
|
||||
int r = thread;
|
||||
uint32_t* flag_ptr = reinterpret_cast<uint32_t*>(buffer_list[r] + comm_flags0_offset + rank * sizeof(uint32_t));
|
||||
set_sync_flag(flag_ptr, flag_color);
|
||||
}
|
||||
// --------------------------------------------------------
|
||||
// Phase-1B: Reduce the segment data from the communication buffers.
|
||||
int32x4_t tR[Codec::kRankAtoms] = {};
|
||||
{
|
||||
// Read the data from the communication buffer.
|
||||
int32x4_t* recv_buffer = reinterpret_cast<int32x4_t*>(rank_buffer + comm_data0_offset);
|
||||
uint32_t* flag_ptr = reinterpret_cast<uint32_t*>(rank_buffer + comm_flags0_offset);
|
||||
|
||||
for (int r = 0; r < kWorldSize; r++) {
|
||||
// Wait for the flags to be set.
|
||||
if (thread == 0) {
|
||||
wait_sync_flag(&flag_ptr[r], flag_color);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// note: we reuse tA as temp buffer here
|
||||
codec.recv(&recv_buffer, tA);
|
||||
|
||||
for (int i = 0; i < Codec::kRankAtoms; i++) {
|
||||
packed_assign_add<T>(&tR[i], &tA[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Phase-2: Write the reduced segment to every other rank
|
||||
for (int r = 0; r < kWorldSize; r++) {
|
||||
int32x4_t* send_buffer =
|
||||
reinterpret_cast<int32x4_t*>(buffer_list[r] + comm_data1_offset + rank * Codec::kRankTransmittedTileSize);
|
||||
codec.send(send_buffer, tR);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
if (thread < kWorldSize) {
|
||||
int r = thread;
|
||||
uint32_t* flag_ptr = reinterpret_cast<uint32_t*>(buffer_list[r] + comm_flags1_offset + rank * sizeof(uint32_t));
|
||||
set_sync_flag(flag_ptr, flag_color);
|
||||
}
|
||||
|
||||
// Phase-2: Read the gather segments from the rank's communication buffer.
|
||||
{
|
||||
// Read the data from the communication buffer.
|
||||
int32x4_t* recv_buffer = reinterpret_cast<int32x4_t*>(rank_buffer + comm_data1_offset);
|
||||
uint32_t* flag_ptr = reinterpret_cast<uint32_t*>(rank_buffer + comm_flags1_offset);
|
||||
|
||||
for (int r = 0; r < kWorldSize; r++) {
|
||||
// Wait for the flags to be set.
|
||||
if (thread == 0) {
|
||||
wait_sync_flag(&flag_ptr[r], flag_color);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Gather all reduced and final rank segments into tA.
|
||||
codec.recv(&recv_buffer, &tA[r * Codec::kRankAtoms]);
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------
|
||||
// Write the result to output.
|
||||
BufferResource dst_buffer(output, N * sizeof(T));
|
||||
uint32_t dst_offset = block * kTileSize + thread * sizeof(int32x4_t);
|
||||
|
||||
for (int i = 0; i < kAtoms; i++) {
|
||||
if constexpr (cast_bf2half) {
|
||||
const half2* half_buf = reinterpret_cast<const half2*>(&tA[i]);
|
||||
nv_bfloat162 bf16_buf[4];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
float2 f = __half22float2(half_buf[j]);
|
||||
bf16_buf[j] = __float22bfloat162_rn(f);
|
||||
}
|
||||
buffer_store_dwordx4(*reinterpret_cast<const int32x4_t*>(bf16_buf), dst_buffer.descriptor, dst_offset, 0, 0);
|
||||
} else {
|
||||
buffer_store_dwordx4(tA[i], dst_buffer.descriptor, dst_offset, 0, 0);
|
||||
}
|
||||
dst_offset += kAtomStride * sizeof(int32x4_t);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace quickreduce
|
||||
233
sgl-kernel/csrc/allreduce/quick_all_reduce.h
Normal file
233
sgl-kernel/csrc/allreduce/quick_all_reduce.h
Normal file
@@ -0,0 +1,233 @@
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "quick_all_reduce.cuh"
|
||||
|
||||
#define HIP_CHECK(err) \
|
||||
do { \
|
||||
hipError_t err_ = (err); \
|
||||
if (err_ != hipSuccess) { \
|
||||
std::printf("HIP error %d at %s:%d. %s\n", err_, __FILE__, __LINE__, hipGetErrorString(err_)); \
|
||||
throw std::runtime_error("HIP error"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace quickreduce {
|
||||
using fptr_t = int64_t;
|
||||
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||
|
||||
template <typename AllReduceKernel, typename T>
|
||||
__global__ __quickreduce_launch_bounds_two_shot__ static void allreduce_prototype_twoshot(
|
||||
T const* A,
|
||||
T* B,
|
||||
uint32_t N,
|
||||
uint32_t num_blocks,
|
||||
int rank,
|
||||
uint8_t** dbuffer_list,
|
||||
uint32_t data_offset,
|
||||
uint32_t flag_color) {
|
||||
int block = blockIdx.x;
|
||||
int grid = gridDim.x;
|
||||
|
||||
while (block < num_blocks) {
|
||||
AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset, flag_color);
|
||||
block += grid;
|
||||
flag_color++;
|
||||
}
|
||||
}
|
||||
|
||||
#define TWOSHOT_DISPATCH(__codec) \
|
||||
if (world_size == 2) { \
|
||||
using LineCodec = __codec<T, 2>; \
|
||||
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
|
||||
hipLaunchKernelGGL( \
|
||||
(allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
||||
dim3(grid), \
|
||||
dim3(kBlockTwoShot), \
|
||||
0, \
|
||||
stream, \
|
||||
A, \
|
||||
B, \
|
||||
N, \
|
||||
num_blocks, \
|
||||
rank, \
|
||||
dbuffer_list, \
|
||||
data_offset, \
|
||||
flag_color); \
|
||||
} else if (world_size == 4) { \
|
||||
using LineCodec = __codec<T, 4>; \
|
||||
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
|
||||
hipLaunchKernelGGL( \
|
||||
(allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
||||
dim3(grid), \
|
||||
dim3(kBlockTwoShot), \
|
||||
0, \
|
||||
stream, \
|
||||
A, \
|
||||
B, \
|
||||
N, \
|
||||
num_blocks, \
|
||||
rank, \
|
||||
dbuffer_list, \
|
||||
data_offset, \
|
||||
flag_color); \
|
||||
} else if (world_size == 8) { \
|
||||
using LineCodec = __codec<T, 8>; \
|
||||
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
|
||||
hipLaunchKernelGGL( \
|
||||
(allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
||||
dim3(grid), \
|
||||
dim3(kBlockTwoShot), \
|
||||
0, \
|
||||
stream, \
|
||||
A, \
|
||||
B, \
|
||||
N, \
|
||||
num_blocks, \
|
||||
rank, \
|
||||
dbuffer_list, \
|
||||
data_offset, \
|
||||
flag_color); \
|
||||
}
|
||||
|
||||
enum QuickReduceQuantLevel {
|
||||
F16 = 0,
|
||||
INT8 = 1,
|
||||
INT6 = 2,
|
||||
INT4 = 3,
|
||||
};
|
||||
|
||||
struct DeviceComms {
|
||||
// Max problem size is 2GB (in bytes) or half of uint32_t max value.
|
||||
int64_t kMaxProblemSize = static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1;
|
||||
|
||||
// Max TP-8
|
||||
static int constexpr kMaxWorldSize = 8;
|
||||
|
||||
bool initialized = false;
|
||||
uint32_t flag_color = 1;
|
||||
int world_size;
|
||||
int rank;
|
||||
|
||||
uint8_t* dbuffer;
|
||||
uint8_t** dbuffer_list;
|
||||
hipIpcMemHandle_t buffer_ipc_handle;
|
||||
std::vector<hipIpcMemHandle_t> all_buffer_ipc_handles;
|
||||
std::vector<uint8_t*> buffer_list;
|
||||
uint32_t data_offset;
|
||||
|
||||
DeviceComms() : initialized(false), world_size(1), rank(0) {}
|
||||
~DeviceComms() {
|
||||
destroy();
|
||||
}
|
||||
|
||||
void init(int world_size, int rank, std::optional<int64_t> max_problem_size = std::nullopt) {
|
||||
destroy();
|
||||
this->world_size = world_size;
|
||||
this->rank = rank;
|
||||
if (max_problem_size.has_value() && max_problem_size.value() > 0) {
|
||||
this->kMaxProblemSize = max_problem_size.value();
|
||||
}
|
||||
// Allocate buffer size for worst case: F16 2-stage buffer.
|
||||
uint32_t flags_buffer_size = 2 * world_size * kMaxNumBlocks * sizeof(uint32_t);
|
||||
static int64_t data_buffer_size = 2 * this->kMaxProblemSize;
|
||||
int64_t total_buffer_size = flags_buffer_size + data_buffer_size;
|
||||
data_offset = flags_buffer_size;
|
||||
HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size, hipDeviceMallocUncached));
|
||||
|
||||
// Clear the flags buffer.
|
||||
HIP_CHECK(hipMemset(dbuffer, 0, flags_buffer_size));
|
||||
|
||||
// Device-side list of IPC buffers.
|
||||
buffer_list.resize(world_size);
|
||||
HIP_CHECK(hipMalloc(&dbuffer_list, world_size * sizeof(uint8_t*)));
|
||||
|
||||
// Create IPC handles for rank's communication buffer.
|
||||
all_buffer_ipc_handles.resize(world_size);
|
||||
HIP_CHECK(hipIpcGetMemHandle(&buffer_ipc_handle, dbuffer));
|
||||
|
||||
initialized = true;
|
||||
}
|
||||
int get_world_size() {
|
||||
return world_size;
|
||||
}
|
||||
int get_rank() {
|
||||
return rank;
|
||||
}
|
||||
bool status() {
|
||||
return initialized;
|
||||
}
|
||||
hipIpcMemHandle_t const get_handle() {
|
||||
return buffer_ipc_handle;
|
||||
}
|
||||
|
||||
void destroy() {
|
||||
if (initialized) {
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
if (i != rank) {
|
||||
HIP_CHECK(hipIpcCloseMemHandle(dbuffer_list[i]));
|
||||
}
|
||||
}
|
||||
|
||||
HIP_CHECK(hipFree(dbuffer));
|
||||
HIP_CHECK(hipFree(dbuffer_list));
|
||||
|
||||
initialized = false;
|
||||
}
|
||||
}
|
||||
|
||||
void open_ipc_handles(std::vector<hipIpcMemHandle_t> const& ipc_handles) {
|
||||
assert(ipc_handles.size() == all_buffer_ipc_handles.size());
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
all_buffer_ipc_handles[i] = ipc_handles[i];
|
||||
}
|
||||
|
||||
// Open device memory access to the IPC communication buffers.
|
||||
// Note: For our own rank, we do not need to open a handle.
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
if (i != rank) {
|
||||
HIP_CHECK(
|
||||
hipIpcOpenMemHandle((void**)&buffer_list[i], all_buffer_ipc_handles[i], hipIpcMemLazyEnablePeerAccess));
|
||||
} else {
|
||||
buffer_list[i] = dbuffer;
|
||||
}
|
||||
}
|
||||
|
||||
HIP_CHECK(hipMemcpy(dbuffer_list, buffer_list.data(), world_size * sizeof(uint8_t*), hipMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
template <typename T, bool cast_bf2half>
|
||||
void allreduce(T const* A, T* B, uint32_t N, int quant_level, hipStream_t stream) {
|
||||
if (world_size != 2 && world_size != 4 && world_size != 8) {
|
||||
throw std::runtime_error("All Reduce not supported for world_size = " + std::to_string(world_size));
|
||||
}
|
||||
|
||||
// Configuration.
|
||||
uint32_t msg_size = N * sizeof(T);
|
||||
uint32_t num_blocks = divceil(msg_size, kTileSize);
|
||||
uint32_t grid = min(kMaxNumBlocks, num_blocks);
|
||||
auto quant_level_ = static_cast<QuickReduceQuantLevel>(quant_level);
|
||||
switch (quant_level_) {
|
||||
case QuickReduceQuantLevel::INT8:
|
||||
TWOSHOT_DISPATCH(CodecQ8)
|
||||
break;
|
||||
case QuickReduceQuantLevel::INT6:
|
||||
TWOSHOT_DISPATCH(CodecQ6)
|
||||
break;
|
||||
case QuickReduceQuantLevel::INT4:
|
||||
TWOSHOT_DISPATCH(CodecQ4)
|
||||
break;
|
||||
default:
|
||||
TWOSHOT_DISPATCH(CodecFP)
|
||||
break;
|
||||
}
|
||||
HIP_CHECK(cudaGetLastError());
|
||||
// Rotate the flag color.
|
||||
flag_color += divceil(N, grid);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace quickreduce
|
||||
318
sgl-kernel/csrc/allreduce/quick_all_reduce_base.h
Normal file
318
sgl-kernel/csrc/allreduce/quick_all_reduce_base.h
Normal file
@@ -0,0 +1,318 @@
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#define __quickreduce_device_inline__ __device__ __forceinline__
|
||||
#define __quickreduce_launch_bounds_two_shot__ __launch_bounds__(256, 4)
|
||||
#define __quickreduce_launch_bounds_one_shot__ __launch_bounds__(512, 4)
|
||||
|
||||
namespace quickreduce {
|
||||
|
||||
typedef __hip_bfloat16 nv_bfloat16;
|
||||
typedef __hip_bfloat162 nv_bfloat162;
|
||||
|
||||
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
|
||||
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
||||
|
||||
// Setup acquire-release semantics for vector memory reads (mubuf instruction)
|
||||
// as per architecture.
|
||||
#if defined(__gfx942__)
|
||||
// CDNA3: Scope bits sc0, sc1
|
||||
#define MUBUF_ACQUIRE 16
|
||||
#define MUBUF_RELEASE 16
|
||||
#elif (defined(__gfx908__) || defined(__gfx90a__))
|
||||
// CDNA1 and CDNA2 - glc bit
|
||||
#define MUBUF_ACQUIRE 1
|
||||
#define MUBUF_RELEASE 0
|
||||
#endif
|
||||
|
||||
static constexpr int kNegOne = 0xBC00BC00; // {-1, -1}, fp16x2_t
|
||||
|
||||
// Number of atoms (4xf16x2_t) processed by a single thread
|
||||
static constexpr int kAtoms = 8;
|
||||
|
||||
// We use a workgroup of 256 threads
|
||||
static constexpr int kBlockSize = 256;
|
||||
static constexpr int kAtomStride = kBlockSize;
|
||||
|
||||
// Size and atom stride of source/destination data that the block will
|
||||
// process.
|
||||
// Workgroup scope = Tile = (256 threads x 8 atoms x 16B)
|
||||
static constexpr int kTileSize = kBlockSize * kAtoms * sizeof(int32x4_t);
|
||||
|
||||
// Max number of blocks. 304 CUs on MI300
|
||||
static constexpr int kMaxNumBlocks = 304 * 4;
|
||||
|
||||
// Standard CDNA wavefront size.
|
||||
static constexpr int kWavefront = 64;
|
||||
|
||||
// 256 thread, 4 wavefronts.
|
||||
static dim3 constexpr kBlockTwoShot = {kWavefront, kBlockSize / kWavefront, 1};
|
||||
|
||||
// Number of threads in a group for quantization
|
||||
// It corresponds to 32 F16 elements in quantization block
|
||||
static constexpr int kThreadGroupSize = 8;
|
||||
|
||||
// Methods
|
||||
__quickreduce_device_inline__ __host__ unsigned long divceil(unsigned long x, unsigned long y) {
|
||||
return ((x + y - 1) / y);
|
||||
}
|
||||
|
||||
union BufferResource {
|
||||
__quickreduce_device_inline__ constexpr BufferResource() : config(0x00020000U) {}
|
||||
|
||||
__quickreduce_device_inline__ constexpr BufferResource(void* buffer_address, uint32_t buffer_size)
|
||||
: address(buffer_address), range(buffer_size), config(0x00020000U) {}
|
||||
|
||||
int32x4_t descriptor;
|
||||
struct {
|
||||
void* address; // 8B, out of which first 48b is address, and 16b is stride
|
||||
// (unused)
|
||||
uint32_t range; // Byte range for the buffer resource
|
||||
uint32_t config; // Constant, DFMT=32b
|
||||
};
|
||||
};
|
||||
|
||||
__quickreduce_device_inline__ static int32x4_t buffer_load_dwordx4(
|
||||
int32x4_t srsrc, int32_t voffset, int32_t soffset, int32_t aux) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
|
||||
|
||||
__quickreduce_device_inline__ static void
|
||||
buffer_store_dwordx4(int32x4_t data, int32x4_t srsrc, int32_t voffset, int32_t soffset, int32_t aux) __asm(
|
||||
"llvm.amdgcn.raw.buffer.store.v4i32");
|
||||
|
||||
__quickreduce_device_inline__ static void set_fp16_ovfl(bool const value) {
|
||||
#if defined(__gfx942__)
|
||||
if (value) {
|
||||
asm volatile("s_setreg_imm32_b32 0xdc1, 1;" ::);
|
||||
} else {
|
||||
asm volatile("s_setreg_imm32_b32 0xdc1, 0;" ::);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
union bf162_int_union {
|
||||
int i;
|
||||
nv_bfloat162 bf2;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, int32x4_t* B);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ void packed_assign_add<half>(int32x4_t* A, int32x4_t* B) {
|
||||
int32x4_t& tR_fragment = A[0];
|
||||
int32x4_t& tA_fragment = B[0];
|
||||
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[0]) : "v"(tR_fragment[0]), "v"(tA_fragment[0]));
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[1]) : "v"(tR_fragment[1]), "v"(tA_fragment[1]));
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[2]) : "v"(tR_fragment[2]), "v"(tA_fragment[2]));
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[3]) : "v"(tR_fragment[3]), "v"(tA_fragment[3]));
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ void packed_assign_add<nv_bfloat16>(int32x4_t* A, int32x4_t* B) {
|
||||
nv_bfloat162* tA = reinterpret_cast<nv_bfloat162*>(A);
|
||||
nv_bfloat162* tB = reinterpret_cast<nv_bfloat162*>(B);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
tA[i] = __hadd2(tA[i], tB[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_max(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_max<half>(int a, int b) {
|
||||
int result;
|
||||
asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_max<nv_bfloat16>(int a, int b) {
|
||||
bf162_int_union A, B, R;
|
||||
A.i = a;
|
||||
B.i = b;
|
||||
R.bf2 = __hmax2(A.bf2, B.bf2);
|
||||
return R.i;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_min(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_min<half>(int a, int b) {
|
||||
int result;
|
||||
asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_min<nv_bfloat16>(int a, int b) {
|
||||
bf162_int_union A, B, R;
|
||||
A.i = a;
|
||||
B.i = b;
|
||||
R.bf2 = __hmin2(A.bf2, B.bf2);
|
||||
return R.i;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_abs_max(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_abs_max<half>(int a, int b) {
|
||||
half2 wmaxh2 = __builtin_bit_cast(half2, a);
|
||||
half2 wminh2 = __builtin_bit_cast(half2, b);
|
||||
half2 wblockmaxh2;
|
||||
|
||||
wblockmaxh2.x = __hgt(__habs(wmaxh2.x), __habs(wminh2.x)) ? wmaxh2.x : wminh2.x;
|
||||
wblockmaxh2.y = __hgt(__habs(wmaxh2.y), __habs(wminh2.y)) ? wmaxh2.y : wminh2.y;
|
||||
return __builtin_bit_cast(int, wblockmaxh2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_abs_max<nv_bfloat16>(int a, int b) {
|
||||
bf162_int_union A, B, R;
|
||||
A.i = a;
|
||||
B.i = b;
|
||||
R.bf2.x = __hgt(__habs(A.bf2.x), __habs(B.bf2.x)) ? A.bf2.x : B.bf2.x;
|
||||
R.bf2.y = __hgt(__habs(A.bf2.y), __habs(B.bf2.y)) ? A.bf2.y : B.bf2.y;
|
||||
return R.i;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_add(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_add<half>(int a, int b) {
|
||||
int result;
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_add<nv_bfloat16>(int a, int b) {
|
||||
bf162_int_union A, B, R;
|
||||
A.i = a;
|
||||
B.i = b;
|
||||
R.bf2 = __hadd2(A.bf2, B.bf2);
|
||||
return R.i;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_add<int16_t>(int a, int b) {
|
||||
int result;
|
||||
asm volatile("v_pk_add_i16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_sub(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_sub<half>(int a, int b) {
|
||||
int result;
|
||||
|
||||
// MI300 lacks packed fp16 sub instruction. So we do -1 * min + max
|
||||
asm volatile("v_pk_fma_f16 %0, %1, %2 %3" : "=v"(result) : "v"(kNegOne), "v"(b), "v"(a));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_sub<nv_bfloat16>(int a, int b) {
|
||||
bf162_int_union A, B, R;
|
||||
A.i = a;
|
||||
B.i = b;
|
||||
R.bf2 = __hsub2(A.bf2, B.bf2);
|
||||
return R.i;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_mul(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_mul<half>(int a, int b) {
|
||||
int result;
|
||||
asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_mul<nv_bfloat16>(int a, int b) {
|
||||
nv_bfloat162* tA = reinterpret_cast<nv_bfloat162*>(&a);
|
||||
nv_bfloat162* tB = reinterpret_cast<nv_bfloat162*>(&b);
|
||||
nv_bfloat162 tR = __hmul2(*tA, *tB);
|
||||
return *(reinterpret_cast<int*>(&tR));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_rcp(int a);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_rcp<half>(int a) {
|
||||
return __builtin_bit_cast(int, h2rcp(__builtin_bit_cast(half2, a)));
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_rcp<nv_bfloat16>(int a) {
|
||||
bf162_int_union A, R;
|
||||
A.i = a;
|
||||
R.bf2 = h2rcp(A.bf2);
|
||||
return R.i;
|
||||
}
|
||||
|
||||
// changes dtype
|
||||
__quickreduce_device_inline__ float T2float_cast(half a) {
|
||||
return __half2float(a);
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) {
|
||||
return __bfloat162float(a);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int group_abs_max(int32x4_t atom) {
|
||||
const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize;
|
||||
|
||||
int wmax, wmin, wblockmax;
|
||||
int a, b;
|
||||
a = packed_max<T>(atom[0], atom[1]);
|
||||
b = packed_max<T>(atom[2], atom[3]);
|
||||
|
||||
wmax = packed_max<T>(a, b);
|
||||
|
||||
a = packed_min<T>(atom[0], atom[1]);
|
||||
b = packed_min<T>(atom[2], atom[3]);
|
||||
|
||||
wmin = packed_min<T>(a, b);
|
||||
|
||||
// Reduce the max among a group of threads
|
||||
// Note: This is basically 2 blocks of values setup as the
|
||||
// upper/lower halves of the f16x2_t
|
||||
for (int i = 1; i < kThreadGroupSize; i <<= 1) {
|
||||
int x = __shfl_down(wmax, i);
|
||||
wmax = packed_max<T>(wmax, x);
|
||||
|
||||
int y = __shfl_down(wmin, i);
|
||||
wmin = packed_min<T>(wmin, y);
|
||||
}
|
||||
wblockmax = packed_abs_max<T>(wmax, wmin);
|
||||
// Share with the cohort
|
||||
wblockmax = __shfl(wblockmax, group_leader);
|
||||
return wblockmax;
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void set_sync_flag(uint32_t* flag_ptr, uint32_t flag) {
|
||||
__atomic_store_n(flag_ptr, flag, __ATOMIC_RELEASE);
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void wait_sync_flag(uint32_t* flag_ptr, uint32_t flag) {
|
||||
while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag) {
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace quickreduce
|
||||
153
sgl-kernel/csrc/allreduce/test_mscclpp_allreduce.cu
Normal file
153
sgl-kernel/csrc/allreduce/test_mscclpp_allreduce.cu
Normal file
@@ -0,0 +1,153 @@
|
||||
/*
|
||||
* this file is used to test mscclpp_allreduce.cu using mpirun
|
||||
* this file is adapted from https://github.com/flashinfer-ai/flashinfer/blob/v0.2.5/src/test_sum_all_reduce.cu
|
||||
usage:
|
||||
cd PATH-TO-THIS-FILE
|
||||
export MPI_HOME=/usr/local/mpi
|
||||
# export MPI_HOME=/opt/hpcx/ompi/
|
||||
export MSCCLPP_HOME=/workspace/test/mscclpp
|
||||
nvcc -O2 -arch=native -std=c++17 test_mscclpp_allreduce.cu \
|
||||
-o test_mscclpp_allreduce -D_GLIBCXX_USE_CXX11_ABI=0 \
|
||||
-I${MSCCLPP_HOME}/include -L${MSCCLPP_HOME}/build -lmscclpp \
|
||||
-lnccl -I${MPI_HOME}/include -L${MPI_HOME}/lib -lmpi
|
||||
|
||||
/opt/hpcx/ompi/bin/
|
||||
mpirun --allow-run-as-root -H 127.0.0.1:8 -np 8 \
|
||||
--map-by ppr:8:node \
|
||||
--mca btl_openib_warn_no_device_params_found 0 \
|
||||
--mca btl_tcp_if_include bond0 \
|
||||
--allow-run-as-root -np 8 \
|
||||
-x NCCL_RUNTIME_CONNECT=0 -x NCCL_IB_GID_INDEX=3 -x NCCL_DEBUG=WARN \
|
||||
-x LD_PRELOAD=${MSCCLPP_HOME}/build/libmscclpp.so ./test_mscclpp_allreduce
|
||||
*/
|
||||
#include <mpi.h>
|
||||
#include <thrust/detail/raw_pointer_cast.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/host_vector.h>
|
||||
|
||||
#ifndef CHECK_CUDA_SUCCESS
|
||||
#define CHECK_CUDA_SUCCESS(cmd) \
|
||||
do { \
|
||||
cudaError_t e = cmd; \
|
||||
if (e != cudaSuccess) { \
|
||||
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "mscclpp_allreduce.cuh"
|
||||
|
||||
template <typename T>
|
||||
bool isclose(T a, T b, float rtol = 1e-5, float atol = 1e-8) {
|
||||
return fabs(a - b) <= (atol + rtol * fabs(b));
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
// init mpi
|
||||
MPI_Init(&argc, &argv);
|
||||
printf("MPI Initialized.\n");
|
||||
int nranks, rank;
|
||||
|
||||
// get work size and rank id
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &nranks);
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||
cudaSetDevice(rank);
|
||||
printf("nranks: %d, rank: %d\n", nranks, rank);
|
||||
|
||||
// init host and device buffers
|
||||
using T = float;
|
||||
using ReduceT = float;
|
||||
const size_t num_elems = 2 * 1024 * 1024;
|
||||
std::vector<T> host_buf(num_elems);
|
||||
for (uint32_t i = 0; i < num_elems; ++i) {
|
||||
host_buf[i] = T(i + rank);
|
||||
}
|
||||
thrust::device_vector<T> device_buf(host_buf);
|
||||
const size_t buf_size_in_bytes = num_elems * sizeof(T);
|
||||
std::vector<T> host_result_buf(num_elems);
|
||||
thrust::device_vector<T> device_result_buf(host_result_buf);
|
||||
|
||||
std::vector<T> host_scratch_buf(num_elems * 8);
|
||||
for (uint32_t i = 0; i < num_elems; ++i) {
|
||||
host_scratch_buf[i] = 1;
|
||||
}
|
||||
thrust::device_vector<T> device_scratch_buf(host_scratch_buf);
|
||||
std::vector<T> host_put_buf(num_elems);
|
||||
thrust::device_vector<T> device_put_buf(host_put_buf);
|
||||
|
||||
mscclpp::UniqueId unique_id;
|
||||
if (rank == 0) unique_id = mscclpp::TcpBootstrap::createUniqueId();
|
||||
MPI_Bcast(&unique_id, sizeof(unique_id), MPI_BYTE, 0, MPI_COMM_WORLD);
|
||||
|
||||
std::vector<int64_t> rank_to_node(nranks);
|
||||
std::vector<int64_t> rank_to_ib(nranks);
|
||||
for (int i = 0; i < nranks; i++) {
|
||||
rank_to_node[i] = i / 8;
|
||||
rank_to_ib[i] = i % 8;
|
||||
}
|
||||
|
||||
cudaStream_t s;
|
||||
CHECK_CUDA_SUCCESS(cudaStreamCreate(&s));
|
||||
CHECK_CUDA_SUCCESS(cudaStreamSynchronize(s));
|
||||
if (nranks == 8) {
|
||||
auto context = std::make_shared<sglang::Msccl1NodeLLcontext>(
|
||||
unique_id,
|
||||
rank,
|
||||
nranks,
|
||||
thrust::raw_pointer_cast(device_scratch_buf.data()),
|
||||
buf_size_in_bytes * 8,
|
||||
rank_to_node,
|
||||
rank_to_ib);
|
||||
printf("rank: %d, Msccl1NodeLLcontext setup.\n", rank);
|
||||
MPI_Barrier(MPI_COMM_WORLD);
|
||||
context->allreduce<T>(
|
||||
s,
|
||||
thrust::raw_pointer_cast(device_buf.data()),
|
||||
thrust::raw_pointer_cast(device_result_buf.data()),
|
||||
device_buf.size());
|
||||
} else if (nranks == 16) {
|
||||
// TODO: this branch is untested since there is something wrong with mpirun in my test machince
|
||||
auto context = std::make_shared<sglang::Msccl2NodeLLcontext>(
|
||||
unique_id,
|
||||
rank,
|
||||
nranks,
|
||||
thrust::raw_pointer_cast(device_scratch_buf.data()),
|
||||
buf_size_in_bytes * 8,
|
||||
thrust::raw_pointer_cast(device_put_buf.data()),
|
||||
buf_size_in_bytes,
|
||||
rank_to_node,
|
||||
rank_to_ib);
|
||||
printf("rank: %d, Msccl2NodeLLcontext setup.\n", rank);
|
||||
MPI_Barrier(MPI_COMM_WORLD);
|
||||
context->allreduce<T>(
|
||||
s,
|
||||
thrust::raw_pointer_cast(device_buf.data()),
|
||||
thrust::raw_pointer_cast(device_result_buf.data()),
|
||||
device_buf.size());
|
||||
}
|
||||
|
||||
// check result correctness
|
||||
thrust::host_vector<T> host_buf_result = device_result_buf;
|
||||
size_t num_results_error_atol_1e_3_rtol_1e_3 = 0;
|
||||
bool nan_detected = false;
|
||||
|
||||
for (uint32_t i = 0; i < num_elems; ++i) {
|
||||
T expected = T(i * nranks + (nranks - 1) * nranks / 2);
|
||||
if (std::isnan(float(host_buf_result[i]))) {
|
||||
nan_detected = true;
|
||||
}
|
||||
if (!isclose(float(host_buf_result[i]), float(expected), 1e-3, 1e-3)) {
|
||||
num_results_error_atol_1e_3_rtol_1e_3++;
|
||||
}
|
||||
}
|
||||
float result_accuracy = 1. - float(num_results_error_atol_1e_3_rtol_1e_3) / float(num_elems);
|
||||
|
||||
printf("rank: %d, nan_detected: %d accuracy: %f\n", rank, nan_detected, result_accuracy);
|
||||
|
||||
CHECK_CUDA_SUCCESS(cudaStreamDestroy(s));
|
||||
MPI_Finalize();
|
||||
return 0;
|
||||
}
|
||||
55
sgl-kernel/csrc/attention/cascade.cu
Normal file
55
sgl-kernel/csrc/attention/cascade.cu
Normal file
@@ -0,0 +1,55 @@
|
||||
// Adapted from
|
||||
// https://github.com/flashinfer-ai/flashinfer/blob/55576c626421b5ee7e7ebe74afd26465c8ae863f/csrc/cascade.cu
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <flashinfer/attention/cascade.cuh>
|
||||
|
||||
#include "pytorch_extension_utils.h"
|
||||
|
||||
using namespace flashinfer;
|
||||
|
||||
void merge_state(
|
||||
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged) {
|
||||
CHECK_INPUT(v_a);
|
||||
CHECK_INPUT(s_a);
|
||||
CHECK_INPUT(v_b);
|
||||
CHECK_INPUT(s_b);
|
||||
auto device = v_a.device();
|
||||
CHECK_EQ(s_a.device(), device);
|
||||
CHECK_EQ(v_b.device(), device);
|
||||
CHECK_EQ(s_b.device(), device);
|
||||
CHECK_DIM(3, v_a);
|
||||
CHECK_DIM(2, s_a);
|
||||
CHECK_DIM(3, v_b);
|
||||
CHECK_DIM(2, s_b);
|
||||
CHECK_SHAPE(v_a, v_b);
|
||||
CHECK_SHAPE(s_a, s_b);
|
||||
CHECK_EQ(v_a.size(0), s_a.size(0));
|
||||
CHECK_EQ(v_a.size(1), s_b.size(1));
|
||||
unsigned int seq_len = v_a.size(0);
|
||||
unsigned int num_heads = v_a.size(1);
|
||||
unsigned int head_dim = v_a.size(2);
|
||||
|
||||
const c10::cuda::OptionalCUDAGuard device_guard(v_a.device());
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(v_a.scalar_type(), c_type, [&] {
|
||||
cudaError_t status = MergeState(
|
||||
static_cast<c_type*>(v_a.data_ptr()),
|
||||
static_cast<float*>(s_a.data_ptr()),
|
||||
static_cast<c_type*>(v_b.data_ptr()),
|
||||
static_cast<float*>(s_b.data_ptr()),
|
||||
static_cast<c_type*>(v_merged.data_ptr()),
|
||||
static_cast<float*>(s_merged.data_ptr()),
|
||||
seq_len,
|
||||
num_heads,
|
||||
head_dim,
|
||||
stream);
|
||||
TORCH_CHECK(status == cudaSuccess, "MergeState kernel launch failed: ", cudaGetErrorString(status));
|
||||
return true;
|
||||
});
|
||||
|
||||
TORCH_CHECK(success, "MergeState kernel launch failed: unsupported data type");
|
||||
}
|
||||
269
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
Normal file
269
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
Normal file
@@ -0,0 +1,269 @@
|
||||
/*
|
||||
Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/kernel_hardware_info.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
|
||||
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
|
||||
|
||||
// clang-format off
|
||||
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
|
||||
void cutlass_mla_decode(
|
||||
torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table,
|
||||
torch::Tensor const& workspace,
|
||||
int64_t num_kv_splits) {
|
||||
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode");
|
||||
}
|
||||
int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
|
||||
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_get_workspace_size");
|
||||
}
|
||||
#else
|
||||
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
cutlass::Status error = status; \
|
||||
TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \
|
||||
}
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
|
||||
template <bool v>
|
||||
struct IsPersistent {
|
||||
static const bool value = v;
|
||||
};
|
||||
|
||||
template <typename T, bool IsPaged128, typename PersistenceOption = IsPersistent<true>>
|
||||
struct MlaSm100 {
|
||||
using Element = T;
|
||||
using ElementAcc = float;
|
||||
using ElementOut = T;
|
||||
|
||||
using TileShape = Shape<_128, _128, Shape<_512, _64>>;
|
||||
using TileShapeH = cute::tuple_element_t<0, TileShape>;
|
||||
using TileShapeD = cute::tuple_element_t<2, TileShape>;
|
||||
|
||||
// H K (D_latent D_rope) B
|
||||
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
|
||||
|
||||
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
|
||||
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
|
||||
using StrideO = StrideK; // H D B
|
||||
using StrideLSE = cute::tuple<_1, int>; // H B
|
||||
|
||||
using TileScheduler =
|
||||
std::conditional_t<PersistenceOption::value, Sm100MlaPersistentTileScheduler, Sm100MlaIndividualTileScheduler>;
|
||||
|
||||
using FmhaKernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
|
||||
TileShape,
|
||||
Element,
|
||||
ElementAcc,
|
||||
ElementOut,
|
||||
ElementAcc,
|
||||
TileScheduler,
|
||||
/*kIsCpAsync=*/!IsPaged128>;
|
||||
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
typename T::Fmha::Arguments args_from_options(
|
||||
at::Tensor const& out,
|
||||
at::Tensor const& q_nope,
|
||||
at::Tensor const& q_pe,
|
||||
at::Tensor const& kv_c_and_k_pe_cache,
|
||||
at::Tensor const& seq_lens,
|
||||
at::Tensor const& page_table,
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits) {
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = q_nope.device().index();
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
int batches = q_nope.size(0);
|
||||
int page_count_per_seq = page_table.size(1);
|
||||
int page_count_total = kv_c_and_k_pe_cache.size(0);
|
||||
int page_size = kv_c_and_k_pe_cache.size(1);
|
||||
int max_seq_len = page_size * page_count_per_seq;
|
||||
using TileShapeH = typename T::TileShapeH;
|
||||
using TileShapeD = typename T::TileShapeD;
|
||||
auto problem_shape = cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches);
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
float scale = float(sm_scale);
|
||||
|
||||
using StrideQ = typename T::StrideQ;
|
||||
using StrideK = typename T::StrideK;
|
||||
using StrideO = typename T::StrideO;
|
||||
using StrideLSE = typename T::StrideLSE;
|
||||
|
||||
StrideQ stride_Q_nope = cute::make_tuple(
|
||||
static_cast<int64_t>(q_nope.stride(1)), _1{}, static_cast<int64_t>(q_nope.stride(0)));
|
||||
StrideQ stride_Q_pe = cute::make_tuple(
|
||||
static_cast<int64_t>(q_pe.stride(1)), _1{}, static_cast<int64_t>(q_pe.stride(0)));
|
||||
|
||||
StrideK stride_C = cute::make_tuple(
|
||||
static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(page_size * (D_latent + D_rope)));
|
||||
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
|
||||
StrideLSE stride_LSE = cute::make_tuple(_1{}, 0 + H);
|
||||
StrideO stride_O = cute::make_tuple(static_cast<int64_t>(0 + D_latent), _1{}, static_cast<int64_t>(0 + H * D_latent));
|
||||
|
||||
using Element = typename T::Element;
|
||||
using ElementOut = typename T::ElementOut;
|
||||
using ElementAcc = typename T::ElementAcc;
|
||||
auto Q_nope_ptr = static_cast<Element*>(q_nope.data_ptr());
|
||||
auto Q_pe_ptr = static_cast<Element*>(q_pe.data_ptr());
|
||||
auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
|
||||
typename T::Fmha::Arguments arguments{
|
||||
problem_shape,
|
||||
{scale,
|
||||
Q_nope_ptr,
|
||||
stride_Q_nope,
|
||||
Q_pe_ptr,
|
||||
stride_Q_pe,
|
||||
C_ptr,
|
||||
stride_C,
|
||||
C_ptr + D_latent,
|
||||
stride_C,
|
||||
static_cast<int*>(seq_lens.data_ptr()),
|
||||
static_cast<int*>(page_table.data_ptr()),
|
||||
stride_PT,
|
||||
page_count_total,
|
||||
page_size},
|
||||
{static_cast<ElementOut*>(out.data_ptr()), stride_O, static_cast<ElementAcc*>(nullptr), stride_LSE},
|
||||
hw_info,
|
||||
// TODO(trevor-m): Change split_kv back to -1 when
|
||||
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
|
||||
// perform worse with larger context length and smaller batch sizes.
|
||||
static_cast<int>(num_kv_splits), // split_kv
|
||||
nullptr, // is_var_split_kv
|
||||
};
|
||||
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
|
||||
// split_kv automatically based on batch size and sequence length to balance
|
||||
// workload across available SMs. Consider using var_split_kv for manual
|
||||
// control if needed.
|
||||
T::Fmha::set_split_kv(arguments);
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template <typename Element, bool IsPaged128, typename PersistenceOption>
|
||||
void runMla(
|
||||
at::Tensor const& out,
|
||||
at::Tensor const& q_nope,
|
||||
at::Tensor const& q_pe,
|
||||
at::Tensor const& kv_c_and_k_pe_cache,
|
||||
at::Tensor const& seq_lens,
|
||||
at::Tensor const& page_table,
|
||||
at::Tensor const& workspace,
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits,
|
||||
cudaStream_t stream) {
|
||||
using MlaSm100Type = MlaSm100<Element, IsPaged128, PersistenceOption>;
|
||||
typename MlaSm100Type::Fmha fmha;
|
||||
auto arguments = args_from_options<MlaSm100Type>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits);
|
||||
|
||||
CUTLASS_CHECK(fmha.can_implement(arguments));
|
||||
|
||||
CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream));
|
||||
|
||||
CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream));
|
||||
}
|
||||
|
||||
#define DISPATCH_BOOL(expr, const_expr, ...) \
|
||||
[&]() -> bool { \
|
||||
if (expr) { \
|
||||
constexpr bool const_expr = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool const_expr = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
void cutlass_mla_decode(
|
||||
torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table,
|
||||
torch::Tensor const& workspace,
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits) {
|
||||
auto in_dtype = q_nope.dtype();
|
||||
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device());
|
||||
const int page_size = kv_c_and_k_pe_cache.size(1);
|
||||
|
||||
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
|
||||
// Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8)
|
||||
// Maybe per batch split kv will fix this.
|
||||
DISPATCH_BOOL(page_size == 128, IsPaged128, [&] {
|
||||
DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] {
|
||||
if (in_dtype == at::ScalarType::Half) {
|
||||
runMla<cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||
} else if (in_dtype == at::ScalarType::BFloat16) {
|
||||
runMla<cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
|
||||
runMla<cutlass::float_e4m3_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported input data type of MLA");
|
||||
}
|
||||
return true;
|
||||
});
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
|
||||
// Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc)
|
||||
// which are float, so Element type here doesn't matter.
|
||||
using MlaSm100Type = MlaSm100<cutlass::half_t, true>;
|
||||
|
||||
// Get split kv. Requires problem shape and sm_count only.
|
||||
typename MlaSm100Type::Fmha::Arguments arguments;
|
||||
using TileShapeH = typename MlaSm100Type::TileShapeH;
|
||||
using TileShapeD = typename MlaSm100Type::TileShapeD;
|
||||
arguments.problem_shape =
|
||||
cute::make_tuple(TileShapeH{}, static_cast<int>(max_seq_len), TileShapeD{}, static_cast<int>(num_batches));
|
||||
// Assumes device 0 when getting sm_count.
|
||||
arguments.hw_info.sm_count =
|
||||
sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count;
|
||||
arguments.split_kv = static_cast<int>(num_kv_splits);
|
||||
MlaSm100Type::Fmha::set_split_kv(arguments);
|
||||
|
||||
return MlaSm100Type::Fmha::get_workspace_size(arguments);
|
||||
}
|
||||
|
||||
#endif
|
||||
// clang-format on
|
||||
358
sgl-kernel/csrc/attention/cutlass_sm100_mla/device/sm100_mla.hpp
Normal file
358
sgl-kernel/csrc/attention/cutlass_sm100_mla/device/sm100_mla.hpp
Normal file
@@ -0,0 +1,358 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief An universal device layer for cutlass 3.x-style kernels.
|
||||
*/
|
||||
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
// common
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#endif // !defined(__CUDACC_RTC__)
|
||||
|
||||
#include "../kernel/sm100_fmha_mla_tma_warpspecialized.hpp"
|
||||
#include "../kernel/sm100_fmha_mla_reduction.hpp"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::fmha::device {
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class Kernel_
|
||||
>
|
||||
class MLA {
|
||||
public:
|
||||
|
||||
using Kernel = Kernel_;
|
||||
|
||||
using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel<
|
||||
typename Kernel::ElementOut,
|
||||
typename Kernel::ElementAcc,
|
||||
typename Kernel::ElementAcc,
|
||||
Kernel::TileShapeH::value,
|
||||
Kernel::TileShapeL::value,
|
||||
256 /*Max split*/
|
||||
>;
|
||||
|
||||
/// Argument structure: User API
|
||||
using KernelArguments = typename Kernel::Arguments;
|
||||
using ReductionArguments = typename ReductionKernel::Arguments;
|
||||
|
||||
using Arguments = KernelArguments;
|
||||
|
||||
/// Argument structure: Kernel API
|
||||
using KernelParams = typename Kernel::Params;
|
||||
using ReductionParams = typename ReductionKernel::Params;
|
||||
struct Params {
|
||||
KernelParams fmha_params;
|
||||
ReductionParams reduction_params;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
/// Kernel API parameters object
|
||||
Params params_;
|
||||
|
||||
bool is_initialized(bool set = false) {
|
||||
static bool initialized = false;
|
||||
if (set) initialized = true;
|
||||
return initialized;
|
||||
}
|
||||
|
||||
static ReductionArguments to_reduction_args(Arguments const& args) {
|
||||
auto [H, K, D, B] = args.problem_shape;
|
||||
return ReductionArguments{
|
||||
nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse,
|
||||
args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq,
|
||||
args.ptr_split_kv, Kernel::TileShapeS::value
|
||||
};
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Access the Params structure
|
||||
Params const& params() const {
|
||||
return params_;
|
||||
}
|
||||
|
||||
static void set_split_kv (KernelArguments& args) {
|
||||
if (args.split_kv >= 1) return;
|
||||
auto [H, K, D, B] = args.problem_shape;
|
||||
int sm_count = args.hw_info.sm_count;
|
||||
int max_splits = ceil_div(K, 128);
|
||||
int sms_per_batch = max(1, sm_count / B);
|
||||
int split_heur = min(max_splits, sms_per_batch);
|
||||
int waves = ceil_div(B * split_heur, sm_count);
|
||||
int k_waves = ceil_div(max_splits, split_heur);
|
||||
int split_wave_aware = ceil_div(max_splits, k_waves);
|
||||
args.split_kv = split_wave_aware;
|
||||
}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status
|
||||
can_implement(Arguments const& args) {
|
||||
if (! Kernel::can_implement(args)) {
|
||||
return Status::kInvalid;
|
||||
}
|
||||
if (! ReductionKernel::can_implement(to_reduction_args(args))) {
|
||||
return Status::kInvalid;
|
||||
}
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
size_t workspace_bytes = 0;
|
||||
workspace_bytes += Kernel::get_workspace_size(args);
|
||||
workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args));
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
|
||||
CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()");
|
||||
int max_active_blocks = -1;
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
// first, account for dynamic smem capacity if needed
|
||||
cudaError_t result;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaFuncSetAttribute() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// query occupancy after setting smem size
|
||||
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks,
|
||||
device_kernel<Kernel>,
|
||||
Kernel::MaxThreadsPerBlock,
|
||||
smem_size);
|
||||
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status
|
||||
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
// Initialize the workspace
|
||||
Status status = Kernel::initialize_workspace(args, workspace, stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace);
|
||||
|
||||
ReductionArguments reduction_args = to_reduction_args(args);
|
||||
if (reduction_args.split_kv > 1) {
|
||||
reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc;
|
||||
reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc;
|
||||
}
|
||||
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
|
||||
// Initialize the Params structure
|
||||
params_ = Params {kernel_params, reduction_params};
|
||||
|
||||
if (is_initialized()) return Status::kSuccess;
|
||||
|
||||
// account for dynamic smem capacity if needed
|
||||
// no dynamic smem is needed for reduction kernel
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
cudaError_t result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
is_initialized(true);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
|
||||
Status
|
||||
update(Arguments const& args, void* workspace = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace);
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
if (workspace_bytes > 0 && nullptr == workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
auto fmha_params = Kernel::to_underlying_arguments(args, workspace);
|
||||
|
||||
ReductionArguments reduction_args = to_reduction_args(args);
|
||||
if (reduction_args.split_kv > 1) {
|
||||
reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc;
|
||||
reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc;
|
||||
}
|
||||
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
|
||||
// Initialize the Params structure
|
||||
params_ = Params {fmha_params, reduction_params};
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Primary run() entry point API that is static allowing users to create and manage their own params.
|
||||
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
|
||||
static Status
|
||||
run(Params& params, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA::run()");
|
||||
dim3 const block = Kernel::get_block_shape();
|
||||
dim3 const grid = Kernel::get_grid_shape(params.fmha_params);
|
||||
|
||||
// configure smem size and carveout
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
Status launch_result;
|
||||
// Use extended launch API only for mainloops that use it
|
||||
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
|
||||
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
|
||||
cute::size<1>(typename Kernel::ClusterShape{}),
|
||||
cute::size<2>(typename Kernel::ClusterShape{}));
|
||||
void const* kernel = (void const*) device_kernel<Kernel>;
|
||||
void* kernel_params[] = {¶ms.fmha_params};
|
||||
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
|
||||
}
|
||||
else {
|
||||
launch_result = Status::kSuccess;
|
||||
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params.fmha_params);
|
||||
}
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (cudaSuccess != result or Status::kSuccess != launch_result) {
|
||||
//return Status::kSuccess;
|
||||
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
if (params.reduction_params.split_kv > 1) {
|
||||
// launch reduction kernel
|
||||
dim3 const block = ReductionKernel::get_block_shape();
|
||||
dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params);
|
||||
device_kernel<ReductionKernel><<<grid, block, 0, stream>>>(params.reduction_params);
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (cudaSuccess == result) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else {
|
||||
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
|
||||
//
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
Status status = initialize(args, workspace, stream);
|
||||
if (Status::kSuccess == status) {
|
||||
status = run(params_, stream);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
return run(args, workspace, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
run(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
operator()(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::device
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,198 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
using namespace cute;
|
||||
template<
|
||||
class ElementOut,
|
||||
class ElementAcc,
|
||||
class ElementScale,
|
||||
size_t kNumHeads,
|
||||
size_t kHeadDimLatent,
|
||||
int kMaxSplits
|
||||
>
|
||||
struct Sm100FmhaMlaReductionKernel {
|
||||
|
||||
static const int SharedStorageSize = 0;
|
||||
static const int MaxThreadsPerBlock = 128;
|
||||
static const int MinBlocksPerMultiprocessor = 1;
|
||||
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
|
||||
static_assert(kHeadDimLatent % MaxThreadsPerBlock == 0);
|
||||
struct Arguments {
|
||||
ElementAcc* ptr_oaccum = nullptr;
|
||||
ElementOut* ptr_o = nullptr;
|
||||
ElementAcc* ptr_lseaccum = nullptr;
|
||||
ElementAcc* ptr_lse = nullptr;
|
||||
ElementScale scale = 1.f;
|
||||
int num_batches = 0;
|
||||
int split_kv = -1;
|
||||
int dim_k = -1;
|
||||
int* ptr_seq = nullptr;
|
||||
int* ptr_split_kv = nullptr;
|
||||
int tile_shape_s = 128;
|
||||
};
|
||||
using Params = Arguments;
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
return {args.ptr_oaccum, args.ptr_o, args.ptr_lseaccum, args.ptr_lse,
|
||||
args.scale, args.num_batches, args.split_kv, args.dim_k, args.ptr_seq,
|
||||
args.ptr_split_kv, args.tile_shape_s};
|
||||
}
|
||||
|
||||
static size_t get_workspace_size(Arguments const& /*args*/) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static Status initialize_workspace(
|
||||
Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return dim3(kNumHeads, 1, params.num_batches);
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
return dim3(MaxThreadsPerBlock, 1, 1);
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
if (args.num_batches <= 0) return false;
|
||||
if (args.split_kv <= 0) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) {
|
||||
if (params.split_kv <= 1) return;
|
||||
auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z);
|
||||
|
||||
__shared__ ElementAcc sLseScale[kMaxSplits];
|
||||
const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord);
|
||||
const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord);
|
||||
|
||||
Tensor gLSEaccum = make_tensor(make_gmem_ptr(params.ptr_lseaccum + offset_lseaccum),
|
||||
make_shape(params.split_kv), Stride<Int<kNumHeads>>{});
|
||||
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse),
|
||||
Shape<_1>{}, Stride<_1>{});
|
||||
|
||||
auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)];
|
||||
auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)];
|
||||
auto k_tile_total = ceil_div(dim_k, params.tile_shape_s);
|
||||
auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv);
|
||||
local_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
|
||||
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
if (warp_idx == 0) {
|
||||
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
|
||||
|
||||
ElementAcc local_lse[kNLsePerThread];
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
const int split = i * 32 + threadIdx.x;
|
||||
local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits<ElementAcc>::infinity();
|
||||
}
|
||||
|
||||
ElementAcc lse_max = -std::numeric_limits<ElementAcc>::infinity();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
lse_max = max(lse_max, local_lse[i]);
|
||||
}
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int offset = 16; offset >= 1; offset /= 2) {
|
||||
lse_max = max(lse_max, __shfl_xor_sync(0xffffffff, lse_max, offset));
|
||||
}
|
||||
lse_max = lse_max == -std::numeric_limits<ElementAcc>::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf
|
||||
lse_max = __shfl_sync(0xffffffff, lse_max, 0);
|
||||
|
||||
ElementAcc sum_lse = 0;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
sum_lse = sum_lse + expf(local_lse[i] - lse_max);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int offset = 16; offset >= 1; offset /= 2) {
|
||||
sum_lse = sum_lse + __shfl_xor_sync(0xffffffff, sum_lse, offset);
|
||||
}
|
||||
|
||||
sum_lse = __shfl_sync(0xffffffff, sum_lse, 0);
|
||||
|
||||
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + lse_max;
|
||||
if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
|
||||
gLSE(0) = global_lse;
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
const int split = i * 32 + threadIdx.x;
|
||||
if (split < local_split_kv) {
|
||||
sLseScale[split] = expf(local_lse[i] - global_lse);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
constexpr int Elements = kHeadDimLatent / MaxThreadsPerBlock;
|
||||
const size_t offset_oaccum = kHeadDimLatent * params.split_kv * (get<0>(blk_coord) + kNumHeads * get<2>(blk_coord));
|
||||
Tensor gOaccum = make_tensor(make_gmem_ptr(params.ptr_oaccum + offset_oaccum),
|
||||
Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
|
||||
ElementAcc local_val[Elements] = {0};
|
||||
for (int split = 0; split < local_split_kv; ++split) {
|
||||
ElementAcc lse_scale = sLseScale[split];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int i = 0; i < Elements; ++i) {
|
||||
local_val[i] += lse_scale * gOaccum(threadIdx.x + MaxThreadsPerBlock * i);
|
||||
}
|
||||
gOaccum.data() = gOaccum.data() + kHeadDimLatent;
|
||||
}
|
||||
auto ptr_o_local = params.ptr_o + (get<0>(blk_coord) + get<2>(blk_coord) * kNumHeads) * kHeadDimLatent;
|
||||
Tensor gO = make_tensor(make_gmem_ptr(ptr_o_local), Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int i = 0; i < Elements; ++i) {
|
||||
gO(threadIdx.x + MaxThreadsPerBlock * i) = static_cast<ElementOut>(local_val[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,160 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Sm100MlaIndividualTileScheduler {
|
||||
|
||||
struct Params {
|
||||
dim3 grid;
|
||||
};
|
||||
|
||||
bool valid_ = true;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaIndividualTileScheduler(Params const&) {}
|
||||
|
||||
template<class ProblemShape, class ClusterShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, int const& split_kv) {
|
||||
using namespace cute;
|
||||
dim3 grid(get<0>(cluster_shape), get<3>(problem_shape) /* Batch */, split_kv /*Maximum Split KV*/);
|
||||
return Params{ grid };
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return params.grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return valid_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
return make_coord(blockIdx.x, _0{}, blockIdx.y, blockIdx.z);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaIndividualTileScheduler& operator++() {
|
||||
valid_ = false;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Sm100MlaPersistentTileScheduler {
|
||||
|
||||
struct Params {
|
||||
int num_blocks;
|
||||
FastDivmod divmod_m_block;
|
||||
FastDivmod divmod_b;
|
||||
FastDivmod divmod_split_kv;
|
||||
KernelHardwareInfo hw_info;
|
||||
};
|
||||
|
||||
int block_idx = 0;
|
||||
Params params;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
|
||||
|
||||
template<class ProblemShape, class ClusterShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, int const& split_kv) {
|
||||
using namespace cute;
|
||||
// Get SM count if needed, otherwise use user supplied SM count
|
||||
int sm_count = hw_info.sm_count;
|
||||
if (sm_count <= 1 || sm_count % size<0>(cluster_shape) != 0) {
|
||||
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
||||
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
||||
hw_info.sm_count = sm_count;
|
||||
|
||||
int num_m_blocks = size<0>(cluster_shape);
|
||||
int num_blocks = num_m_blocks * get<3>(problem_shape) /* Batch */;
|
||||
num_blocks *= split_kv; /* Maximum Split KV*/
|
||||
|
||||
return Params {
|
||||
num_blocks,
|
||||
{ num_m_blocks}, { get<3>(problem_shape) }, {split_kv},
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
|
||||
return grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return block_idx < params.num_blocks;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
int block_decode = block_idx;
|
||||
int m_block, bidb, n_split_kv;
|
||||
params.divmod_m_block(block_decode, m_block, block_decode);
|
||||
params.divmod_b(block_decode, bidb, block_decode);
|
||||
params.divmod_split_kv(block_decode, n_split_kv, block_decode);
|
||||
return make_coord(m_block, _0{}, bidb, n_split_kv);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaPersistentTileScheduler& operator++() {
|
||||
block_idx += gridDim.x;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
||||
154
sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu
Normal file
154
sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu
Normal file
@@ -0,0 +1,154 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#define THREADS_PER_BLOCK 128
|
||||
|
||||
template <typename T>
|
||||
__global__ void lightning_attention_decode_kernel(
|
||||
const T* __restrict__ q, // [b, h, 1, d]
|
||||
const T* __restrict__ k, // [b, h, 1, d]
|
||||
const T* __restrict__ v, // [b, h, 1, e]
|
||||
const float* __restrict__ past_kv, // [b, h, d, e]
|
||||
const float* __restrict__ slope, // [h, 1, 1]
|
||||
T* __restrict__ output, // [b, h, 1, e]
|
||||
float* __restrict__ new_kv, // [b, h, d, e]
|
||||
const int batch_size,
|
||||
const int num_heads,
|
||||
const int qk_dim,
|
||||
const int v_dim) {
|
||||
extern __shared__ char smem[];
|
||||
T* __restrict__ q_shared = reinterpret_cast<T*>(smem);
|
||||
T* __restrict__ k_shared = reinterpret_cast<T*>(smem + qk_dim * sizeof(T));
|
||||
T* __restrict__ v_shared = reinterpret_cast<T*>(smem + 2 * qk_dim * sizeof(T));
|
||||
float* __restrict__ new_kv_shared = reinterpret_cast<float*>(smem + (2 * qk_dim + v_dim) * sizeof(T));
|
||||
T* __restrict__ output_shared =
|
||||
reinterpret_cast<T*>(smem + (2 * qk_dim + v_dim) * sizeof(T) + qk_dim * (v_dim + 1) * sizeof(float));
|
||||
|
||||
const int32_t tid = threadIdx.x;
|
||||
const int32_t current_head = blockIdx.x;
|
||||
const int32_t b = current_head / num_heads;
|
||||
const int32_t h = current_head % num_heads;
|
||||
|
||||
if (b >= batch_size) return;
|
||||
|
||||
const int32_t qk_offset = b * num_heads * qk_dim + h * qk_dim;
|
||||
const int32_t v_offset = b * num_heads * v_dim + h * v_dim;
|
||||
const int32_t kv_offset = b * num_heads * qk_dim * v_dim + h * qk_dim * v_dim;
|
||||
|
||||
// Load q, k, v into shared memory
|
||||
for (int d = tid; d < qk_dim; d += blockDim.x) {
|
||||
q_shared[d] = q[qk_offset + d];
|
||||
k_shared[d] = k[qk_offset + d];
|
||||
}
|
||||
for (int e = tid; e < v_dim; e += blockDim.x) {
|
||||
v_shared[e] = v[v_offset + e];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
const float ratio = expf(-1.0f * slope[h]);
|
||||
|
||||
// Compute new_kv
|
||||
for (int d = tid; d < qk_dim; d += blockDim.x) {
|
||||
const T k_val = k_shared[d];
|
||||
for (int e = 0; e < v_dim; ++e) {
|
||||
const int past_kv_idx = kv_offset + d * v_dim + e;
|
||||
const T v_val = v_shared[e];
|
||||
const float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val;
|
||||
const int shared_idx = d * (v_dim + 1) + e;
|
||||
new_kv_shared[shared_idx] = new_val;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Store new_kv to global memory
|
||||
for (int idx = tid; idx < qk_dim * v_dim; idx += blockDim.x) {
|
||||
const int d = idx / v_dim;
|
||||
const int e = idx % v_dim;
|
||||
const int shared_idx = d * (v_dim + 1) + e;
|
||||
const int global_idx = kv_offset + idx;
|
||||
new_kv[global_idx] = new_kv_shared[shared_idx];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Compute output
|
||||
for (int e = tid; e < v_dim; e += blockDim.x) {
|
||||
float sum = 0.0f;
|
||||
for (int d = 0; d < qk_dim; ++d) {
|
||||
const int shared_idx = d * (v_dim + 1) + e;
|
||||
sum += q_shared[d] * new_kv_shared[shared_idx];
|
||||
}
|
||||
output_shared[e] = static_cast<T>(sum);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Store output to global memory
|
||||
if (tid == 0) {
|
||||
for (int e = 0; e < v_dim; ++e) {
|
||||
output[v_offset + e] = output_shared[e];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void lightning_attention_decode(
|
||||
const torch::Tensor& q,
|
||||
const torch::Tensor& k,
|
||||
const torch::Tensor& v,
|
||||
const torch::Tensor& past_kv,
|
||||
const torch::Tensor& slope,
|
||||
torch::Tensor output,
|
||||
torch::Tensor new_kv) {
|
||||
TORCH_CHECK(q.is_contiguous(), "q must be contiguous");
|
||||
TORCH_CHECK(k.is_contiguous(), "k must be contiguous");
|
||||
TORCH_CHECK(v.is_contiguous(), "v must be contiguous");
|
||||
TORCH_CHECK(past_kv.is_contiguous(), "past_kv must be contiguous");
|
||||
|
||||
auto batch_size = q.size(0);
|
||||
auto num_heads = q.size(1);
|
||||
auto qk_dim = q.size(3);
|
||||
auto v_dim = v.size(3);
|
||||
|
||||
dim3 block(THREADS_PER_BLOCK);
|
||||
dim3 grid(batch_size * num_heads);
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "lightning_attention_decode_kernel", ([&] {
|
||||
size_t smem_size = (2 * qk_dim + 2 * v_dim) * sizeof(scalar_t) + qk_dim * (v_dim + 1) * sizeof(float);
|
||||
lightning_attention_decode_kernel<scalar_t><<<grid, block, smem_size, stream>>>(
|
||||
q.data_ptr<scalar_t>(),
|
||||
k.data_ptr<scalar_t>(),
|
||||
v.data_ptr<scalar_t>(),
|
||||
past_kv.data_ptr<float>(),
|
||||
slope.data_ptr<float>(),
|
||||
output.data_ptr<scalar_t>(),
|
||||
new_kv.data_ptr<float>(),
|
||||
batch_size,
|
||||
num_heads,
|
||||
qk_dim,
|
||||
v_dim);
|
||||
}));
|
||||
}
|
||||
204
sgl-kernel/csrc/attention/merge_attn_states.cu
Normal file
204
sgl-kernel/csrc/attention/merge_attn_states.cu
Normal file
@@ -0,0 +1,204 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <optional>
|
||||
|
||||
#include "pytorch_extension_utils.h"
|
||||
|
||||
// Helper functions to convert between different data types
|
||||
// (float, half, bfloat16) for the merge attention states kernel.
|
||||
inline __device__ float to_float(float u) {
|
||||
return u;
|
||||
}
|
||||
inline __device__ float to_float(half u) {
|
||||
return __half2float(u);
|
||||
}
|
||||
inline __device__ float to_float(__nv_bfloat16 u) {
|
||||
return __bfloat162float(u);
|
||||
}
|
||||
inline __device__ void from_float(float& d, float s) {
|
||||
d = s;
|
||||
}
|
||||
inline __device__ void from_float(half& d, float s) {
|
||||
d = __float2half(s);
|
||||
}
|
||||
inline __device__ void from_float(__nv_bfloat16& d, float s) {
|
||||
d = __float2bfloat16(s);
|
||||
}
|
||||
|
||||
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||
template <typename scalar_t, const uint NUM_THREADS>
|
||||
__global__ void merge_attn_states_kernel(
|
||||
scalar_t* output,
|
||||
float* output_lse,
|
||||
const scalar_t* prefix_output,
|
||||
const float* prefix_lse,
|
||||
const scalar_t* suffix_output,
|
||||
const float* suffix_lse,
|
||||
const uint num_tokens,
|
||||
const uint num_heads,
|
||||
const uint head_size) {
|
||||
using pack_128b_t = uint4;
|
||||
const uint pack_size = 16 / sizeof(scalar_t);
|
||||
const uint threads_per_head = head_size / pack_size;
|
||||
|
||||
const uint global_idx = blockIdx.x * NUM_THREADS + threadIdx.x;
|
||||
const uint token_head_threads = num_tokens * num_heads * threads_per_head;
|
||||
|
||||
if (global_idx >= token_head_threads) return;
|
||||
|
||||
// global_idx -> token_idx + head_idx + pack_idx
|
||||
const uint token_head_idx = global_idx / threads_per_head;
|
||||
const uint pack_idx = global_idx % threads_per_head;
|
||||
|
||||
const uint token_idx = token_head_idx / num_heads;
|
||||
const uint head_idx = token_head_idx % num_heads;
|
||||
|
||||
const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
|
||||
const uint head_offset = token_idx * num_heads * head_size + head_idx * head_size;
|
||||
const scalar_t* prefix_head_ptr = prefix_output + head_offset;
|
||||
const scalar_t* suffix_head_ptr = suffix_output + head_offset;
|
||||
scalar_t* output_head_ptr = output + head_offset;
|
||||
|
||||
// float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
|
||||
// float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
|
||||
float p_lse = prefix_lse[token_idx * num_heads + head_idx];
|
||||
float s_lse = suffix_lse[token_idx * num_heads + head_idx];
|
||||
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
|
||||
s_lse = std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse;
|
||||
|
||||
const float max_lse = fmaxf(p_lse, s_lse);
|
||||
p_lse = p_lse - max_lse;
|
||||
s_lse = s_lse - max_lse;
|
||||
const float p_se = expf(p_lse);
|
||||
const float s_se = expf(s_lse);
|
||||
const float out_se = p_se + s_se;
|
||||
const float p_scale = p_se / out_se;
|
||||
const float s_scale = s_se / out_se;
|
||||
|
||||
if (pack_offset < head_size) {
|
||||
// Pack 128b load
|
||||
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(prefix_head_ptr)[pack_offset / pack_size];
|
||||
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(suffix_head_ptr)[pack_offset / pack_size];
|
||||
pack_128b_t o_out_pack;
|
||||
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < pack_size; ++i) {
|
||||
// Always use float for FMA to keep high precision.
|
||||
// half(uint16_t), bfloat16, float -> float.
|
||||
const float p_out_f = to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
|
||||
const float s_out_f = to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
|
||||
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
|
||||
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
|
||||
// float -> half(uint16_t), bfloat16, float.
|
||||
from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f);
|
||||
}
|
||||
|
||||
// Pack 128b storage
|
||||
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] = o_out_pack;
|
||||
}
|
||||
// We only need to write to output_lse once per head.
|
||||
if (output_lse != nullptr && pack_idx == 0) {
|
||||
float out_lse = logf(out_se) + max_lse;
|
||||
output_lse[token_idx * num_heads + head_idx] = out_lse;
|
||||
}
|
||||
}
|
||||
|
||||
// The following macro is used to dispatch the conversion function based on
|
||||
// the output data type. The FN is a macro that calls a function with
|
||||
// template<typename scalar_t>.
|
||||
#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \
|
||||
{ \
|
||||
if (scalar_dtype == at::ScalarType::Float) { \
|
||||
fn(float); \
|
||||
} else if (scalar_dtype == at::ScalarType::Half) { \
|
||||
fn(half); \
|
||||
} else if (scalar_dtype == at::ScalarType::BFloat16) { \
|
||||
fn(__nv_bfloat16); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
|
||||
{ \
|
||||
merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<scalar_t*>(output.data_ptr()), \
|
||||
reinterpret_cast<float*>(output_lse.data_ptr()), \
|
||||
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
|
||||
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
|
||||
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
|
||||
reinterpret_cast<float*>(suffix_lse.data_ptr()), \
|
||||
num_tokens, \
|
||||
num_heads, \
|
||||
head_size); \
|
||||
}
|
||||
|
||||
/*@brief Merges the attention states from prefix and suffix
|
||||
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
|
||||
*
|
||||
* @param output [n,h,d] The output tensor to store the merged attention states.
|
||||
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
|
||||
* @param prefix_output [n,h,d] The prefix attention states.
|
||||
* @param prefix_lse [n,h] The log-sum-exp values for the prefix attention
|
||||
* states.
|
||||
* @param suffix_output [n,h,d] The suffix attention states.
|
||||
* @param suffix_lse [n,h] The log-sum-exp values for the suffix attention
|
||||
* states.
|
||||
*/
|
||||
template <typename scalar_t>
|
||||
void merge_attn_states_launcher(
|
||||
const at::Tensor& prefix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
const at::Tensor& prefix_lse, // [NUM_TOKENS, NUM_HEADS]
|
||||
const at::Tensor& suffix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
const at::Tensor& suffix_lse, // [NUM_TOKENS, NUM_HEADS]
|
||||
at::Tensor& output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
at::Tensor& output_lse // [NUM_TOKENS, NUM_HEADS]
|
||||
) {
|
||||
constexpr uint NUM_THREADS = 128;
|
||||
const uint num_tokens = output.size(0);
|
||||
const uint num_heads = output.size(1);
|
||||
const uint head_size = output.size(2);
|
||||
const uint pack_size = 16 / sizeof(scalar_t);
|
||||
TORCH_CHECK(head_size % pack_size == 0, "headsize must be multiple of pack_size:", pack_size);
|
||||
// Process one pack elements per thread. for float, the
|
||||
// pack_size is 4 for half/bf16, the pack_size is 8.
|
||||
const uint threads_per_head = head_size / pack_size;
|
||||
const uint total_threads = num_tokens * num_heads * threads_per_head;
|
||||
|
||||
dim3 block(NUM_THREADS);
|
||||
dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS);
|
||||
|
||||
const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device());
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
|
||||
}
|
||||
|
||||
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
|
||||
{ merge_attn_states_launcher<scalar_t>(v_a, s_a, v_b, s_b, v_merged, s_merged); }
|
||||
|
||||
void merge_state_v2(
|
||||
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged) {
|
||||
// Input tensors must be contiguous
|
||||
CHECK_INPUT(v_a); // v_a prefix_output (seq_len, num_heads, head_dim)
|
||||
CHECK_INPUT(s_a); // s_a prefix_lse (seq_len, num_heads)
|
||||
CHECK_INPUT(v_b); // v_b suffix_output (seq_len, num_heads, head_dim)
|
||||
CHECK_INPUT(s_b); // s_b suffix_lse (seq_len, num_heads)
|
||||
// v_merged output (seq_len, num_heads, head_dim)
|
||||
// s_merged output_lse (seq_len, num_heads)
|
||||
auto device = v_a.device();
|
||||
CHECK_EQ(s_a.device(), device);
|
||||
CHECK_EQ(v_b.device(), device);
|
||||
CHECK_EQ(s_b.device(), device);
|
||||
CHECK_DIM(3, v_a);
|
||||
CHECK_DIM(2, s_a);
|
||||
CHECK_DIM(3, v_b);
|
||||
CHECK_DIM(2, s_b);
|
||||
CHECK_SHAPE(v_a, v_b);
|
||||
CHECK_SHAPE(s_a, s_b);
|
||||
CHECK_EQ(v_a.size(0), s_a.size(0));
|
||||
CHECK_EQ(v_a.size(1), s_b.size(1));
|
||||
DISPATCH_BY_SCALAR_DTYPE(v_merged.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
|
||||
}
|
||||
462
sgl-kernel/csrc/attention/vertical_slash_index.cu
Normal file
462
sgl-kernel/csrc/attention/vertical_slash_index.cu
Normal file
@@ -0,0 +1,462 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT license.
|
||||
// This file is for blocksparse attention utils cuda kernel.
|
||||
|
||||
#include <assert.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <cuda.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
// Save the start index of each block in the given range into block_offset.
|
||||
// Returns the updated block count.
|
||||
__device__ int64_t save_blocks(
|
||||
int* block_offset,
|
||||
int64_t range_start,
|
||||
int64_t range_end,
|
||||
int64_t block_size,
|
||||
int64_t input_block_count,
|
||||
int64_t kv_seqlen) {
|
||||
if (range_start >= kv_seqlen) {
|
||||
return input_block_count;
|
||||
}
|
||||
if (range_end > kv_seqlen) {
|
||||
range_end = kv_seqlen;
|
||||
}
|
||||
int64_t current_block_count = input_block_count;
|
||||
for (int idx = range_start; idx < range_end; idx += block_size) {
|
||||
block_offset[current_block_count++] = idx;
|
||||
}
|
||||
return current_block_count;
|
||||
}
|
||||
|
||||
// CUDA kernel: convert sparse vertical/slash indices to block/column offsets.
|
||||
__global__ void convert_vertical_slash_indexes_kernel(
|
||||
const int* q_seqlens, // [BATCH, ]
|
||||
const int* kv_seqlens, // [BATCH, ]
|
||||
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
||||
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
|
||||
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
||||
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
|
||||
int64_t N_HEADS,
|
||||
int64_t N_ROWS,
|
||||
int64_t BLOCK_SIZE_M,
|
||||
int64_t BLOCK_SIZE_N,
|
||||
int64_t NNZ_V,
|
||||
int64_t NNZ_S,
|
||||
bool causal // True for intra, False for succ
|
||||
) {
|
||||
const int batch_idx = blockIdx.y;
|
||||
const int head_idx = blockIdx.x;
|
||||
const int group_idx = blockIdx.z;
|
||||
|
||||
int64_t q_seqlen = q_seqlens[batch_idx];
|
||||
int64_t kv_seqlen = kv_seqlens[batch_idx];
|
||||
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
|
||||
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
|
||||
if (start_m >= q_seqlen) {
|
||||
return;
|
||||
}
|
||||
int64_t end_m = start_m + BLOCK_SIZE_M;
|
||||
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
|
||||
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
|
||||
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
|
||||
block_count += row_offset;
|
||||
block_offset += row_offset * NNZ_S;
|
||||
column_count += row_offset;
|
||||
column_index += row_offset * NNZ_V;
|
||||
|
||||
bool has_slash = true;
|
||||
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
|
||||
int64_t s = 0, v = 0;
|
||||
int64_t v_idx = vertical_indexes[v++];
|
||||
int64_t s_idx = slash_indexes[s++];
|
||||
if (causal) {
|
||||
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
|
||||
s_idx = slash_indexes[s++];
|
||||
}
|
||||
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
|
||||
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
|
||||
} else {
|
||||
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
|
||||
s_idx = slash_indexes[s++];
|
||||
}
|
||||
if (s_idx > end_m + kv_seqlen) has_slash = false;
|
||||
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
|
||||
}
|
||||
|
||||
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
|
||||
if (!has_slash) {
|
||||
if (causal) {
|
||||
range_start = (kv_seqlen - q_seqlen) + end_m;
|
||||
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
|
||||
} else {
|
||||
range_start = kv_seqlen;
|
||||
range_end = kv_seqlen + BLOCK_SIZE_N;
|
||||
}
|
||||
}
|
||||
|
||||
bool slash_finished = false;
|
||||
while (1) {
|
||||
if (v_idx < range_end) {
|
||||
if (v_idx < range_start) {
|
||||
column_index[tmp_col_cnt++] = v_idx;
|
||||
}
|
||||
if (v < NNZ_V) {
|
||||
v_idx = vertical_indexes[v++];
|
||||
} else {
|
||||
if (causal)
|
||||
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
|
||||
else
|
||||
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
|
||||
}
|
||||
} else {
|
||||
if ((s < NNZ_S && causal) || (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
|
||||
if (causal)
|
||||
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++], BLOCK_SIZE_M);
|
||||
else
|
||||
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
|
||||
} else {
|
||||
if (v == NNZ_V || (v_idx > range_start && causal)) {
|
||||
// add the last vertical if no more slash
|
||||
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
|
||||
column_index[tmp_col_cnt++] = v_idx;
|
||||
}
|
||||
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
|
||||
break;
|
||||
} else {
|
||||
if (causal) {
|
||||
range_start = (kv_seqlen - q_seqlen) + end_m;
|
||||
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
|
||||
} else {
|
||||
// if slash_finished but there are vertical left, save current
|
||||
// blocks
|
||||
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
|
||||
range_start = kv_seqlen;
|
||||
range_end = kv_seqlen + BLOCK_SIZE_N;
|
||||
}
|
||||
slash_finished = true;
|
||||
}
|
||||
}
|
||||
if (!slash_finished) {
|
||||
if (s_idx > range_end + BLOCK_SIZE_M) {
|
||||
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
|
||||
range_start = s_idx - BLOCK_SIZE_M;
|
||||
range_end = s_idx;
|
||||
} else if (s_idx > range_end) {
|
||||
range_end += BLOCK_SIZE_M;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
block_count[0] = tmp_blk_cnt;
|
||||
column_count[0] = tmp_col_cnt;
|
||||
}
|
||||
|
||||
// Host function: launches the kernel with 64 threads per block.
|
||||
void convert_vertical_slash_indexes_64x64(
|
||||
const int* q_seqlens, // [BATCH, ]
|
||||
const int* kv_seqlens, // [BATCH, ]
|
||||
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
||||
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
|
||||
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
||||
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
|
||||
int64_t BATCH_SIZE,
|
||||
int64_t N_HEADS,
|
||||
int64_t N_ROWS,
|
||||
int64_t BLOCK_SIZE_M,
|
||||
int64_t BLOCK_SIZE_N,
|
||||
int64_t NNZ_V,
|
||||
int64_t NNZ_S,
|
||||
bool causal) {
|
||||
const int N_THREADS = 64;
|
||||
const dim3 dimBlock((int32_t)N_THREADS);
|
||||
const dim3 dimGrid(
|
||||
(int32_t)N_HEADS, (int32_t)BATCH_SIZE, ((int32_t)N_ROWS + (int32_t)N_THREADS - 1) / (int32_t)N_THREADS);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock, 0, stream>>>(
|
||||
q_seqlens,
|
||||
kv_seqlens,
|
||||
vertical_indexes,
|
||||
slash_indexes,
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
N_HEADS,
|
||||
N_ROWS,
|
||||
BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N,
|
||||
NNZ_V,
|
||||
NNZ_S,
|
||||
causal);
|
||||
}
|
||||
|
||||
// Host function: prepares tensor pointers and launches the CUDA kernel.
|
||||
void convert_vertical_slash_indexes(
|
||||
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
|
||||
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
|
||||
torch::Tensor q_seqlens, // [BATCH, ]
|
||||
torch::Tensor kv_seqlens, // [BATCH, ]
|
||||
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
int64_t context_size,
|
||||
int64_t block_size_M,
|
||||
int64_t block_size_N,
|
||||
bool causal) {
|
||||
cudaSetDevice(q_seqlens.get_device());
|
||||
|
||||
int64_t batch_size = slash_indexes.size(0);
|
||||
int64_t num_heads = slash_indexes.size(1);
|
||||
int64_t nnz_slash = slash_indexes.size(2);
|
||||
int64_t nnz_vertical = vertical_indexes.size(2);
|
||||
int64_t num_rows = (context_size + block_size_M - 1) / block_size_M;
|
||||
|
||||
convert_vertical_slash_indexes_64x64(
|
||||
q_seqlens.data_ptr<int>(),
|
||||
kv_seqlens.data_ptr<int>(),
|
||||
vertical_indexes.data_ptr<int>(),
|
||||
slash_indexes.data_ptr<int>(),
|
||||
block_count.data_ptr<int>(),
|
||||
block_offset.data_ptr<int>(),
|
||||
column_count.data_ptr<int>(),
|
||||
column_index.data_ptr<int>(),
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
block_size_M,
|
||||
block_size_N,
|
||||
nnz_vertical,
|
||||
nnz_slash,
|
||||
causal);
|
||||
}
|
||||
|
||||
// --- mergehead kernels --- //
|
||||
|
||||
// Kernel: like above, but supports per-head variable NNZ_V/NNZ_S.
|
||||
__global__ void convert_vertical_slash_indexes_kernel_mergehead(
|
||||
const int* q_seqlens, // [BATCH, ]
|
||||
const int* kv_seqlens, // [BATCH, ]
|
||||
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
const int* per_head_vertical_topkv,
|
||||
const int* per_head_slash_topkv,
|
||||
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
||||
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
|
||||
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
||||
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
|
||||
int64_t N_HEADS,
|
||||
int64_t N_ROWS,
|
||||
int64_t BLOCK_SIZE_M,
|
||||
int64_t BLOCK_SIZE_N,
|
||||
int64_t NNZ_V,
|
||||
int64_t NNZ_S,
|
||||
bool causal // True for intra, False for succ
|
||||
) {
|
||||
const int batch_idx = blockIdx.y;
|
||||
const int head_idx = blockIdx.x;
|
||||
const int group_idx = blockIdx.z;
|
||||
|
||||
int64_t q_seqlen = q_seqlens[batch_idx];
|
||||
int64_t kv_seqlen = kv_seqlens[batch_idx];
|
||||
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
|
||||
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
|
||||
if (start_m >= q_seqlen) {
|
||||
return;
|
||||
}
|
||||
int64_t end_m = start_m + BLOCK_SIZE_M;
|
||||
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
|
||||
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
|
||||
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
|
||||
block_count += row_offset;
|
||||
block_offset += row_offset * NNZ_S;
|
||||
column_count += row_offset;
|
||||
column_index += row_offset * NNZ_V;
|
||||
|
||||
// MergeHead: each head has it's unique max topk NNZ_V,NNZ_S. (NNZ_V,NNZ_S
|
||||
// above is buffer size, use to compute offset)
|
||||
NNZ_S = per_head_slash_topkv[head_idx];
|
||||
NNZ_V = per_head_vertical_topkv[head_idx];
|
||||
|
||||
bool has_slash = true;
|
||||
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
|
||||
int64_t s = 0, v = 0;
|
||||
int64_t v_idx = vertical_indexes[v++];
|
||||
int64_t s_idx = slash_indexes[s++];
|
||||
if (causal) {
|
||||
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
|
||||
s_idx = slash_indexes[s++];
|
||||
}
|
||||
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
|
||||
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
|
||||
} else {
|
||||
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
|
||||
s_idx = slash_indexes[s++];
|
||||
}
|
||||
if (s_idx > end_m + kv_seqlen) has_slash = false;
|
||||
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
|
||||
}
|
||||
|
||||
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
|
||||
if (!has_slash) {
|
||||
if (causal) {
|
||||
range_start = (kv_seqlen - q_seqlen) + end_m;
|
||||
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
|
||||
} else {
|
||||
range_start = kv_seqlen;
|
||||
range_end = kv_seqlen + BLOCK_SIZE_N;
|
||||
}
|
||||
}
|
||||
|
||||
bool slash_finished = false;
|
||||
while (1) {
|
||||
if (v_idx < range_end) {
|
||||
if (v_idx < range_start) {
|
||||
column_index[tmp_col_cnt++] = v_idx;
|
||||
}
|
||||
if (v < NNZ_V) {
|
||||
v_idx = vertical_indexes[v++];
|
||||
} else {
|
||||
if (causal)
|
||||
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
|
||||
else
|
||||
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
|
||||
}
|
||||
} else {
|
||||
if ((s < NNZ_S && causal) || (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
|
||||
if (causal)
|
||||
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++], BLOCK_SIZE_M);
|
||||
else
|
||||
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
|
||||
} else {
|
||||
if (v == NNZ_V || (v_idx > range_start && causal)) {
|
||||
// add the last vertical if no more slash
|
||||
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
|
||||
column_index[tmp_col_cnt++] = v_idx;
|
||||
}
|
||||
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
|
||||
break;
|
||||
} else {
|
||||
if (causal) {
|
||||
range_start = (kv_seqlen - q_seqlen) + end_m;
|
||||
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
|
||||
} else {
|
||||
// if slash_finished but there are vertical left, save current
|
||||
// blocks
|
||||
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
|
||||
range_start = kv_seqlen;
|
||||
range_end = kv_seqlen + BLOCK_SIZE_N;
|
||||
}
|
||||
slash_finished = true;
|
||||
}
|
||||
}
|
||||
if (!slash_finished) {
|
||||
if (s_idx > range_end + BLOCK_SIZE_M) {
|
||||
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
|
||||
range_start = s_idx - BLOCK_SIZE_M;
|
||||
range_end = s_idx;
|
||||
} else if (s_idx > range_end) {
|
||||
range_end += BLOCK_SIZE_M;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
block_count[0] = tmp_blk_cnt;
|
||||
column_count[0] = tmp_col_cnt;
|
||||
}
|
||||
|
||||
// Launch the mergehead kernel with 64 threads per block.
|
||||
void convert_vertical_slash_indexes_64x64_mergehead(
|
||||
const int* q_seqlens, // [BATCH, ]
|
||||
const int* kv_seqlens, // [BATCH, ]
|
||||
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
int* per_head_vertical_topkv,
|
||||
int* per_head_slash_topkv,
|
||||
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
||||
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
|
||||
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
|
||||
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
|
||||
int64_t BATCH_SIZE,
|
||||
int64_t N_HEADS,
|
||||
int64_t N_ROWS,
|
||||
int64_t BLOCK_SIZE_M,
|
||||
int64_t BLOCK_SIZE_N,
|
||||
int64_t NNZ_V,
|
||||
int64_t NNZ_S,
|
||||
bool causal) {
|
||||
const int N_THREADS = 64;
|
||||
const dim3 dimBlock(N_THREADS);
|
||||
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
convert_vertical_slash_indexes_kernel_mergehead<<<dimGrid, dimBlock, 0, stream>>>(
|
||||
q_seqlens,
|
||||
kv_seqlens,
|
||||
vertical_indexes,
|
||||
slash_indexes,
|
||||
per_head_vertical_topkv,
|
||||
per_head_slash_topkv,
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
N_HEADS,
|
||||
N_ROWS,
|
||||
BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N,
|
||||
NNZ_V,
|
||||
NNZ_S,
|
||||
causal);
|
||||
}
|
||||
|
||||
// Host wrapper for mergehead kernel.
|
||||
void convert_vertical_slash_indexes_mergehead(
|
||||
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
|
||||
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
|
||||
torch::Tensor q_seqlens, // [BATCH, ]
|
||||
torch::Tensor kv_seqlens, // [BATCH, ]
|
||||
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
|
||||
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
|
||||
torch::Tensor vertical_indices_count, // [N_HEADS, ]
|
||||
torch::Tensor slash_indices_count,
|
||||
int64_t context_size,
|
||||
int64_t block_size_M,
|
||||
int64_t block_size_N,
|
||||
bool causal) {
|
||||
cudaSetDevice(q_seqlens.get_device());
|
||||
|
||||
int batch_size = slash_indexes.size(0);
|
||||
int num_heads = slash_indexes.size(1);
|
||||
int nnz_slash = slash_indexes.size(2);
|
||||
int nnz_vertical = vertical_indexes.size(2);
|
||||
int num_rows = (context_size + block_size_M - 1) / block_size_M;
|
||||
|
||||
convert_vertical_slash_indexes_64x64_mergehead(
|
||||
q_seqlens.data_ptr<int>(),
|
||||
kv_seqlens.data_ptr<int>(),
|
||||
vertical_indexes.data_ptr<int>(),
|
||||
slash_indexes.data_ptr<int>(),
|
||||
vertical_indices_count.data_ptr<int>(),
|
||||
slash_indices_count.data_ptr<int>(),
|
||||
block_count.data_ptr<int>(),
|
||||
block_offset.data_ptr<int>(),
|
||||
column_count.data_ptr<int>(),
|
||||
column_index.data_ptr<int>(),
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
block_size_M,
|
||||
block_size_N,
|
||||
nnz_vertical,
|
||||
nnz_slash,
|
||||
causal);
|
||||
}
|
||||
455
sgl-kernel/csrc/common_extension.cc
Normal file
455
sgl-kernel/csrc/common_extension.cc
Normal file
@@ -0,0 +1,455 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <torch/all.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "sgl_kernel_ops.h"
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
/*
|
||||
* From csrc/allreduce
|
||||
*/
|
||||
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
|
||||
m.def("register_graph_buffers", ®ister_graph_buffers);
|
||||
m.def("dispose", &dispose);
|
||||
m.def("meta_size", &meta_size);
|
||||
m.def("register_buffer", ®ister_buffer);
|
||||
|
||||
m.def(
|
||||
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
|
||||
"int rank, bool full_nvlink) -> int");
|
||||
m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
|
||||
|
||||
m.def(
|
||||
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
|
||||
"int reg_buffer_sz_bytes) -> ()");
|
||||
m.impl("all_reduce", torch::kCUDA, &all_reduce);
|
||||
|
||||
m.def("mscclpp_generate_unique_id", &mscclpp_generate_unique_id);
|
||||
m.def(
|
||||
"mscclpp_init_context(Tensor unique_id, int rank, int world_size, Tensor scratch, Tensor put_buffer, "
|
||||
"int nranks_per_node, int[] rank_to_node, int[] rank_to_ib, int context_selection) -> int");
|
||||
m.impl("mscclpp_init_context", torch::kCUDA, &mscclpp_init_context);
|
||||
|
||||
m.def("mscclpp_allreduce(int context, Tensor inp, Tensor! out, int nthreads, int nblocks) -> ()");
|
||||
m.impl("mscclpp_allreduce", torch::kCUDA, &mscclpp_allreduce);
|
||||
|
||||
/*
|
||||
* From csrc/attention
|
||||
*/
|
||||
m.def(
|
||||
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
|
||||
"new_kv) -> ()");
|
||||
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
|
||||
m.def("merge_state(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()");
|
||||
m.impl("merge_state", torch::kCUDA, &merge_state);
|
||||
m.def("merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()");
|
||||
m.impl("merge_state_v2", torch::kCUDA, &merge_state_v2);
|
||||
m.def(
|
||||
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
|
||||
"page_table, Tensor! workspace, float sm_scale, int num_kv_splits) -> ()");
|
||||
m.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
|
||||
m.def("cutlass_mla_get_workspace_size", &cutlass_mla_get_workspace_size);
|
||||
|
||||
/*
|
||||
* From csrc/elementwise
|
||||
*/
|
||||
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, bool enable_pdl) -> ()");
|
||||
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);
|
||||
|
||||
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, bool enable_pdl) -> ()");
|
||||
m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm);
|
||||
|
||||
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, bool enable_pdl) -> ()");
|
||||
m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm);
|
||||
|
||||
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, bool enable_pdl) -> ()");
|
||||
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);
|
||||
|
||||
m.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
|
||||
m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
|
||||
|
||||
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
|
||||
m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
|
||||
|
||||
m.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
|
||||
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
|
||||
|
||||
m.def(
|
||||
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
|
||||
"Tensor pos_ids, bool interleave, bool enable_pdl, int cuda_stream, "
|
||||
"Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()");
|
||||
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
|
||||
|
||||
m.def(
|
||||
"downcast_fp8(Tensor k, Tensor v, Tensor k_out, Tensor v_out, Tensor k_scale, Tensor v_scale, Tensor loc, int "
|
||||
"mult, int offset, int cuda_stream) -> ()");
|
||||
m.impl("downcast_fp8", torch::kCUDA, &downcast_fp8);
|
||||
|
||||
/*
|
||||
* From csrc/gemm
|
||||
*/
|
||||
m.def("awq_dequantize(Tensor qweight, Tensor scales, Tensor qzeros) -> Tensor");
|
||||
m.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
|
||||
|
||||
m.def(
|
||||
"int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
|
||||
"bias) -> Tensor");
|
||||
m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm);
|
||||
|
||||
m.def(
|
||||
"fp8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
|
||||
"bias) -> Tensor");
|
||||
m.impl("fp8_scaled_mm", torch::kCUDA, &fp8_scaled_mm);
|
||||
|
||||
m.def(
|
||||
"fp8_blockwise_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype) -> "
|
||||
"Tensor");
|
||||
m.impl("fp8_blockwise_scaled_mm", torch::kCUDA, &fp8_blockwise_scaled_mm);
|
||||
|
||||
m.def(
|
||||
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
|
||||
" float eps, float fp8_min, float fp8_max, bool scale_ue8m0) -> ()");
|
||||
m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8);
|
||||
|
||||
m.def(
|
||||
"sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
|
||||
" float eps, float int8_min, float int8_max) -> ()");
|
||||
m.impl("sgl_per_token_group_quant_int8", torch::kCUDA, &sgl_per_token_group_quant_int8);
|
||||
|
||||
m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()");
|
||||
m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8);
|
||||
|
||||
m.def("sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()");
|
||||
m.impl("sgl_per_token_quant_fp8", torch::kCUDA, &sgl_per_token_quant_fp8);
|
||||
|
||||
m.def(
|
||||
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
|
||||
" Tensor block_scale_a, Tensor block_scale_b,"
|
||||
" Tensor alpha) -> ()");
|
||||
m.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
|
||||
|
||||
m.def(
|
||||
"scaled_fp4_quant(Tensor! output, Tensor! input,"
|
||||
" Tensor! output_scale, Tensor! input_scale) -> ()");
|
||||
m.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
|
||||
|
||||
m.def("dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
|
||||
m.impl("dsv3_fused_a_gemm", torch::kCUDA, &dsv3_fused_a_gemm);
|
||||
|
||||
// Compute NVFP4 experts quantization.
|
||||
m.def(
|
||||
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
|
||||
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
|
||||
"Tensor output_scale_offset_by_experts) -> ()");
|
||||
m.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant);
|
||||
|
||||
m.def(
|
||||
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
|
||||
"Tensor input, Tensor input_global_scale, Tensor mask, bool use_silu_and_mul) -> ()");
|
||||
m.impl("silu_and_mul_scaled_fp4_experts_quant", torch::kCUDA, &silu_and_mul_scaled_fp4_experts_quant);
|
||||
|
||||
m.def(
|
||||
"cutlass_fp4_group_mm(Tensor! output, Tensor a, Tensor b,"
|
||||
"Tensor a_blockscale, Tensor b_blockscale, Tensor alphas,"
|
||||
"Tensor ab_strides, Tensor c_strides, Tensor problem_sizes,"
|
||||
" Tensor expert_offsets, Tensor sf_offsets) -> ()");
|
||||
m.impl("cutlass_fp4_group_mm", torch::kCUDA, &cutlass_fp4_group_mm);
|
||||
|
||||
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
|
||||
m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm);
|
||||
|
||||
/*
|
||||
* From csrc/gemm/gptq
|
||||
*/
|
||||
m.def(
|
||||
"gptq_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
||||
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale_or_none,"
|
||||
"Tensor? b_zeros_or_none, Tensor? g_idx_or_none, Tensor? perm_or_none,"
|
||||
"Tensor! workspace, int b_q_type_id, int size_m, int size_n, int size_k,"
|
||||
"bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
||||
m.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
|
||||
|
||||
m.def(
|
||||
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, Tensor b_gptq_scales, Tensor b_g_idx, bool "
|
||||
"use_shuffle, int bit) -> Tensor");
|
||||
m.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
|
||||
|
||||
m.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
|
||||
m.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
|
||||
|
||||
m.def("gptq_marlin_repack(Tensor! b_q_weight, Tensor! perm, int size_k, int size_n, int num_bits) -> Tensor");
|
||||
m.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
|
||||
|
||||
m.def("awq_marlin_repack(Tensor! b_q_weight, int size_k, int size_n, int num_bits) -> Tensor");
|
||||
m.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
|
||||
|
||||
/*
|
||||
* From csrc/moe
|
||||
*/
|
||||
m.def(
|
||||
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
|
||||
"experts_ids, Tensor! num_tokens_post_pad, Tensor! cumsum_buffer, bool "
|
||||
"pad_sorted_token_ids) -> ()");
|
||||
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||
|
||||
m.def("topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize) -> ()");
|
||||
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
|
||||
|
||||
m.def(
|
||||
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
|
||||
"num_fused_shared_experts, float routed_scaling_factor, bool apply_routed_scaling_factor_on_output) -> "
|
||||
"(Tensor[])");
|
||||
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
|
||||
m.def(
|
||||
"ep_moe_pre_reorder(Tensor input, Tensor gateup_input, Tensor src2dst, Tensor topk_ids, Tensor "
|
||||
"a1_scales, int start_expert_id, int end_expert_id, int topk, bool use_per_token_if_dynamic) -> ()");
|
||||
m.impl("ep_moe_pre_reorder", torch::kCUDA, &ep_moe_pre_reorder);
|
||||
m.def(
|
||||
"ep_moe_silu_and_mul(Tensor gateup_output, Tensor down_input, Tensor reorder_topk_ids, Tensor scales, int "
|
||||
"start_expert_id, int end_expert_id) -> ()");
|
||||
m.impl("ep_moe_silu_and_mul", torch::kCUDA, &ep_moe_silu_and_mul);
|
||||
m.def(
|
||||
"ep_moe_post_reorder(Tensor down_output, Tensor output, Tensor src2dst, Tensor topk_ids, Tensor "
|
||||
"topk_weights, int start_expert_id, int end_expert_id, int topk) -> ()");
|
||||
m.impl("ep_moe_post_reorder", torch::kCUDA, &ep_moe_post_reorder);
|
||||
m.def(
|
||||
"fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a_ptrs, Tensor b_ptrs, Tensor out_ptrs, Tensor "
|
||||
"a_scales_ptrs, Tensor b_scales_ptrs, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor "
|
||||
"stride_a, Tensor stride_b, Tensor stride_c, Tensor layout_sfa, Tensor layout_sfb, Tensor problem_sizes, Tensor "
|
||||
"expert_offsets, Tensor workspace) -> ()");
|
||||
m.impl("fp8_blockwise_scaled_grouped_mm", torch::kCUDA, &fp8_blockwise_scaled_grouped_mm);
|
||||
m.def(
|
||||
"prepare_moe_input(Tensor topk_ids, Tensor expert_offsets, Tensor? blockscale_offsets, Tensor problem_sizes1,"
|
||||
" Tensor problem_sizes2, Tensor input_permutation, Tensor output_permutation, int num_experts, int n, int k) -> "
|
||||
"()");
|
||||
m.impl("prepare_moe_input", torch::kCUDA, &prepare_moe_input);
|
||||
|
||||
m.def("shuffle_rows(Tensor input, Tensor dst2src_map, Tensor output) -> ()");
|
||||
m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows);
|
||||
m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()");
|
||||
m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum);
|
||||
|
||||
/*
|
||||
* From csrc/moe/marlin_moe_wna16
|
||||
*/
|
||||
m.def(
|
||||
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
||||
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
|
||||
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
|
||||
"Tensor sorted_token_ids,"
|
||||
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
|
||||
"Tensor! topk_weights, int moe_block_size, int top_k, "
|
||||
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
|
||||
"int size_m, int size_n, int size_k,"
|
||||
"bool is_k_full, bool use_atomic_add,"
|
||||
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
||||
m.impl("moe_wna16_marlin_gemm", torch::kCUDA, &moe_wna16_marlin_gemm);
|
||||
|
||||
/*
|
||||
* From csrc/moe/cutlass_moe/w4a8
|
||||
*/
|
||||
m.def(
|
||||
"get_cutlass_w4a8_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
|
||||
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
|
||||
" Tensor! input_permutation, "
|
||||
" Tensor! output_permutation, int num_experts, "
|
||||
" int n, int k) -> ()");
|
||||
m.impl("get_cutlass_w4a8_moe_mm_data", torch::kCUDA, &get_cutlass_w4a8_moe_mm_data);
|
||||
|
||||
m.def(
|
||||
"cutlass_w4a8_moe_mm(Tensor! d, Tensor a, Tensor b, "
|
||||
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
|
||||
" Tensor problem_sizes, Tensor a_strides, "
|
||||
" Tensor b_strides, Tensor d_strides, Tensor s_strides,"
|
||||
" int chunk_size, int topk) -> ()");
|
||||
m.impl("cutlass_w4a8_moe_mm", torch::kCUDA, &cutlass_w4a8_moe_mm);
|
||||
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
m.def(
|
||||
"tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
|
||||
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
|
||||
"Tensor uniform_samples, Tensor uniform_samples_for_final_sampling, Tensor target_probs, Tensor draft_probs, "
|
||||
"float threshold_single, float threshold_acc, "
|
||||
"bool deterministic, int cuda_stream) -> ()");
|
||||
m.impl("tree_speculative_sampling_target_only", torch::kCUDA, &tree_speculative_sampling_target_only);
|
||||
|
||||
m.def(
|
||||
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
|
||||
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
|
||||
"Tensor target_predict, int cuda_stream) -> ()");
|
||||
m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy);
|
||||
|
||||
m.def(
|
||||
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
|
||||
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
|
||||
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num, int tree_mask_mode) -> "
|
||||
"()");
|
||||
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
|
||||
|
||||
m.def(
|
||||
"segment_packbits(Tensor x, Tensor input_indptr, Tensor output_indptr, Tensor! y, int batch_size, "
|
||||
"int cuda_stream) -> ()");
|
||||
m.impl("segment_packbits", torch::kCUDA, &segment_packbits);
|
||||
|
||||
/*
|
||||
* From csrc/kvcacheio
|
||||
*/
|
||||
m.def(
|
||||
"transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
|
||||
"dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_per_layer", torch::kCUDA, &transfer_kv_per_layer);
|
||||
m.def(
|
||||
"transfer_kv_per_layer_pf_lf(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
|
||||
"dst_indices, int layer_id, int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_per_layer_pf_lf", torch::kCUDA, &transfer_kv_per_layer_pf_lf);
|
||||
m.def(
|
||||
"transfer_kv_all_layer(Tensor src_k_layers, Tensor dst_k_layers, Tensor src_v_layers, Tensor dst_v_layers, "
|
||||
"Tensor src_indices, Tensor dst_indices, int item_size, int num_layers, int block_quota, int "
|
||||
"num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_all_layer", torch::kCUDA, &transfer_kv_all_layer);
|
||||
m.def(
|
||||
"transfer_kv_all_layer_lf_pf(Tensor src_k_layers, Tensor dst_k, Tensor src_v_layers, Tensor dst_v, "
|
||||
"Tensor src_indices, Tensor dst_indices, int item_size, int dst_layout_dim, int num_layers, int block_quota, int "
|
||||
"num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_all_layer_lf_pf", torch::kCUDA, &transfer_kv_all_layer_lf_pf);
|
||||
m.def(
|
||||
"transfer_kv_per_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int "
|
||||
"block_quota, int num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_per_layer_mla", torch::kCUDA, &transfer_kv_per_layer_mla);
|
||||
m.def(
|
||||
"transfer_kv_per_layer_mla_pf_lf(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int layer_id, "
|
||||
"int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_per_layer_mla_pf_lf", torch::kCUDA, &transfer_kv_per_layer_mla_pf_lf);
|
||||
m.def(
|
||||
"transfer_kv_all_layer_mla(Tensor src_layers, Tensor dst_layers, Tensor src_indices, Tensor dst_indices, int "
|
||||
"item_size, int num_layers, int block_quota, int num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_all_layer_mla", torch::kCUDA, &transfer_kv_all_layer_mla);
|
||||
m.def(
|
||||
"transfer_kv_all_layer_mla_lf_pf(Tensor src_layers, Tensor dst, Tensor src_indices, Tensor dst_indices, "
|
||||
"int item_size, int dst_layout_dim, int num_layers, int block_quota, int num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_all_layer_mla_lf_pf", torch::kCUDA, &transfer_kv_all_layer_mla_lf_pf);
|
||||
m.def(
|
||||
"transfer_kv_direct(Tensor[] src_layers, Tensor[] dst_layers, Tensor src_indices, Tensor dst_indices, int "
|
||||
"page_size) -> ()");
|
||||
m.impl("transfer_kv_direct", torch::kCUDA, &transfer_kv_direct);
|
||||
|
||||
/*
|
||||
* From csrc/memory
|
||||
*/
|
||||
m.def("store_kv_cache(Tensor k_cache, Tensor v_cache, Tensor out_loc, Tensor k, Tensor v) -> ()");
|
||||
m.impl("store_kv_cache", &store_kv_cache);
|
||||
|
||||
/*
|
||||
* From FlashInfer
|
||||
*/
|
||||
m.def(
|
||||
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int "
|
||||
"cublas_handle, int cuda_stream) -> ()",
|
||||
{at::Tag::needs_fixed_stride_order});
|
||||
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
|
||||
|
||||
m.def(
|
||||
"min_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? maybe_min_p_arr, float "
|
||||
"min_p_val, bool deterministic, Generator? gen) -> ()");
|
||||
m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs);
|
||||
|
||||
m.def("top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val) -> ()");
|
||||
m.impl("top_k_renorm_probs", torch::kCUDA, &top_k_renorm_probs);
|
||||
|
||||
m.def("top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val) -> ()");
|
||||
m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs);
|
||||
|
||||
m.def(
|
||||
"top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? "
|
||||
"maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
|
||||
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
|
||||
|
||||
m.def(
|
||||
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? maybe_top_k_arr, "
|
||||
"float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
|
||||
m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs);
|
||||
|
||||
m.def("top_k_mask_logits(Tensor logits, Tensor mask_logits, Tensor? maybe_top_k_arr, int top_k_val) -> ()");
|
||||
m.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits);
|
||||
|
||||
/*
|
||||
* From Sparse Flash Attention
|
||||
*/
|
||||
m.def(
|
||||
"fwd_sparse(Tensor! q, Tensor k, Tensor v, "
|
||||
"Tensor block_count, Tensor block_offset, Tensor column_count, Tensor column_index, "
|
||||
"Tensor!? out, Tensor? alibi_slopes, "
|
||||
"float p_dropout, float softmax_scale, bool is_causal, "
|
||||
"float softcap, bool return_softmax, Generator? gen)"
|
||||
"-> Tensor[]");
|
||||
m.impl("fwd_sparse", torch::kCUDA, &flash::mha_fwd_sparse);
|
||||
|
||||
m.def(
|
||||
"varlen_fwd_sparse(Tensor! q, Tensor k, Tensor v, "
|
||||
"Tensor block_count, Tensor block_offset, Tensor column_count, Tensor column_index, "
|
||||
"Tensor!? out, Tensor cu_seqlens_q, "
|
||||
"Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? alibi_slopes, "
|
||||
"int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, "
|
||||
"bool is_causal, float softcap, bool return_softmax, "
|
||||
"Generator? gen) -> Tensor[]");
|
||||
m.impl("varlen_fwd_sparse", torch::kCUDA, &flash::mha_varlen_fwd_sparse);
|
||||
|
||||
// Sparse Attention utils
|
||||
m.def(
|
||||
"convert_vertical_slash_indexes("
|
||||
" Tensor! block_count, Tensor! block_offset, "
|
||||
" Tensor! column_count, Tensor! column_index, "
|
||||
" Tensor q_seqlens, Tensor q_seqlens, "
|
||||
" Tensor vertical_indexes, Tensor slash_indexes, "
|
||||
" int context_size, int block_size_M, int block_size_N, "
|
||||
" bool causal) -> ()");
|
||||
m.impl("convert_vertical_slash_indexes", torch::kCUDA, &convert_vertical_slash_indexes);
|
||||
|
||||
m.def(
|
||||
"convert_vertical_slash_indexes_mergehead("
|
||||
" Tensor! block_count, Tensor! block_offset, "
|
||||
" Tensor! column_count, Tensor! column_index, "
|
||||
" Tensor q_seqlens, Tensor q_seqlens, "
|
||||
" Tensor vertical_indexes, Tensor slash_indexes, "
|
||||
" Tensor vertical_indices_count, Tensor slash_indices_count, "
|
||||
" int context_size, int block_size_M, int block_size_N, "
|
||||
" bool causal) -> ()");
|
||||
m.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA, &convert_vertical_slash_indexes_mergehead);
|
||||
|
||||
/*
|
||||
* From csrc/grammar
|
||||
*/
|
||||
m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()");
|
||||
m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace);
|
||||
|
||||
/*
|
||||
* From csrc/gemm (QServe)
|
||||
*/
|
||||
m.def(
|
||||
"qserve_w4a8_per_chn_gemm(Tensor _in_feats, Tensor _kernel, Tensor _wscales, Tensor _ascales, Tensor _w_szs, "
|
||||
"Tensor _a_ssums, Tensor! _out_feats) -> ()");
|
||||
m.impl("qserve_w4a8_per_chn_gemm", torch::kCUDA, &qserve_w4a8_per_chn_gemm);
|
||||
|
||||
m.def(
|
||||
"qserve_w4a8_per_group_gemm(Tensor _in_feats, Tensor _kernel, Tensor _zeros, Tensor _scales_i8, Tensor _wscales, "
|
||||
"Tensor _ascales, Tensor! _out_feats) -> ()");
|
||||
m.impl("qserve_w4a8_per_group_gemm", torch::kCUDA, &qserve_w4a8_per_group_gemm);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(common_ops)
|
||||
168
sgl-kernel/csrc/common_extension_rocm.cc
Normal file
168
sgl-kernel/csrc/common_extension_rocm.cc
Normal file
@@ -0,0 +1,168 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "sgl_kernel_ops.h"
|
||||
|
||||
TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
/*
|
||||
* From csrc/activation
|
||||
*/
|
||||
m.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
|
||||
m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
|
||||
|
||||
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
|
||||
m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
|
||||
|
||||
m.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
|
||||
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
|
||||
|
||||
m.def("gelu_quick(Tensor! out, Tensor input) -> ()");
|
||||
m.impl("gelu_quick", torch::kCUDA, &gelu_quick);
|
||||
|
||||
/*
|
||||
* From csrc/allreduce
|
||||
*/
|
||||
m.def(
|
||||
"init_custom_ar(Tensor meta, Tensor rank_data, "
|
||||
"str[] handles, int[] offsets, int rank, "
|
||||
"bool full_nvlink) -> int");
|
||||
m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
|
||||
|
||||
m.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
|
||||
m.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);
|
||||
|
||||
m.def(
|
||||
"all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> "
|
||||
"()");
|
||||
m.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);
|
||||
|
||||
m.def("dispose", &dispose);
|
||||
|
||||
m.def("meta_size", &meta_size);
|
||||
|
||||
m.def(
|
||||
"register_buffer(int fa, Tensor t, str[] handles, "
|
||||
"int[] offsets) -> ()");
|
||||
m.impl("register_buffer", torch::kCUDA, ®ister_buffer);
|
||||
|
||||
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
|
||||
m.def("register_graph_buffers", ®ister_graph_buffers);
|
||||
|
||||
m.def("allocate_meta_buffer", &allocate_meta_buffer);
|
||||
m.impl("allocate_meta_buffer", torch::kCUDA, &allocate_meta_buffer);
|
||||
|
||||
m.def("get_meta_buffer_ipc_handle", &get_meta_buffer_ipc_handle);
|
||||
m.impl("get_meta_buffer_ipc_handle", torch::kCPU, &get_meta_buffer_ipc_handle);
|
||||
|
||||
// quick allreduce
|
||||
#ifdef USE_ROCM
|
||||
m.def(
|
||||
"qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool "
|
||||
"cast_bf2half) -> ()");
|
||||
m.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce);
|
||||
|
||||
m.def("init_custom_qr", &init_custom_qr);
|
||||
m.def("qr_destroy", &qr_destroy);
|
||||
|
||||
m.def("qr_get_handle", &qr_get_handle);
|
||||
|
||||
m.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()");
|
||||
m.impl("qr_open_handles", torch::kCPU, &qr_open_handles);
|
||||
|
||||
// Max input size in bytes
|
||||
m.def("qr_max_size", &qr_max_size);
|
||||
#endif
|
||||
|
||||
/*
|
||||
* From csrc/moe
|
||||
*/
|
||||
m.def(
|
||||
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
|
||||
"experts_ids, Tensor! num_tokens_post_pad, Tensor! cumsum_buffer, bool "
|
||||
"pad_sorted_token_ids) -> ()");
|
||||
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||
|
||||
m.def("topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize) -> ()");
|
||||
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
|
||||
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
m.def(
|
||||
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
|
||||
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
|
||||
"Tensor target_predict, int cuda_stream) -> ()");
|
||||
m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy);
|
||||
|
||||
m.def(
|
||||
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
|
||||
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
|
||||
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num, int tree_mask_mode) -> "
|
||||
"()");
|
||||
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
|
||||
|
||||
/*
|
||||
* From XGrammar
|
||||
*/
|
||||
m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()");
|
||||
m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace);
|
||||
|
||||
/*
|
||||
* From csrc/kvcacheio
|
||||
*/
|
||||
m.def(
|
||||
"transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
|
||||
"dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_per_layer", torch::kCUDA, &transfer_kv_per_layer);
|
||||
m.def(
|
||||
"transfer_kv_per_layer_pf_lf(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
|
||||
"dst_indices, int layer_id, int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_per_layer_pf_lf", torch::kCUDA, &transfer_kv_per_layer_pf_lf);
|
||||
m.def(
|
||||
"transfer_kv_all_layer(Tensor src_k_layers, Tensor dst_k_layers, Tensor src_v_layers, Tensor dst_v_layers, "
|
||||
"Tensor src_indices, Tensor dst_indices, int item_size, int num_layers, int block_quota, int "
|
||||
"num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_all_layer", torch::kCUDA, &transfer_kv_all_layer);
|
||||
m.def(
|
||||
"transfer_kv_all_layer_lf_pf(Tensor src_k_layers, Tensor dst_k, Tensor src_v_layers, Tensor dst_v, "
|
||||
"Tensor src_indices, Tensor dst_indices, int item_size, int dst_layout_dim, int num_layers, int block_quota, int "
|
||||
"num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_all_layer_lf_pf", torch::kCUDA, &transfer_kv_all_layer_lf_pf);
|
||||
m.def(
|
||||
"transfer_kv_per_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int "
|
||||
"block_quota, int num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_per_layer_mla", torch::kCUDA, &transfer_kv_per_layer_mla);
|
||||
m.def(
|
||||
"transfer_kv_per_layer_mla_pf_lf(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int layer_id, "
|
||||
"int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_per_layer_mla_pf_lf", torch::kCUDA, &transfer_kv_per_layer_mla_pf_lf);
|
||||
m.def(
|
||||
"transfer_kv_all_layer_mla(Tensor src_layers, Tensor dst_layers, Tensor src_indices, Tensor dst_indices, int "
|
||||
"item_size, int num_layers, int block_quota, int num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_all_layer_mla", torch::kCUDA, &transfer_kv_all_layer_mla);
|
||||
m.def(
|
||||
"transfer_kv_all_layer_mla_lf_pf(Tensor src_layers, Tensor dst, Tensor src_indices, Tensor dst_indices, "
|
||||
"int item_size, int dst_layout_dim, int num_layers, int block_quota, int num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_all_layer_mla_lf_pf", torch::kCUDA, &transfer_kv_all_layer_mla_lf_pf);
|
||||
m.def(
|
||||
"transfer_kv_direct(Tensor[] src_layers, Tensor[] dst_layers, Tensor src_indices, Tensor dst_indices, int "
|
||||
"page_size) -> ()");
|
||||
m.impl("transfer_kv_direct", torch::kCUDA, &transfer_kv_direct);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(common_ops)
|
||||
83
sgl-kernel/csrc/cpu/CMakeLists.txt
Executable file
83
sgl-kernel/csrc/cpu/CMakeLists.txt
Executable file
@@ -0,0 +1,83 @@
|
||||
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
|
||||
project(sgl_kernel)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
|
||||
find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python_EXECUTABLE}
|
||||
-c "import torch; print(torch.utils.cmake_prefix_path)"
|
||||
OUTPUT_VARIABLE TORCH_PY_PREFIX
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
|
||||
message(STATUS ${TORCH_PY_PREFIX})
|
||||
list(APPEND CMAKE_PREFIX_PATH ${TORCH_PY_PREFIX}/Torch)
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
include_directories(
|
||||
${TORCH_INCLUDE_DIRS}
|
||||
${TORCH_INSTALL_PREFIX}/include
|
||||
${Python_INCLUDE_DIRS}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../csrc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../include
|
||||
)
|
||||
|
||||
# Platform-specific library directory
|
||||
if(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|AMD64")
|
||||
set(PLAT_LIB_DIR "/usr/lib/x86_64-linux-gnu")
|
||||
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64")
|
||||
set(PLAT_LIB_DIR "/usr/lib/aarch64-linux-gnu")
|
||||
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "ppc64le|ppc64")
|
||||
set(PLAT_LIB_DIR "/usr/lib/powerpc64le-linux-gnu")
|
||||
else()
|
||||
set(PLAT_LIB_DIR "/usr/lib/${CMAKE_SYSTEM_PROCESSOR}-linux-gnu")
|
||||
endif()
|
||||
link_directories(${PLAT_LIB_DIR})
|
||||
|
||||
# Conda library path support
|
||||
if(DEFINED ENV{CONDA_PREFIX})
|
||||
set(CONDA_LIB_DIR "$ENV{CONDA_PREFIX}/lib")
|
||||
message(STATUS "Using Conda lib dir: ${CONDA_LIB_DIR}")
|
||||
link_directories(${CONDA_LIB_DIR})
|
||||
set(CONDA_INCLUDE_DIR "$ENV{CONDA_PREFIX}/include")
|
||||
include_directories(${CONDA_INCLUDE_DIR})
|
||||
|
||||
# Look for libnuma in Conda's lib directory
|
||||
find_library(NUMA_LIB numa HINTS "${CONDA_LIB_DIR}")
|
||||
if(NUMA_LIB)
|
||||
message(STATUS "Found libnuma: ${NUMA_LIB}")
|
||||
else()
|
||||
message(FATAL_ERROR "libnuma not found in Conda environment at ${CONDA_LIB_DIR}\n"
|
||||
"Please install it using: conda install libnuma numactl\n")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
file(GLOB SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp")
|
||||
|
||||
if(NOT DEFINED ENV{SGLANG_CPU_FP8_CVT_FTZ})
|
||||
set(ENV{SGLANG_CPU_FP8_CVT_FTZ} "1")
|
||||
endif()
|
||||
|
||||
if("$ENV{SGLANG_CPU_FP8_CVT_FTZ}" STREQUAL "1")
|
||||
message(STATUS "Enabling macro: SGLANG_CPU_FP8_CVT_FTZ")
|
||||
add_compile_definitions(SGLANG_CPU_FP8_CVT_FTZ)
|
||||
endif()
|
||||
|
||||
add_compile_options(
|
||||
-O3
|
||||
-Wno-unknown-pragmas
|
||||
-march=native
|
||||
-fopenmp
|
||||
)
|
||||
|
||||
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
|
||||
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} ${NUMA_LIB})
|
||||
target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS})
|
||||
|
||||
install(TARGETS common_ops
|
||||
LIBRARY DESTINATION sgl_kernel
|
||||
)
|
||||
79
sgl-kernel/csrc/cpu/activation.cpp
Normal file
79
sgl-kernel/csrc/cpu/activation.cpp
Normal file
@@ -0,0 +1,79 @@
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t, typename func_t, typename vec_func_t>
|
||||
void act_and_mul_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
const scalar_t* __restrict__ input,
|
||||
int64_t num_tokens,
|
||||
int64_t dim,
|
||||
const func_t& f,
|
||||
const vec_func_t& vf) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
|
||||
constexpr int64_t kVecSize = bVec::size();
|
||||
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
// local ptrs
|
||||
const scalar_t* __restrict__ input_ptr = input + i * 2 * dim;
|
||||
const scalar_t* __restrict__ input_other_ptr = input_ptr + dim;
|
||||
scalar_t* __restrict__ output_ptr = output + i * dim;
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= dim - kVecSize; d += kVecSize) {
|
||||
bVec x_bvec = bVec::loadu(input_ptr + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
bVec y_bvec = bVec::loadu(input_other_ptr + d);
|
||||
fVec y_fvec0, y_fvec1;
|
||||
std::tie(y_fvec0, y_fvec1) = at::vec::convert_to_float(y_bvec);
|
||||
|
||||
x_fvec0 = vf(x_fvec0);
|
||||
x_fvec1 = vf(x_fvec1);
|
||||
|
||||
x_fvec0 = x_fvec0 * y_fvec0;
|
||||
x_fvec1 = x_fvec1 * y_fvec1;
|
||||
|
||||
x_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
|
||||
x_bvec.store(output_ptr + d);
|
||||
}
|
||||
#pragma GCC unroll 4
|
||||
for (; d < dim; ++d) {
|
||||
float x_val = static_cast<float>(input_ptr[d]);
|
||||
float y_val = static_cast<float>(input_other_ptr[d]);
|
||||
output_ptr[d] = f(x_val) * y_val;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// input : {num_tokens, 2 * d}
|
||||
// output : {num_tokens, d}
|
||||
at::Tensor silu_and_mul_cpu(at::Tensor& input) {
|
||||
RECORD_FUNCTION("sgl-kernel::silu_and_mul_cpu", std::vector<c10::IValue>({input}));
|
||||
auto sizes = input.sizes().vec();
|
||||
int64_t last_dim = input.ndimension() - 1;
|
||||
int64_t d = sizes[last_dim] / 2;
|
||||
sizes[last_dim] = d;
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
at::Tensor out = at::empty(sizes, input.options());
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "silu_and_mul", [&] {
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
act_and_mul_kernel_impl(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
num_tokens,
|
||||
d,
|
||||
[](float x) { return x / (1.f + std::exp(-x)); },
|
||||
[](Vec x) { return x / (Vec(1.f) + x.neg().exp()); });
|
||||
});
|
||||
return out;
|
||||
}
|
||||
123
sgl-kernel/csrc/cpu/bmm.cpp
Normal file
123
sgl-kernel/csrc/cpu/bmm.cpp
Normal file
@@ -0,0 +1,123 @@
|
||||
#include "common.h"
|
||||
#include "gemm.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
void bmm_kernel_impl(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ mat1,
|
||||
const scalar_t* __restrict__ mat2,
|
||||
int64_t B,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t mat1_strideB,
|
||||
int64_t mat1_strideM,
|
||||
int64_t out_strideB,
|
||||
int64_t out_strideM,
|
||||
float scale = 0.f) {
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
// mat2 contiguous in [B, N, K]
|
||||
int64_t mat2_strideB = N * K;
|
||||
int64_t mat2_strideN = K;
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<scalar_t>(M);
|
||||
|
||||
// parallel on [B, MB, NB]
|
||||
at::parallel_for(0, B * MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t bs{0}, mb{0}, nb{0};
|
||||
data_index_init(begin, bs, B, mb, MB, nb, NB);
|
||||
|
||||
// for brgemm, use float32 for accumulate
|
||||
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
for (int i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int nb_start = nb * BLOCK_N;
|
||||
int nb_size = std::min(N - nb_start, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ mat1 + bs * mat1_strideB + mb_start * mat1_strideM,
|
||||
/* B */ mat2 + bs * mat2_strideB + nb_start * mat2_strideN /* nb * BLOCK_N * K */,
|
||||
/* C */ out + bs * out_strideB + mb_start * out_strideM + nb_start,
|
||||
/* Ctmp*/ Ctmp,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ mat1_strideM,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ out_strideM,
|
||||
/* brg */ use_brgemm);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(bs, B, mb, MB, nb, NB);
|
||||
}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// mat1 : [B, M, K]
|
||||
// mat2 : [B, N, K] or [B, OC, IC]
|
||||
// out : [B, M, N]
|
||||
// scale: [] 0-dim tensor for per tensor quant
|
||||
//
|
||||
void bmm_cpu(
|
||||
at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale) {
|
||||
RECORD_FUNCTION("sgl-kernel::bmm_cpu", std::vector<c10::IValue>({out, mat1, mat2}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
// input and out could be non-contiguous
|
||||
// weight needs to be contiguous in [OC, IC] order
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(out);
|
||||
CHECK_INPUT(mat2);
|
||||
CHECK_DIM(3, out);
|
||||
CHECK_DIM(3, mat1);
|
||||
CHECK_DIM(3, mat2);
|
||||
|
||||
int64_t B = mat1.size(0);
|
||||
int64_t M = mat1.size(1);
|
||||
int64_t N = mat2.size(1);
|
||||
int64_t K = mat1.size(2);
|
||||
|
||||
TORCH_CHECK(!scale.has_value(), "bmm: do not support fp8 weight for now.")
|
||||
TORCH_CHECK(N % 32 == 0, "tinygemm requires N to be 32x.");
|
||||
|
||||
int64_t mat1_strideB = mat1.stride(0);
|
||||
int64_t mat1_strideM = mat1.stride(1);
|
||||
int64_t out_strideB = out.stride(0);
|
||||
int64_t out_strideM = out.stride(1);
|
||||
|
||||
// check shapes
|
||||
TORCH_CHECK(mat2.size(0) == B && mat2.size(2) == K, "bmm: mat2 shape mismatch!");
|
||||
TORCH_CHECK(out.size(0) == B && out.size(1) == M, "bmm: out shape mismatch!");
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "bmm_kernel_impl", [&] {
|
||||
bmm_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
mat1.data_ptr<scalar_t>(),
|
||||
packed_w.data_ptr<scalar_t>(),
|
||||
B,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
mat1_strideB,
|
||||
mat1_strideM,
|
||||
out_strideB,
|
||||
out_strideM);
|
||||
});
|
||||
}
|
||||
324
sgl-kernel/csrc/cpu/common.h
Normal file
324
sgl-kernel/csrc/cpu/common.h
Normal file
@@ -0,0 +1,324 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/record_function.h>
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
// dispatch bool
|
||||
#define AT_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \
|
||||
[&] { \
|
||||
if (BOOL_V) { \
|
||||
constexpr bool BOOL_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool BOOL_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
// dispatch: bfloat16, float16, int8_t, fp8_e4m3
|
||||
#define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \
|
||||
[&] { \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using packed_t = at::BFloat16; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using packed_t = at::Half; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Char: { \
|
||||
using packed_t = int8_t; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Float8_e4m3fn: { \
|
||||
using packed_t = at::Float8_e4m3fn; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
|
||||
} \
|
||||
}()
|
||||
|
||||
// dispatch with mixed dtypes (TYPE1, TYPE2):
|
||||
// TYPE1: the primary dtype (input, output, weight);
|
||||
// TYPE2: the secondary dtype (bias, etc.).
|
||||
#define CPU_DISPATCH_REDUCED_FLOATING_TYPES_EXT(TYPE1, TYPE2, ...) \
|
||||
[&] { \
|
||||
if (TYPE2 == at::kFloat) { \
|
||||
switch (TYPE1) { \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using scalar_t = at::BFloat16; \
|
||||
using param_t = float; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t = at::Half; \
|
||||
using param_t = float; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
|
||||
} \
|
||||
} else { \
|
||||
TORCH_CHECK(TYPE1 == TYPE2); \
|
||||
switch (TYPE1) { \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using scalar_t = at::BFloat16; \
|
||||
using param_t = at::BFloat16; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t = at::Half; \
|
||||
using param_t = at::Half; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
|
||||
} \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define UNUSED(x) (void)(x)
|
||||
|
||||
#define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor")
|
||||
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension")
|
||||
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CPU(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
|
||||
CHECK_CPU(x); \
|
||||
CHECK_LAST_DIM_CONTIGUOUS(x)
|
||||
|
||||
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
|
||||
|
||||
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
|
||||
// [NB] Parallel Routines
|
||||
//
|
||||
// * at::parallel_for - applies for most of generic use cases, this will be compiled
|
||||
// against openmp in default torch release.
|
||||
//
|
||||
// * parallel_for - same function as above, can choose payload partition scheme in
|
||||
// balance211.
|
||||
//
|
||||
// * parallel_2d - parallel for 2 dimensions, used in GEMM, etc.
|
||||
// this one will do payload balance across 2 dimensions.
|
||||
//
|
||||
|
||||
// grain size for each thread
|
||||
constexpr int GRAIN_SIZE = 1024;
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
|
||||
inline T div_up(T x, T y) {
|
||||
return (x + y - 1) / y;
|
||||
}
|
||||
|
||||
// you can only use at::get_thread_num() with at::parallel_for()
|
||||
// as it is lazy initialized, otherwise it will always return 0.
|
||||
inline int get_thread_num() {
|
||||
#if defined(_OPENMP)
|
||||
return omp_get_thread_num();
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
// balance payload across each thread
|
||||
template <typename T>
|
||||
inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
|
||||
#if 0
|
||||
// onednn partition pattern
|
||||
T& n_my = n_end;
|
||||
if (nth <= 1 || n == 0) {
|
||||
n_start = 0;
|
||||
n_my = n;
|
||||
} else {
|
||||
T n1 = div_up(n, nth);
|
||||
T n2 = n1 - 1;
|
||||
T T1 = n - n2 * nth;
|
||||
n_my = ith < T1 ? n1 : n2;
|
||||
n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2;
|
||||
}
|
||||
n_end += n_start;
|
||||
#else
|
||||
// pytorch aten partition pattern
|
||||
T n_my = div_up(n, nth);
|
||||
n_start = ith * n_my;
|
||||
n_end = std::min(n_start + n_my, n);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
inline void parallel_for(int n, const func_t& f) {
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel
|
||||
{
|
||||
int nth = omp_get_num_threads();
|
||||
int ith = omp_get_thread_num();
|
||||
int tbegin, tend;
|
||||
balance211(n, nth, ith, tbegin, tend);
|
||||
f(tbegin, tend);
|
||||
}
|
||||
#else
|
||||
f(0, n);
|
||||
#endif
|
||||
}
|
||||
|
||||
// for 1d parallel, use `actual_nth`
|
||||
// for 2d parallel, use even nths, e.g. 43->42
|
||||
int inline adjust_num_threads(int m) {
|
||||
int actual_nth = at::get_num_threads();
|
||||
if (m == 1) {
|
||||
return actual_nth;
|
||||
}
|
||||
return std::max(1, (actual_nth >> 1) * 2);
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
inline void parallel_2d(int m, int n, const func_t& f) {
|
||||
// make sure we have even num_threads
|
||||
int nth = adjust_num_threads(m);
|
||||
|
||||
// [NOTE] thread blocking:
|
||||
//
|
||||
// 1) prefer square block per thread
|
||||
// 2) use even number of CPU cores
|
||||
// 3) use all `num_threads` cores
|
||||
//
|
||||
// we have:
|
||||
// TM * TN = T
|
||||
// BM / TM = BN / TN
|
||||
// then:
|
||||
// TM = ((BM / BN) * T) ^ 0.5
|
||||
//
|
||||
float r = float(m) / n;
|
||||
int nth_m = std::ceil(std::sqrt(r * nth));
|
||||
int nth_n = 1;
|
||||
for (; nth_m > 0; --nth_m) {
|
||||
nth_n = nth / nth_m;
|
||||
if (nth_m * nth_n == nth) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel num_threads(nth)
|
||||
{
|
||||
int ith = omp_get_thread_num();
|
||||
int ith_m = ith / nth_n;
|
||||
int ith_n = ith % nth_n;
|
||||
|
||||
int thread_block_m = div_up(m, nth_m);
|
||||
int thread_block_n = div_up(n, nth_n);
|
||||
|
||||
int begin_m = ith_m * thread_block_m;
|
||||
int end_m = std::min(m, begin_m + thread_block_m);
|
||||
int begin_n = ith_n * thread_block_n;
|
||||
int end_n = std::min(n, begin_n + thread_block_n);
|
||||
|
||||
f(begin_m, end_m, begin_n, end_n);
|
||||
}
|
||||
#else
|
||||
f(0, m, 0, n);
|
||||
#endif
|
||||
}
|
||||
|
||||
// limit max cache blocks
|
||||
// when we need to do pre-unpack for weights, e.g. fp8
|
||||
#define MAX_CACHE_BLOCK_SIZE 4
|
||||
|
||||
template <typename T>
|
||||
inline int get_cache_blocks(int chunk_size) {
|
||||
// L2 2MB and ratio of 50%
|
||||
const int L2_size = 2048 * 1024 >> 1;
|
||||
return std::max(1, int(L2_size / (chunk_size * sizeof(T))));
|
||||
}
|
||||
|
||||
template <>
|
||||
inline int get_cache_blocks<at::Float8_e4m3fn>(int chunk_size) {
|
||||
// fp8 uses bf16 as accumulate type
|
||||
int cache_block_size = get_cache_blocks<at::BFloat16>(chunk_size);
|
||||
return std::min(MAX_CACHE_BLOCK_SIZE, cache_block_size);
|
||||
}
|
||||
|
||||
// 2d sequential loop in range : [mb0, mb1), [nb0, nb1)
|
||||
template <typename T, typename func_t>
|
||||
inline void loop_2d(int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1, int64_t chunk_size, const func_t& f) {
|
||||
// get number of blocks for L2 in most inner loop
|
||||
int64_t cache_blocks_nb = get_cache_blocks<T>(chunk_size);
|
||||
|
||||
// loop order: [NB / cache_blocks_nb, MB, cache_blocks_nb]
|
||||
// TODO: implement reverse order of [MB / cache_blocks_mb, NB, cache_blocks_mb]
|
||||
for (int64_t nbb = nb0; nbb < nb1; nbb += cache_blocks_nb) {
|
||||
for (int64_t mb = mb0; mb < mb1; ++mb) {
|
||||
for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, nb1); ++nb) {
|
||||
f(mb, nb, nb - nbb);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// data indexing for dimension collapse
|
||||
template <typename T>
|
||||
inline T data_index_init(T offset) {
|
||||
return offset;
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
|
||||
offset = data_index_init(offset, std::forward<Args>(args)...);
|
||||
x = offset % X;
|
||||
return offset / X;
|
||||
}
|
||||
|
||||
inline bool data_index_step() {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
inline bool data_index_step(T& x, const T& X, Args&&... args) {
|
||||
if (data_index_step(std::forward<Args>(args)...)) {
|
||||
x = ((x + 1) == X) ? 0 : (x + 1);
|
||||
return x == 0;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// forced unroll for perf critical path
|
||||
|
||||
#if __has_attribute(always_inline)
|
||||
#define ALWAYS_INLINE __attribute__((__always_inline__)) inline
|
||||
#else
|
||||
#define ALWAYS_INLINE inline
|
||||
#endif
|
||||
|
||||
template <int n>
|
||||
struct Unroll {
|
||||
template <typename Func, typename... Args>
|
||||
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
|
||||
Unroll<n - 1>{}(f, args...);
|
||||
f(std::integral_constant<int, n - 1>{}, args...);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Unroll<1> {
|
||||
template <typename Func, typename... Args>
|
||||
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
|
||||
f(std::integral_constant<int, 0>{}, args...);
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
1575
sgl-kernel/csrc/cpu/decode.cpp
Normal file
1575
sgl-kernel/csrc/cpu/decode.cpp
Normal file
File diff suppressed because it is too large
Load Diff
723
sgl-kernel/csrc/cpu/extend.cpp
Normal file
723
sgl-kernel/csrc/cpu/extend.cpp
Normal file
@@ -0,0 +1,723 @@
|
||||
#include "common.h"
|
||||
#include "gemm.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
// [NOTE]: extend attention for CPU
|
||||
// 1. tune BLOCK_M and BLOCK_N
|
||||
// 2. can handle non-contiguous k_exttend and v_extend
|
||||
// 3. computes attention for prefix and extend separately
|
||||
// 4. TODO: vectorize `pack_vnni` and `pack_vnni2`
|
||||
//
|
||||
|
||||
template <typename index_t>
|
||||
inline index_t get_index(index_t* ind, int i) {
|
||||
return (ind == nullptr) ? (index_t)i : ind[i];
|
||||
}
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
// key: from [N, 32] to [32/2, N, 2]
|
||||
template <typename scalar_t, typename index_t>
|
||||
inline void pack_vnni_Nx32(
|
||||
scalar_t* __restrict__ dst,
|
||||
const scalar_t* __restrict__ src,
|
||||
const index_t* __restrict__ ind,
|
||||
int N,
|
||||
int ld_src,
|
||||
int ld_dst) {
|
||||
__m512i vinputs[16];
|
||||
|
||||
int n = 0;
|
||||
for (; n < N; ++n) {
|
||||
index_t index = get_index(ind, n);
|
||||
vinputs[n] = _mm512_loadu_si512(src + index * ld_src);
|
||||
}
|
||||
// padding with zero to avoid uninitialized vectors
|
||||
for (; n < 16; ++n) {
|
||||
vinputs[n] = _mm512_set1_epi32(0);
|
||||
}
|
||||
|
||||
// pack key
|
||||
transpose_16x16_32bit(vinputs);
|
||||
|
||||
const __mmask16 vmask = (1 << N) - 1;
|
||||
for (int k = 0; k < 16; ++k) {
|
||||
_mm512_mask_storeu_epi32(dst + k * ld_dst * 2, vmask, vinputs[k]);
|
||||
}
|
||||
}
|
||||
|
||||
// value: from [K, 32] to [K/2, 32, 2]
|
||||
template <typename scalar_t, typename index_t>
|
||||
inline void pack_vnni_Kx32(
|
||||
scalar_t* __restrict__ dst,
|
||||
const scalar_t* __restrict__ src,
|
||||
const index_t* __restrict__ ind,
|
||||
int K,
|
||||
int ld_src,
|
||||
int ld_dst) {
|
||||
__m512i vinputs[2];
|
||||
|
||||
int k = 0;
|
||||
for (; k < K; ++k) {
|
||||
index_t index = get_index(ind, k);
|
||||
vinputs[k] = _mm512_loadu_si512(src + index * ld_src);
|
||||
}
|
||||
// padding with zero to avoid uninitialized vectors
|
||||
for (; k < 2; ++k) {
|
||||
vinputs[k] = _mm512_set1_epi32(0);
|
||||
}
|
||||
|
||||
// pack value
|
||||
__m512i d0, d1;
|
||||
std::tie(d0, d1) = transpose_2x32_16bit(vinputs[0], vinputs[1]);
|
||||
_mm512_storeu_si512(dst + 0 * ld_dst * 2, d0);
|
||||
_mm512_storeu_si512(dst + 0 * ld_dst * 2 + 32, d1);
|
||||
}
|
||||
#endif
|
||||
|
||||
// convert to vnni format
|
||||
// from [N, K/2, 2] to [K/2, N, 2] for bfloat16 and float16
|
||||
template <typename scalar_t, typename index_t>
|
||||
void pack_vnni(
|
||||
scalar_t* __restrict__ dst,
|
||||
const scalar_t* __restrict__ src,
|
||||
const index_t* __restrict__ ind,
|
||||
int N,
|
||||
int K,
|
||||
int ld_src,
|
||||
int ld_dst) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
const int NB = div_up(N, 16);
|
||||
const int KB = K / 32; // no remainder
|
||||
const bool is_indexed = ind != nullptr;
|
||||
|
||||
for (int nb = 0; nb < NB; ++nb) {
|
||||
for (int kb = 0; kb < KB; ++kb) {
|
||||
// handle 16x512bits each block
|
||||
int nb_size = std::min(N - nb * 16, 16);
|
||||
pack_vnni_Nx32<scalar_t, index_t>(
|
||||
/* dst */ dst + ((kb * 32) >> 1) * ld_dst * 2 + nb * 16 * 2,
|
||||
/* src */ src + kb * 32 + (is_indexed ? 0 : nb * 16 * ld_src),
|
||||
/* ind */ is_indexed ? ind + nb * 16 : nullptr,
|
||||
/* N */ nb_size,
|
||||
/* ld_src */ ld_src,
|
||||
/* ld_dst */ ld_dst);
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (int n = 0; n < N; ++n) {
|
||||
index_t index = get_index(ind, n);
|
||||
for (int k = 0; k < K / 2; ++k) {
|
||||
for (int d = 0; d < 2; ++d) {
|
||||
dst[k * ld_dst * 2 + n * 2 + d] = src[index * ld_src + k * 2 + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert to vnni format
|
||||
// from [K/2, 2, N] to [K/2, N, 2] for bfloat16 and float16
|
||||
template <typename scalar_t, typename index_t>
|
||||
void pack_vnni2(
|
||||
scalar_t* __restrict__ dst,
|
||||
const scalar_t* __restrict__ src,
|
||||
const index_t* __restrict__ ind,
|
||||
int K,
|
||||
int N,
|
||||
int ld_src,
|
||||
int ld_dst) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
const int KB = div_up(K, 2);
|
||||
const int NB = N / 32; // no remainder
|
||||
const bool is_indexed = ind != nullptr;
|
||||
|
||||
for (int kb = 0; kb < KB; ++kb) {
|
||||
for (int nb = 0; nb < NB; ++nb) {
|
||||
// handle 2x512bits each block
|
||||
int kb_size = std::min(K - kb * 2, 2);
|
||||
pack_vnni_Kx32<scalar_t, index_t>(
|
||||
/* dst */ dst + ((kb * 2) >> 1) * ld_dst * 2 + nb * 32 * 2,
|
||||
/* src */ src + (is_indexed ? 0 : kb * 2 * ld_src) + nb * 32,
|
||||
/* ind */ is_indexed ? ind + kb * 2 : nullptr,
|
||||
/* K */ kb_size,
|
||||
/* ld_src */ ld_src,
|
||||
/* ld_dst */ ld_dst);
|
||||
}
|
||||
}
|
||||
#else
|
||||
int k = 0;
|
||||
for (; k < (K >> 1) * 2; k += 2) {
|
||||
index_t index0 = get_index(ind, k + 0);
|
||||
index_t index1 = get_index(ind, k + 1);
|
||||
for (int n = 0; n < N; ++n) {
|
||||
dst[(k >> 1) * ld_dst * 2 + n * 2 + 0] = src[index0 * ld_src + n];
|
||||
dst[(k >> 1) * ld_dst * 2 + n * 2 + 1] = src[index1 * ld_src + n];
|
||||
}
|
||||
}
|
||||
if (K % 2 != 0) {
|
||||
index_t index = get_index(ind, K - 1);
|
||||
for (int n = 0; n < N; ++n) {
|
||||
dst[(K >> 1) * ld_dst * 2 + n * 2 + 0] = src[index * ld_src + n];
|
||||
dst[(K >> 1) * ld_dst * 2 + n * 2 + 1] = 0;
|
||||
}
|
||||
k += 2;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void fill_stub(scalar_t* __restrict__ out, float val, int size) {
|
||||
using Vec = at::vec::Vectorized<scalar_t>;
|
||||
constexpr int kVecSize = Vec::size();
|
||||
const Vec data_vec = Vec(static_cast<scalar_t>(val));
|
||||
int d = 0;
|
||||
#pragma GCC unroll 4
|
||||
for (; d <= size - kVecSize; d += kVecSize) {
|
||||
data_vec.store(out + d);
|
||||
}
|
||||
if (size - d > 0) {
|
||||
data_vec.store(out + d, size - d);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int BLOCK_N>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input) {
|
||||
static_assert(BLOCK_N % 32 == 0);
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
auto store = [&](auto i) {
|
||||
constexpr int col = i % COLS;
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
fVec a_fvec0 = fVec::loadu(input + col * 16);
|
||||
fVec a_fvec1 = fVec::loadu(input + col * 16 + 16);
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(a_fvec0, a_fvec1);
|
||||
out_bvec.store(out + col * 16);
|
||||
}
|
||||
};
|
||||
Unroll<COLS>{}(store);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc, float s, int size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec s_fvec = fVec(s);
|
||||
int d = 0;
|
||||
#pragma GCC unroll 4
|
||||
for (; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec a_fvec0 = fVec::loadu(acc + d) * s_fvec;
|
||||
fVec a_fvec1 = fVec::loadu(acc + d + fVec::size()) * s_fvec;
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(a_fvec0, a_fvec1);
|
||||
out_bvec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(acc[d] * s);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename index_t, int BLOCK_M, int BLOCK_N>
|
||||
void extend_attention_kernel_impl(
|
||||
scalar_t* __restrict__ o_extend,
|
||||
const scalar_t* __restrict__ q_extend,
|
||||
const scalar_t* __restrict__ k_extend,
|
||||
const scalar_t* __restrict__ v_extend,
|
||||
const scalar_t* __restrict__ k_buffer,
|
||||
const scalar_t* __restrict__ v_buffer,
|
||||
const index_t* __restrict__ req_to_token,
|
||||
const int64_t* __restrict__ req_pool_indices,
|
||||
const int64_t* __restrict__ seq_lens,
|
||||
const index_t* __restrict__ extend_seq_lens,
|
||||
const index_t* __restrict__ extend_start_loc,
|
||||
const void* __restrict__ buffer,
|
||||
int batches,
|
||||
int num_heads,
|
||||
int num_heads_kv,
|
||||
int head_size,
|
||||
int head_size_v,
|
||||
int q_strideM,
|
||||
int q_strideH,
|
||||
int ke_strideN,
|
||||
int ke_strideH,
|
||||
int ve_strideN,
|
||||
int ve_strideH,
|
||||
int k_strideN,
|
||||
int k_strideH,
|
||||
int v_strideN,
|
||||
int v_strideH,
|
||||
float scaling,
|
||||
float logit_cap,
|
||||
int max_num_reqs,
|
||||
int max_context_len,
|
||||
int max_total_num_tokens,
|
||||
int max_len_extend,
|
||||
int buffer_size_per_thread,
|
||||
bool is_prefix_skipped) {
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
|
||||
// strides
|
||||
const int o_strideM = num_heads * head_size_v;
|
||||
const int o_strideH = head_size_v;
|
||||
|
||||
// we use same buffer for packed key and value
|
||||
const int ldb_tmp = std::max(head_size, head_size_v);
|
||||
|
||||
const bool has_logit_cap = logit_cap > 0;
|
||||
float rlogit_cap = has_logit_cap ? 1 / logit_cap : 0.f;
|
||||
|
||||
const int num_groups = num_heads / num_heads_kv;
|
||||
TORCH_CHECK(num_groups * num_heads_kv == num_heads);
|
||||
|
||||
// number of blocks along M
|
||||
int MB = div_up(max_len_extend, BLOCK_M);
|
||||
|
||||
// parallel on [batches, num_heads, BM]
|
||||
at::parallel_for(0, batches * num_heads * MB, 0, [&](int begin, int end) {
|
||||
int bs{0}, head_id{0}, mb{0};
|
||||
data_index_init(begin, bs, batches, head_id, num_heads, mb, MB);
|
||||
|
||||
int tid = at::get_thread_num();
|
||||
// s_i and s_delta: [BLOCK_M, BLOCK_N]
|
||||
float* __restrict__ s_i = reinterpret_cast<float*>((char*)(buffer) + tid * buffer_size_per_thread);
|
||||
float* __restrict__ s_delta = s_i;
|
||||
|
||||
// v_prime: [BLOCK_M, head_size_v]
|
||||
float* __restrict__ v_prime = s_i + BLOCK_M * BLOCK_N;
|
||||
|
||||
// s_delta2: [BLOCK_M, BLOCK_N]; copy of s_delta in scalar_t
|
||||
scalar_t* __restrict__ s_delta2 = reinterpret_cast<scalar_t*>(v_prime + BLOCK_N * head_size_v);
|
||||
|
||||
// Btmp: [BLOCK_N, max(head_size, head_size_v)]
|
||||
scalar_t* __restrict__ Btmp = s_delta2 + BLOCK_M * BLOCK_N;
|
||||
|
||||
// init Btmp just once for each thread to prevent NaN
|
||||
fill_stub(Btmp, 0.f, BLOCK_N * ldb_tmp);
|
||||
|
||||
alignas(64) float s_prime[BLOCK_M];
|
||||
alignas(64) float m_prime[BLOCK_M];
|
||||
|
||||
for (int i = begin; i < end; ++i) {
|
||||
// seq_len = prefix + extend
|
||||
int head_kv_id = head_id / num_groups;
|
||||
int seq_len = seq_lens[bs];
|
||||
int seq_len_extend = extend_seq_lens[bs];
|
||||
int seq_len_prefix = seq_len - seq_len_extend;
|
||||
int seq_extend_start_loc = extend_start_loc[bs];
|
||||
|
||||
int req_pool_id = req_pool_indices[bs];
|
||||
TORCH_CHECK(seq_len_prefix >= 0, "prefix len < 0!");
|
||||
TORCH_CHECK(seq_len <= max_context_len, "seq_len out of scope!");
|
||||
TORCH_CHECK(req_pool_id < max_num_reqs, "req_pool_id out of scope!");
|
||||
|
||||
if (is_prefix_skipped) {
|
||||
TORCH_CHECK(seq_len_prefix == 0, "extend attention: expect seq_len_prefix to be 0, got ", seq_len_prefix);
|
||||
}
|
||||
|
||||
// offset and size in MB
|
||||
int m = mb * BLOCK_N;
|
||||
int m_size = std::min(BLOCK_M, seq_len_extend - m);
|
||||
|
||||
if (m_size <= 0) {
|
||||
data_index_step(bs, batches, head_id, num_heads, mb, MB);
|
||||
continue;
|
||||
}
|
||||
|
||||
// get query
|
||||
const scalar_t* __restrict__ q_ptr = q_extend + (seq_extend_start_loc + m) * q_strideM + head_id * q_strideH;
|
||||
|
||||
// init v', s' and m'
|
||||
fill_stub(v_prime, 0.f, m_size * head_size_v);
|
||||
fill_stub(s_prime, 0.f, m_size);
|
||||
fill_stub(m_prime, -std::numeric_limits<scalar_t>::infinity(), m_size);
|
||||
|
||||
// stage 1: compute scores with prefix
|
||||
for (int n = 0; n < seq_len_prefix; n += BLOCK_N) {
|
||||
int n_size = std::min(BLOCK_N, seq_len_prefix - n);
|
||||
|
||||
// `n_size` is K in 2nd gemm, pad to TILE_K;
|
||||
const int padded_n_size = div_up(n_size, TILE_K) * TILE_K;
|
||||
|
||||
// get key and pack
|
||||
pack_vnni<scalar_t, index_t>(
|
||||
/* dst */ Btmp,
|
||||
/* src */ k_buffer + head_kv_id * k_strideH,
|
||||
/* ind */ req_to_token + req_pool_id * max_context_len + n,
|
||||
/* N */ n_size,
|
||||
/* K */ head_size,
|
||||
/* ld_src */ k_strideN,
|
||||
/* ld_dst */ BLOCK_N);
|
||||
|
||||
// calculate s_i <- Q @ K
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ head_size,
|
||||
/* lda */ q_strideM,
|
||||
/* ldb */ BLOCK_N,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* add_C */ false,
|
||||
/* A */ q_ptr,
|
||||
/* B */ Btmp,
|
||||
/* C */ s_i);
|
||||
|
||||
const Vec scale_vec = Vec(scaling);
|
||||
for (int row = 0; row < m_size; ++row) {
|
||||
// s_i <- s_i * scale
|
||||
at::vec::map<float>(
|
||||
[scale_vec](Vec x) { return x * scale_vec; }, s_i + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
|
||||
|
||||
// TODO: `tanh` from torch uses sleef u10, going to be slow
|
||||
if (has_logit_cap) {
|
||||
at::vec::map<float>(
|
||||
[logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); },
|
||||
s_i + row * BLOCK_N,
|
||||
s_i + row * BLOCK_N,
|
||||
n_size);
|
||||
}
|
||||
|
||||
// m_i: max value per row
|
||||
float m_i = at::vec::reduce_all<float>(
|
||||
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i + row * BLOCK_N, n_size);
|
||||
m_i = std::max(m_i, m_prime[row]);
|
||||
|
||||
// m_delta <- exp(m' - m_i)
|
||||
float m_delta = std::exp(m_prime[row] - m_i);
|
||||
|
||||
// s_delta <- exp(s_i - m_i)
|
||||
at::vec::map<float>(
|
||||
[m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
|
||||
|
||||
// s' <- s' * m_delta + sum(s_delta)
|
||||
s_prime[row] *= m_delta;
|
||||
s_prime[row] +=
|
||||
at::vec::reduce_all<float>([](Vec& x, Vec& y) { return x + y; }, s_delta + row * BLOCK_N, n_size);
|
||||
|
||||
m_prime[row] = m_i;
|
||||
|
||||
// v' <- v' * m_delta
|
||||
at::vec::map<float>(
|
||||
[m_delta](Vec x) { return x * Vec(m_delta); },
|
||||
v_prime + row * head_size_v,
|
||||
v_prime + row * head_size_v,
|
||||
head_size_v);
|
||||
|
||||
// pad s_delta with 0 first and then convert to scalar_t
|
||||
fill_stub(s_delta + row * BLOCK_N + n_size, 0.f, padded_n_size - n_size);
|
||||
copy_stub<scalar_t, BLOCK_N>(s_delta2 + row * BLOCK_N, s_delta + row * BLOCK_N);
|
||||
}
|
||||
|
||||
// get value and pack
|
||||
pack_vnni2<scalar_t, index_t>(
|
||||
/* dst */ Btmp,
|
||||
/* src */ v_buffer + head_kv_id * v_strideH,
|
||||
/* ind */ req_to_token + req_pool_id * max_context_len + n,
|
||||
/* K */ n_size,
|
||||
/* N */ head_size_v,
|
||||
/* ld_src */ v_strideN,
|
||||
/* ld_dst */ head_size_v);
|
||||
|
||||
// calculate V' <- s_delta @ V + V'
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ head_size_v,
|
||||
/* K */ padded_n_size, // n_size
|
||||
/* lda */ BLOCK_N,
|
||||
/* ldb */ head_size_v,
|
||||
/* ldc */ head_size_v,
|
||||
/* add_C */ true,
|
||||
/* A */ s_delta2,
|
||||
/* B */ Btmp,
|
||||
/* C */ v_prime);
|
||||
} // loop with seq_len_prefix
|
||||
|
||||
// stage 2: compute the triangle part
|
||||
int num_keys = std::min(seq_len_extend, m + BLOCK_M);
|
||||
for (int n = 0; n < num_keys; n += BLOCK_N) {
|
||||
int n_size = std::min(BLOCK_N, num_keys - n);
|
||||
|
||||
// `n_size` is K in 2nd gemm, pad to TILE_K;
|
||||
const int padded_n_size = div_up(n_size, TILE_K) * TILE_K;
|
||||
|
||||
// get key and pack
|
||||
pack_vnni<scalar_t, index_t>(
|
||||
/* dst */ Btmp,
|
||||
/* src */ k_extend + (seq_extend_start_loc + n) * ke_strideN + head_kv_id * ke_strideH,
|
||||
/* ind */ nullptr,
|
||||
/* N */ n_size,
|
||||
/* K */ head_size,
|
||||
/* ld_src */ ke_strideN,
|
||||
/* ld_dst */ BLOCK_N);
|
||||
|
||||
// calculate s_i <- Q @ K
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ head_size,
|
||||
/* lda */ q_strideM,
|
||||
/* ldb */ BLOCK_N,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* add_C */ false,
|
||||
/* A */ q_ptr,
|
||||
/* B */ Btmp,
|
||||
/* C */ s_i);
|
||||
|
||||
// apply causal mask
|
||||
if (num_keys - n <= BLOCK_N) {
|
||||
for (int row = 0; row < m_size; ++row) {
|
||||
int last_col = m + row - n;
|
||||
// fill [last_col + 1, n_size) to -inf
|
||||
float* row_ptr = s_i + row * BLOCK_N;
|
||||
fill_stub(row_ptr + last_col + 1, -std::numeric_limits<float>::infinity(), n_size - last_col - 1);
|
||||
}
|
||||
}
|
||||
|
||||
const Vec scale_vec = Vec(scaling);
|
||||
for (int row = 0; row < m_size; ++row) {
|
||||
// s_i <- s_i * scale
|
||||
at::vec::map<float>(
|
||||
[scale_vec](Vec x) { return x * scale_vec; }, s_i + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
|
||||
|
||||
// TODO: `tanh` from torch uses sleef u10, going to be slow
|
||||
if (has_logit_cap) {
|
||||
at::vec::map<float>(
|
||||
[logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); },
|
||||
s_i + row * BLOCK_N,
|
||||
s_i + row * BLOCK_N,
|
||||
n_size);
|
||||
}
|
||||
|
||||
// m_i: max value per row
|
||||
float m_i = at::vec::reduce_all<float>(
|
||||
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i + row * BLOCK_N, n_size);
|
||||
m_i = std::max(m_i, m_prime[row]);
|
||||
|
||||
// m_delta <- exp(m' - m_i)
|
||||
float m_delta = std::exp(m_prime[row] - m_i);
|
||||
|
||||
// s_delta <- exp(s_i - m_i)
|
||||
at::vec::map<float>(
|
||||
[m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
|
||||
|
||||
// s' <- s' * m_delta + sum(s_delta)
|
||||
s_prime[row] *= m_delta;
|
||||
s_prime[row] +=
|
||||
at::vec::reduce_all<float>([](Vec& x, Vec& y) { return x + y; }, s_delta + row * BLOCK_N, n_size);
|
||||
|
||||
m_prime[row] = m_i;
|
||||
|
||||
// v' <- v' * m_delta
|
||||
at::vec::map<float>(
|
||||
[m_delta](Vec x) { return x * Vec(m_delta); },
|
||||
v_prime + row * head_size_v,
|
||||
v_prime + row * head_size_v,
|
||||
head_size_v);
|
||||
|
||||
// pad s_delta with 0 first and then convert to scalar_t
|
||||
fill_stub(s_delta + row * BLOCK_N + n_size, 0.f, padded_n_size - n_size);
|
||||
copy_stub<scalar_t, BLOCK_N>(s_delta2 + row * BLOCK_N, s_delta + row * BLOCK_N);
|
||||
}
|
||||
|
||||
// get value and pack
|
||||
pack_vnni2<scalar_t, index_t>(
|
||||
/* dst */ Btmp,
|
||||
/* src */ v_extend + (seq_extend_start_loc + n) * ve_strideN + head_kv_id * ve_strideH,
|
||||
/* ind */ nullptr,
|
||||
/* K */ n_size,
|
||||
/* N */ head_size_v,
|
||||
/* ld_src */ ve_strideN,
|
||||
/* ld_dst */ head_size_v);
|
||||
|
||||
// calculate V' <- s_delta @ V + V'
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ head_size_v,
|
||||
/* K */ padded_n_size, // n_size
|
||||
/* lda */ BLOCK_N,
|
||||
/* ldb */ head_size_v,
|
||||
/* ldc */ head_size_v,
|
||||
/* add_C */ true,
|
||||
/* A */ s_delta2,
|
||||
/* B */ Btmp,
|
||||
/* C */ v_prime);
|
||||
} // loop with seq_len_extend
|
||||
|
||||
scalar_t* __restrict__ out_ptr = o_extend + (seq_extend_start_loc + m) * o_strideM + head_id * o_strideH;
|
||||
for (int row = 0; row < m_size; ++row) {
|
||||
float s = 1 / s_prime[row];
|
||||
copy_stub<scalar_t>(out_ptr + row * o_strideM, v_prime + row * head_size_v, s, head_size_v);
|
||||
}
|
||||
|
||||
// move to the next index
|
||||
data_index_step(bs, batches, head_id, num_heads, mb, MB);
|
||||
}
|
||||
at::native::cpublas::brgemm_release();
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
||||
// k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
|
||||
//
|
||||
// q_extend: [num_tokens, num_heads, head_size]
|
||||
// k_extend: [num_extend_tokens, num_heads, head_size]
|
||||
// v_extend: [num_extend_tokens, num_heads, head_size]
|
||||
// o_extend: [num_tokens, num_heads, head_size]
|
||||
// k_buffer: [max_total_num_tokens, num_heads, head_size]
|
||||
// v_buffer: [max_total_num_tokens, num_heads, head_size]
|
||||
// req_to_token: [max_num_reqs, max_context_len] int32 or int64
|
||||
// req_pool_indices: [num_seqs] int64
|
||||
// seq_lens: [num_seqs] int64
|
||||
// extend_seq_lens: [num_seqs]
|
||||
// extend_start_loc: [num_seqs]
|
||||
//
|
||||
void extend_attention_cpu(
|
||||
at::Tensor& q_extend,
|
||||
at::Tensor& k_extend,
|
||||
at::Tensor& v_extend,
|
||||
at::Tensor& o_extend,
|
||||
at::Tensor& k_buffer,
|
||||
at::Tensor& v_buffer,
|
||||
at::Tensor& req_to_token,
|
||||
at::Tensor& req_pool_indices,
|
||||
at::Tensor& seq_lens,
|
||||
at::Tensor& extend_seq_lens,
|
||||
at::Tensor& extend_start_loc,
|
||||
int64_t max_len_extend,
|
||||
double sm_scale,
|
||||
double logit_cap) {
|
||||
RECORD_FUNCTION(
|
||||
"sgl-kernel::extend_attention_cpu",
|
||||
std::vector<c10::IValue>(
|
||||
{q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_token,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
extend_seq_lens,
|
||||
extend_start_loc}));
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_extend);
|
||||
CHECK_INPUT(o_extend);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_extend);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_extend);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_buffer);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_buffer);
|
||||
|
||||
int num_seqs = seq_lens.size(0);
|
||||
int max_num_reqs = req_to_token.size(0);
|
||||
int max_context_len = req_to_token.size(1);
|
||||
int max_total_num_tokens = k_buffer.size(0);
|
||||
|
||||
int num_heads = q_extend.size(1);
|
||||
int num_heads_kv = k_extend.size(1);
|
||||
int head_size = q_extend.size(2);
|
||||
int head_size_v = v_extend.size(2);
|
||||
|
||||
// strides for q_extend, k_extend and v_extend
|
||||
int q_strideM = q_extend.stride(0);
|
||||
int q_strideH = q_extend.stride(1);
|
||||
int ke_strideN = k_extend.stride(0);
|
||||
int ke_strideH = k_extend.stride(1);
|
||||
int ve_strideN = v_extend.stride(0);
|
||||
int ve_strideH = v_extend.stride(1);
|
||||
|
||||
// strides for k_buffer and v_buffer
|
||||
int k_strideN = k_buffer.stride(0);
|
||||
int k_strideH = k_buffer.stride(1);
|
||||
int v_strideN = v_buffer.stride(0);
|
||||
int v_strideH = v_buffer.stride(1);
|
||||
|
||||
// check sizes
|
||||
CHECK_EQ(req_pool_indices.size(0), num_seqs);
|
||||
CHECK_EQ(extend_seq_lens.size(0), num_seqs);
|
||||
CHECK_EQ(extend_start_loc.size(0), num_seqs);
|
||||
CHECK_EQ(v_extend.size(1), num_heads_kv);
|
||||
CHECK_EQ(k_buffer.size(1), v_buffer.size(1));
|
||||
|
||||
// MLA will skip prefix part
|
||||
const bool is_prefix_skipped = k_buffer.size(1) != num_heads_kv;
|
||||
|
||||
// check index data types
|
||||
const auto index_dtype = req_to_token.scalar_type();
|
||||
TORCH_CHECK(
|
||||
index_dtype == at::kInt || index_dtype == at::kLong,
|
||||
"extend: expect req_to_token to be int32 or int64, got ",
|
||||
index_dtype);
|
||||
TORCH_CHECK(seq_lens.scalar_type() == at::kLong, "extend: expect req_lens to be int64, got ", seq_lens.scalar_type());
|
||||
TORCH_CHECK(
|
||||
req_pool_indices.scalar_type() == at::kLong,
|
||||
"extend: expect req_pool_indices to be int64, got ",
|
||||
req_pool_indices.scalar_type());
|
||||
TORCH_CHECK(
|
||||
extend_seq_lens.scalar_type() == index_dtype && extend_start_loc.scalar_type() == index_dtype,
|
||||
"extend: expect extend_seq_lens and extend_start_loc to have same dtype as req_to_token.");
|
||||
|
||||
// D and DV need to be 32x as we transpose by 512-bit
|
||||
TORCH_CHECK(head_size % 32 == 0, "invalid head_size ", head_size);
|
||||
TORCH_CHECK(head_size_v % 32 == 0, "invalid head_size_v ", head_size_v);
|
||||
|
||||
// block size for query seq length
|
||||
constexpr int BLOCK_M = 32;
|
||||
// block size for key/value seq length
|
||||
constexpr int BLOCK_N = 32;
|
||||
|
||||
const int size_per_thread =
|
||||
/* s_i */ BLOCK_M * BLOCK_N * sizeof(float) +
|
||||
/* v_prime */ BLOCK_M * head_size_v * sizeof(float) +
|
||||
/* s_delta */ BLOCK_M * BLOCK_N * sizeof(uint16_t) +
|
||||
/* Btmp */ BLOCK_N * std::max(head_size, head_size_v) * sizeof(uint16_t);
|
||||
|
||||
int num_threads = at::get_num_threads();
|
||||
auto buffer = at::empty({num_threads, size_per_thread}, q_extend.options().dtype(at::kChar));
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(q_extend.scalar_type(), "extend_attention_kernel", [&] {
|
||||
AT_DISPATCH_INDEX_TYPES(index_dtype, "extend_attention_indices", [&] {
|
||||
extend_attention_kernel_impl<scalar_t, index_t, BLOCK_M, BLOCK_N>(
|
||||
o_extend.data_ptr<scalar_t>(),
|
||||
q_extend.data_ptr<scalar_t>(),
|
||||
k_extend.data_ptr<scalar_t>(),
|
||||
v_extend.data_ptr<scalar_t>(),
|
||||
k_buffer.data_ptr<scalar_t>(),
|
||||
v_buffer.data_ptr<scalar_t>(),
|
||||
req_to_token.data_ptr<index_t>(),
|
||||
req_pool_indices.data_ptr<int64_t>(),
|
||||
seq_lens.data_ptr<int64_t>(),
|
||||
extend_seq_lens.data_ptr<index_t>(),
|
||||
extend_start_loc.data_ptr<index_t>(),
|
||||
buffer.data_ptr(),
|
||||
num_seqs,
|
||||
num_heads,
|
||||
num_heads_kv,
|
||||
head_size,
|
||||
head_size_v,
|
||||
q_strideM,
|
||||
q_strideH,
|
||||
ke_strideN,
|
||||
ke_strideH,
|
||||
ve_strideN,
|
||||
ve_strideH,
|
||||
k_strideN,
|
||||
k_strideH,
|
||||
v_strideN,
|
||||
v_strideH,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
max_num_reqs,
|
||||
max_context_len,
|
||||
max_total_num_tokens,
|
||||
max_len_extend,
|
||||
size_per_thread,
|
||||
is_prefix_skipped);
|
||||
});
|
||||
});
|
||||
}
|
||||
525
sgl-kernel/csrc/cpu/gemm.cpp
Normal file
525
sgl-kernel/csrc/cpu/gemm.cpp
Normal file
@@ -0,0 +1,525 @@
|
||||
#include "gemm.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
// packed layout:
|
||||
// quants {N, K} int8_t
|
||||
// comp {N} int32_t
|
||||
template <int BLOCK_N>
|
||||
inline void s8s8_compensation(int8_t* __restrict__ packed, int K) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
__m512i vcomp[COLS];
|
||||
|
||||
for (int col = 0; col < COLS; ++col) {
|
||||
vcomp[col] = _mm512_setzero_si512();
|
||||
}
|
||||
|
||||
const int64_t offset = BLOCK_N * K;
|
||||
const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
|
||||
for (int k = 0; k < K / 4; ++k) {
|
||||
for (int col = 0; col < COLS; ++col) {
|
||||
__m512i vb = _mm512_loadu_si512((const __m512i*)(packed + k * BLOCK_N * 4 + col * 64));
|
||||
vcomp[col] = _mm512_dpbusd_epi32(vcomp[col], off, vb);
|
||||
}
|
||||
}
|
||||
|
||||
for (int col = 0; col < COLS; ++col) {
|
||||
_mm512_storeu_si512((__m512i*)(packed + offset + col * 64), vcomp[col]);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false, "s8s8_compensation not implemented!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert to vnni format
|
||||
// from [N, K] to [K/2, N, 2] for bfloat16 and float16
|
||||
template <typename packed_t>
|
||||
inline void pack_vnni(packed_t* __restrict__ packed, const packed_t* __restrict__ weight, int N, int K) {
|
||||
const int VNNI_BLK = 2;
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int k = 0; k < K / VNNI_BLK; ++k) {
|
||||
for (int d = 0; d < VNNI_BLK; ++d) {
|
||||
packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void pack_vnni<int8_t>(int8_t* __restrict__ packed, const int8_t* __restrict__ weight, int N, int K) {
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
TORCH_CHECK(N == BLOCK_N);
|
||||
|
||||
const int VNNI_BLK = 4;
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int k = 0; k < K / VNNI_BLK; ++k) {
|
||||
for (int d = 0; d < VNNI_BLK; ++d) {
|
||||
packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
s8s8_compensation<BLOCK_N>(packed, K);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_add_stub(
|
||||
scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] + bias[d]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
const float* __restrict__ bias,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const at::BFloat16* __restrict__ A,
|
||||
const at::BFloat16* __restrict__ B,
|
||||
at::BFloat16* __restrict__ C,
|
||||
const float* __restrict__ bias,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 0;
|
||||
|
||||
__m512bh va;
|
||||
__m512bh vb[COLS];
|
||||
__m512 vc[ROWS * COLS];
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
constexpr int col = i % COLS;
|
||||
if constexpr (has_bias) {
|
||||
vc[i] = _mm512_loadu_ps(bias + col * 16);
|
||||
} else {
|
||||
vc[i] = _mm512_set1_ps(0.f);
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int64_t K2 = K >> 1;
|
||||
const int64_t lda2 = lda >> 1;
|
||||
const int64_t ldb2 = ldb; // ldb * 2 >> 1;
|
||||
const float* a_ptr = reinterpret_cast<const float*>(A);
|
||||
const float* b_ptr = reinterpret_cast<const float*>(B);
|
||||
|
||||
auto compute = [&](auto i, int64_t k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
vb[col] = (__m512bh)(_mm512_loadu_si512(b_ptr + k * ldb2 + col * 16));
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]);
|
||||
};
|
||||
for (int64_t k = 0; k < K2; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
// for COLS = 1, 3 use 256bit store
|
||||
if constexpr (COLS % 2 == 0) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
|
||||
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
|
||||
}
|
||||
} else {
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(C + row * ldc + col * 16), (__m256i)(_mm512_cvtneps_pbh(vc[i])));
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, \
|
||||
B + nb_start * 2, \
|
||||
C + mb_start * ldc + nb_start, \
|
||||
has_bias ? bias + nb_start : nullptr, \
|
||||
K, \
|
||||
lda, \
|
||||
ldb, \
|
||||
ldc);
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
struct brgemm {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
at::native::cpublas::brgemm(M, N, K, lda, ldb, BLOCK_N, /* add_C */ false, A, B, Ctmp);
|
||||
|
||||
// copy from Ctmp to C
|
||||
for (int64_t m = 0; m < M; ++m) {
|
||||
if constexpr (has_bias) {
|
||||
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
|
||||
} else {
|
||||
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg) {
|
||||
if (brg) {
|
||||
brgemm<scalar_t, has_bias>::apply(A, B, C, Ctmp, bias, M, N, K, lda, ldb, ldc);
|
||||
return;
|
||||
}
|
||||
|
||||
// pattern: 1-4-16, N = 16, 32, 48, 64
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch (mb_size << 4 | nb_size >> 4) {
|
||||
// mb_size = 1
|
||||
case 0x11:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 16);
|
||||
break;
|
||||
case 0x12:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
|
||||
break;
|
||||
case 0x13:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 48);
|
||||
break;
|
||||
case 0x14:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 64);
|
||||
break;
|
||||
// mb_size = 2
|
||||
case 0x21:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 16);
|
||||
break;
|
||||
case 0x22:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
|
||||
break;
|
||||
case 0x23:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 48);
|
||||
break;
|
||||
case 0x24:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 64);
|
||||
break;
|
||||
// mb_size = 3
|
||||
case 0x31:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 16);
|
||||
break;
|
||||
case 0x32:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
|
||||
break;
|
||||
case 0x33:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 48);
|
||||
break;
|
||||
case 0x34:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 64);
|
||||
break;
|
||||
// mb_size = 4
|
||||
case 0x41:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 16);
|
||||
break;
|
||||
case 0x42:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
|
||||
break;
|
||||
case 0x43:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 48);
|
||||
break;
|
||||
case 0x44:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 64);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unexpected block size, ", mb_size, " x ", nb_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void weight_packed_linear_kernel_impl(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ mat1,
|
||||
const scalar_t* __restrict__ mat2,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t mat1_strideM,
|
||||
int64_t out_strideM) {
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<scalar_t>(M);
|
||||
|
||||
// parallel on [MB, NB]
|
||||
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
|
||||
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// for brgemm, use float32 for accumulate
|
||||
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
loop_2d<scalar_t>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t, has_bias>(
|
||||
/* A */ mat1 + mb_start * mat1_strideM,
|
||||
/* B */ mat2 + nb_start * K /* nb * BLOCK_N * K */,
|
||||
/* C */ out + mb_start * out_strideM + nb_start,
|
||||
/* Ctmp*/ Ctmp,
|
||||
/* bias*/ bias + nb_start,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ mat1_strideM,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ out_strideM,
|
||||
/* brg */ use_brgemm);
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg) {
|
||||
tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, nullptr, M, N, K, lda, ldb, ldc, brg);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||
template void tinygemm_kernel<TYPE>( \
|
||||
const TYPE* __restrict__ A, \
|
||||
const TYPE* __restrict__ B, \
|
||||
TYPE* __restrict__ C, \
|
||||
float* __restrict__ Ctmp, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K, \
|
||||
int64_t lda, \
|
||||
int64_t ldb, \
|
||||
int64_t ldc, \
|
||||
bool brg)
|
||||
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||
|
||||
at::Tensor convert_weight_packed(at::Tensor& weight) {
|
||||
// for 3d moe weights
|
||||
// weight : [E, OC, IC]
|
||||
// w1 : [E, 2N, K]
|
||||
// w2 : [E, K, N]
|
||||
CHECK_INPUT(weight);
|
||||
|
||||
const int64_t ndim = weight.ndimension();
|
||||
TORCH_CHECK(ndim == 2 || ndim == 3, "expect weight to be 2d or 3d, got ", ndim, "d tensor.");
|
||||
const auto st = weight.scalar_type();
|
||||
const int64_t E = ndim == 3 ? weight.size(0) : 1;
|
||||
const int64_t OC = ndim == 3 ? weight.size(1) : weight.size(0);
|
||||
const int64_t IC = ndim == 3 ? weight.size(2) : weight.size(1);
|
||||
|
||||
// we handle 2 TILE_N at a time.
|
||||
TORCH_CHECK(OC % TILE_N == 0, "invalid weight out features ", OC);
|
||||
TORCH_CHECK(IC % TILE_K == 0, "invalid weight input features ", IC);
|
||||
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t NB = div_up(OC, BLOCK_N);
|
||||
|
||||
// use phony sizes here [E, OC, IC], for each [E], [OC, IC] -> [IC / 2, OC, 2]
|
||||
auto packed_weight = at::empty({}, weight.options());
|
||||
const int64_t stride = OC * IC;
|
||||
|
||||
TORCH_CHECK(
|
||||
st == at::kBFloat16 || st == at::kHalf || st == at::kChar || st == at::kFloat8_e4m3fn,
|
||||
"expect weight to be bfloat16, float16, int8 or fp8_e4m3.");
|
||||
|
||||
CPU_DISPATCH_PACKED_TYPES(st, [&] {
|
||||
// adjust most inner dimension size
|
||||
const int packed_row_size = get_row_size<packed_t>(IC);
|
||||
auto sizes = weight.sizes().vec();
|
||||
sizes[ndim - 1] = packed_row_size;
|
||||
packed_weight.resize_(sizes);
|
||||
|
||||
const packed_t* w_data = weight.data_ptr<packed_t>();
|
||||
packed_t* packed_data = packed_weight.data_ptr<packed_t>();
|
||||
|
||||
// parallel on {E, NB}
|
||||
at::parallel_for(0, E * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t e{0}, nb{0};
|
||||
data_index_init(begin, e, E, nb, NB);
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
|
||||
int64_t n = nb * BLOCK_N;
|
||||
int64_t n_size = std::min(BLOCK_N, OC - n);
|
||||
pack_vnni<packed_t>(
|
||||
packed_data + e * OC * packed_row_size + n * packed_row_size, w_data + e * stride + n * IC, n_size, IC);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(e, E, nb, NB);
|
||||
}
|
||||
});
|
||||
});
|
||||
return packed_weight;
|
||||
}
|
||||
|
||||
// mat1 : [M, K]
|
||||
// mat2 : [N, K]
|
||||
// bias : [N]
|
||||
// out : [M, N]
|
||||
//
|
||||
at::Tensor
|
||||
weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional<at::Tensor>& bias, bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::weight_packed_linear", std::vector<c10::IValue>({mat1, mat2, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat2.size(1);
|
||||
CHECK_EQ(mat1.size(1), K);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
auto out = at::empty({M, N}, mat1.options());
|
||||
|
||||
// strides
|
||||
int64_t mat1_strideM = mat1.stride(0);
|
||||
int64_t out_strideM = out.stride(0);
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "weight_packed_linear_kernel_impl", [&] {
|
||||
weight_packed_linear_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
mat1.data_ptr<scalar_t>(),
|
||||
packed_w.data_ptr<scalar_t>(),
|
||||
bias_data,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
mat1_strideM,
|
||||
out_strideM);
|
||||
});
|
||||
|
||||
return out;
|
||||
}
|
||||
202
sgl-kernel/csrc/cpu/gemm.h
Normal file
202
sgl-kernel/csrc/cpu/gemm.h
Normal file
@@ -0,0 +1,202 @@
|
||||
#pragma once
|
||||
#include <ATen/native/CPUBlas.h>
|
||||
|
||||
#include "common.h"
|
||||
|
||||
// amx-bf16
|
||||
#define TILE_M 16
|
||||
#define TILE_N 16
|
||||
#define TILE_K 32
|
||||
|
||||
// block size for AMX gemm
|
||||
constexpr int block_size_m() {
|
||||
return 2 * TILE_M;
|
||||
}
|
||||
constexpr int block_size_n() {
|
||||
return 2 * TILE_N;
|
||||
}
|
||||
|
||||
// define threshold using brgemm (intel AMX)
|
||||
template <typename T>
|
||||
inline bool can_use_brgemm(int M);
|
||||
template <>
|
||||
inline bool can_use_brgemm<at::BFloat16>(int M) {
|
||||
return M > 4;
|
||||
}
|
||||
template <>
|
||||
inline bool can_use_brgemm<at::Half>(int M) {
|
||||
return true;
|
||||
}
|
||||
// this requires PyTorch 2.7 or above
|
||||
template <>
|
||||
inline bool can_use_brgemm<int8_t>(int M) {
|
||||
return M > 4;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) {
|
||||
return M > 4;
|
||||
}
|
||||
|
||||
// work around compiler internal error
|
||||
#define BLOCK_K 128 // 4 * TILE_K
|
||||
|
||||
// adjust leading dimension size for K
|
||||
template <typename T>
|
||||
inline int64_t get_row_size(int64_t K) {
|
||||
return K;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline int64_t get_row_size<int8_t>(int64_t K) {
|
||||
return K + sizeof(int32_t);
|
||||
}
|
||||
|
||||
inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) {
|
||||
return use_int8_w8a8 ? K + sizeof(int32_t) : K;
|
||||
}
|
||||
|
||||
// pack weight to vnni format
|
||||
at::Tensor convert_weight_packed(at::Tensor& weight);
|
||||
|
||||
// moe implementations for int8 w8a8
|
||||
template <typename scalar_t>
|
||||
void fused_experts_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
uint8_t* __restrict__ A_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
uint8_t* __restrict__ Aq_tmp,
|
||||
float* __restrict__ As_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const int8_t* __restrict__ packed_w1,
|
||||
const int8_t* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad);
|
||||
|
||||
// moe implementations for fp8 w8a16
|
||||
template <typename scalar_t>
|
||||
void fused_experts_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
scalar_t* __restrict__ A_tmp,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad);
|
||||
|
||||
// shared expert implementation for int8 w8a8
|
||||
template <typename scalar_t>
|
||||
void shared_expert_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic1,
|
||||
float* __restrict__ C_tmp,
|
||||
uint8_t* __restrict__ Aq_tmp,
|
||||
float* __restrict__ As_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const int8_t* __restrict__ packed_w1,
|
||||
const int8_t* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
const scalar_t* __restrict__ fused_experts_out,
|
||||
float routed_scaling_factor,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K);
|
||||
|
||||
template <typename scalar_t>
|
||||
void shared_expert_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const scalar_t* __restrict__ fused_experts_out,
|
||||
float routed_scaling_factor,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K);
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
int32_t* __restrict__ Ctmp,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ scale,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg,
|
||||
int64_t block_size_K,
|
||||
bool do_unpack = true);
|
||||
551
sgl-kernel/csrc/cpu/gemm_fp8.cpp
Normal file
551
sgl-kernel/csrc/cpu/gemm_fp8.cpp
Normal file
@@ -0,0 +1,551 @@
|
||||
#include "common.h"
|
||||
#include "gemm.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_add_stub(
|
||||
scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] + bias[d]);
|
||||
}
|
||||
}
|
||||
|
||||
inline void unpack_B(
|
||||
at::BFloat16* __restrict__ Btmp,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_B,
|
||||
int N,
|
||||
int K,
|
||||
int ldb,
|
||||
int ldb_tmp,
|
||||
float scale) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
// [K/2, N, 2]
|
||||
const int K2 = K >> 1;
|
||||
const int ldb2 = ldb; // ldb * 2 >> 1;
|
||||
const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(packed_B);
|
||||
const __m512 vd = _mm512_set1_ps(scale);
|
||||
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
static_assert(BLOCK_N == 32);
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 64;
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (int k = 0; k < K2; ++k) {
|
||||
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2);
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2, _MM_HINT_T0);
|
||||
}
|
||||
|
||||
__m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0);
|
||||
__m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1);
|
||||
|
||||
__m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0);
|
||||
__m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1);
|
||||
|
||||
// Apply scale
|
||||
__m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0));
|
||||
__m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1));
|
||||
__m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0));
|
||||
__m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1));
|
||||
|
||||
f0_lo = _mm512_mul_ps(f0_lo, vd);
|
||||
f0_hi = _mm512_mul_ps(f0_hi, vd);
|
||||
f1_lo = _mm512_mul_ps(f1_lo, vd);
|
||||
f1_hi = _mm512_mul_ps(f1_hi, vd);
|
||||
|
||||
bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo);
|
||||
bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo);
|
||||
|
||||
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 0, (__m512i)bf16_0);
|
||||
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 32, (__m512i)bf16_1);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false, "unpack_B: scalar path not implemented!");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename packed_t, bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A,
|
||||
const packed_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
const float* __restrict__ bias,
|
||||
const float* __restrict__ scale,
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
int64_t block_size_K) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn<at::BFloat16, at::Float8_e4m3fn, has_bias, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const at::BFloat16* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
at::BFloat16* __restrict__ C,
|
||||
const float* __restrict__ bias,
|
||||
const float* __restrict__ scale,
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
int64_t block_size_K) {
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
|
||||
const int KB = div_up(K, BLOCK_K);
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 64;
|
||||
constexpr int PREFETCH_SIZE_KB = 1;
|
||||
|
||||
__m512bh va;
|
||||
__m512bh vb[COLS];
|
||||
__m512 vc[ROWS * COLS];
|
||||
__m512 vsum[ROWS * COLS];
|
||||
|
||||
// block quant scale
|
||||
__m512 vscale;
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
constexpr int col = i % COLS;
|
||||
if constexpr (has_bias) {
|
||||
vc[i] = _mm512_loadu_ps(bias + col * 16);
|
||||
} else {
|
||||
vc[i] = _mm512_setzero_ps();
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int lda2 = lda >> 1;
|
||||
const int ldb2 = ldb; // ldb * 2 >> 1;
|
||||
const float* a_ptr = reinterpret_cast<const float*>(A);
|
||||
const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(B);
|
||||
|
||||
auto compute = [&](auto i, int k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(a_ptr + row * lda2 + k + PREFETCH_SIZE_K, _MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + col * 16);
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
|
||||
}
|
||||
vb[col + 0] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 0));
|
||||
vb[col + 1] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 1));
|
||||
}
|
||||
}
|
||||
vsum[i] = _mm512_dpbf16_ps(vsum[i], va, vb[col]);
|
||||
};
|
||||
|
||||
constexpr int BLOCK_K2 = BLOCK_K >> 1;
|
||||
for (int kb = 0; kb < KB; ++kb) {
|
||||
int kb_start = kb * BLOCK_K2;
|
||||
int kb_end = std::min(K >> 1, kb_start + BLOCK_K2);
|
||||
// 1. load scale vector
|
||||
vscale = _mm512_set1_ps(scale[kb]);
|
||||
if constexpr (PREFETCH_SIZE_KB > 0) {
|
||||
_mm_prefetch(scale + kb + PREFETCH_SIZE_KB, _MM_HINT_T0);
|
||||
}
|
||||
// 2. zero vsum for each block
|
||||
Unroll<ROWS * COLS>{}([&](auto i) { vsum[i] = _mm512_setzero_ps(); });
|
||||
// 3. accumulate across each block
|
||||
for (int k = kb_start; k < kb_end; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
// 4. apply scale
|
||||
Unroll<ROWS * COLS>{}([&](auto i) { vc[i] = _mm512_fmadd_ps(vsum[i], vscale, vc[i]); });
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
// for COLS = 2,4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
|
||||
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_nn<scalar_t, at::Float8_e4m3fn, has_bias, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, \
|
||||
B + nb_start * 2, \
|
||||
C + mb_start * ldc + nb_start, \
|
||||
has_bias ? bias + nb_start : nullptr, \
|
||||
scale, \
|
||||
K, \
|
||||
lda, \
|
||||
ldb, \
|
||||
ldc, \
|
||||
block_size_K);
|
||||
|
||||
template <typename scalar_t, typename packed_t, bool has_bias>
|
||||
struct brgemm {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A,
|
||||
const packed_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ bias,
|
||||
const float* __restrict__ scale,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
bool do_unpack = true) {
|
||||
TORCH_CHECK(false, "struct brgemm: primary template not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
template <bool has_bias>
|
||||
struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> {
|
||||
static inline void apply(
|
||||
const at::BFloat16* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
at::BFloat16* __restrict__ C,
|
||||
at::BFloat16* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ bias,
|
||||
const float* __restrict__ scale,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
bool do_unpack = true) {
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
|
||||
// [K, BLOCK_N] -> [K / 2, BLOCK_N * 2]
|
||||
const int ldb_tmp = BLOCK_N;
|
||||
|
||||
if (do_unpack) {
|
||||
for (int k = 0; k < K; k += BLOCK_K) {
|
||||
int kb_size = std::min(BLOCK_K, K - k);
|
||||
|
||||
int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128
|
||||
unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]);
|
||||
}
|
||||
}
|
||||
|
||||
at::native::cpublas::brgemm(M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp);
|
||||
|
||||
// copy from Ctmp to C
|
||||
for (int m = 0; m < M; ++m) {
|
||||
if constexpr (has_bias) {
|
||||
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
|
||||
} else {
|
||||
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ scale,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg,
|
||||
int64_t block_size_K,
|
||||
bool do_unpack = true) {
|
||||
if (brg) {
|
||||
brgemm<scalar_t, at::Float8_e4m3fn, has_bias>::apply(
|
||||
A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc, do_unpack);
|
||||
return;
|
||||
}
|
||||
|
||||
// pattern: 1-4-16
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch (mb_size << 4 | nb_size >> 4) {
|
||||
case 0x12:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
|
||||
break;
|
||||
case 0x22:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
|
||||
break;
|
||||
case 0x32:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
|
||||
break;
|
||||
case 0x42:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void fp8_scaled_mm_kernel_impl(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ mat1,
|
||||
const at::Float8_e4m3fn* __restrict__ mat2,
|
||||
const float* __restrict__ scales2,
|
||||
const float* __restrict__ bias,
|
||||
scalar_t* __restrict__ buffer,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t mat1_strideM,
|
||||
int64_t out_strideM,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
int64_t buffer_size_per_thread) {
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
const int64_t scale_size_K = div_up(K, block_size_K);
|
||||
const int64_t blocks_n_per_group = block_size_N / BLOCK_N;
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
|
||||
|
||||
// parallel on [MB, NB]
|
||||
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
|
||||
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
int tid = get_thread_num();
|
||||
scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread;
|
||||
float* __restrict__ Ctmp = (float*)((void*)(Btmp + MAX_CACHE_BLOCK_SIZE * BLOCK_N * K));
|
||||
|
||||
loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K;
|
||||
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
|
||||
|
||||
// only do unpacking for the first row
|
||||
bool do_unpack = (mb == mb0);
|
||||
|
||||
tinygemm_kernel<scalar_t, has_bias>(
|
||||
/* A */ mat1 + mb_start * mat1_strideM,
|
||||
/* B */ mat2 + nb_start * K, // nb * BLOCK_N * K
|
||||
/* C */ out + mb_start * out_strideM + nb_start,
|
||||
/* Btmp */ Btmp + nb_offset * BLOCK_N * K,
|
||||
/* Ctmp */ Ctmp,
|
||||
/* scale */ scale_ptr,
|
||||
/* bias */ bias + nb_start,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ mat1_strideM,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ out_strideM,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K,
|
||||
/* do_unpack */ do_unpack);
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ scale,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg,
|
||||
int64_t block_size_K,
|
||||
bool do_unpack) {
|
||||
tinygemm_kernel<scalar_t, false>(
|
||||
A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K, do_unpack);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||
template void tinygemm_kernel<TYPE>( \
|
||||
const TYPE* __restrict__ A, \
|
||||
const at::Float8_e4m3fn* __restrict__ B, \
|
||||
TYPE* __restrict__ C, \
|
||||
TYPE* __restrict__ Btmp, \
|
||||
float* __restrict__ Ctmp, \
|
||||
const float* __restrict__ scale, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K, \
|
||||
int64_t lda, \
|
||||
int64_t ldb, \
|
||||
int64_t ldc, \
|
||||
bool brg, \
|
||||
int64_t block_size_K, \
|
||||
bool do_unpack)
|
||||
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||
|
||||
at::Tensor fp8_scaled_mm_cpu(
|
||||
at::Tensor& mat1,
|
||||
at::Tensor& mat2,
|
||||
at::Tensor& scales2,
|
||||
std::vector<int64_t> block_size,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype,
|
||||
bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, block_size, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
CHECK_INPUT(scales2);
|
||||
TORCH_CHECK(scales2.scalar_type() == at::kFloat, "fp8_scaled_mm_cpu: expect scales2 to be float32.");
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat2.size(1);
|
||||
|
||||
CHECK_EQ(mat1.size(1), K);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
TORCH_CHECK(block_size.size() == 2, "fp8_scaled_mm_cpu: expect block_size.size() to be 2.");
|
||||
|
||||
int64_t block_size_N = block_size[0];
|
||||
int64_t block_size_K = block_size[1];
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N");
|
||||
TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K");
|
||||
CHECK_EQ(scales2.size(0), div_up(N, block_size_N));
|
||||
CHECK_EQ(scales2.size(1), div_up(K, block_size_K));
|
||||
|
||||
const auto st = mat1.scalar_type();
|
||||
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, "fp8_scaled_mm_cpu: expect A to be bfloat16 or half.");
|
||||
TORCH_CHECK(st == out_dtype, "fp8_scaled_mm_cpu: expect A has same dtype with out_dtype.");
|
||||
TORCH_CHECK(mat2.scalar_type() == at::kFloat8_e4m3fn, "fp8_scaled_mm_cpu: expect mat2 to be fp8_e4m3.");
|
||||
TORCH_CHECK(scales2.scalar_type() == at::kFloat, "fp8_scaled_mm_cpu: expect scales to be float32.");
|
||||
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
|
||||
|
||||
// strides
|
||||
int64_t mat1_strideM = mat1.stride(0);
|
||||
int64_t out_strideM = out.stride(0);
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
// Btmp : [T, BLOCK_N * K]
|
||||
// Ctmp : [T, BLOCK_M * BLOCK_N]
|
||||
int num_threads = at::get_num_threads();
|
||||
int64_t size_per_thread = MAX_CACHE_BLOCK_SIZE * BLOCK_N * K + BLOCK_M * BLOCK_N * 2;
|
||||
auto buffer = at::empty({num_threads, size_per_thread}, mat1.options());
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] {
|
||||
fp8_scaled_mm_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
mat1.data_ptr<scalar_t>(),
|
||||
packed_w.data_ptr<at::Float8_e4m3fn>(),
|
||||
scales2.data_ptr<float>(),
|
||||
bias_data,
|
||||
buffer.data_ptr<scalar_t>(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
mat1_strideM,
|
||||
out_strideM,
|
||||
block_size_N,
|
||||
block_size_K,
|
||||
size_per_thread);
|
||||
});
|
||||
|
||||
return out;
|
||||
}
|
||||
547
sgl-kernel/csrc/cpu/gemm_int8.cpp
Normal file
547
sgl-kernel/csrc/cpu/gemm_int8.cpp
Normal file
@@ -0,0 +1,547 @@
|
||||
#include "common.h"
|
||||
#include "gemm.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t, bool has_bias, int BLOCK_N>
|
||||
struct scale_C {
|
||||
static inline void apply(
|
||||
scalar_t* __restrict__ C,
|
||||
const int32_t* __restrict__ Ctmp,
|
||||
const int32_t* __restrict__ Bcomp,
|
||||
const float* __restrict__ bias,
|
||||
float As,
|
||||
const float* __restrict__ Bs) {
|
||||
TORCH_CHECK(false, "scale_C: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <bool has_bias, int BLOCK_N>
|
||||
struct scale_C<at::BFloat16, has_bias, BLOCK_N> {
|
||||
static inline void apply(
|
||||
at::BFloat16* __restrict__ C,
|
||||
const int32_t* __restrict__ Ctmp,
|
||||
const int32_t* __restrict__ Bcomp,
|
||||
const float* __restrict__ bias,
|
||||
float As,
|
||||
const float* __restrict__ Bs) {
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
static_assert(COLS % 2 == 0);
|
||||
|
||||
__m512 vc[COLS];
|
||||
__m512 vd0 = _mm512_set1_ps(As);
|
||||
|
||||
auto compute = [&](auto col) {
|
||||
__m512 vd1 = _mm512_loadu_ps(Bs + col * 16);
|
||||
__m512i vcomp = _mm512_loadu_si512(Bcomp + col * 16);
|
||||
__m512i vc32 = _mm512_loadu_si512(Ctmp + col * 16);
|
||||
vc[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32, vcomp));
|
||||
if constexpr (has_bias) {
|
||||
__m512 vbias = _mm512_loadu_ps(bias + col * 16);
|
||||
vc[col] = _mm512_fmadd_ps(_mm512_mul_ps(vc[col], vd0), vd1, vbias);
|
||||
} else {
|
||||
vc[col] = _mm512_mul_ps(_mm512_mul_ps(vc[col], vd0), vd1);
|
||||
}
|
||||
};
|
||||
Unroll<COLS>{}(compute);
|
||||
|
||||
auto storec = [&](auto col) {
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + col * 16)), (__m512i)(_mm512_cvtne2ps_pbh(vc[col + 1], vc[col + 0])));
|
||||
}
|
||||
};
|
||||
Unroll<COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
const int32_t* __restrict__ Bcomp,
|
||||
const float* __restrict__ bias,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
at::BFloat16* __restrict__ C,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
const int32_t* __restrict__ Bcomp,
|
||||
const float* __restrict__ bias,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
static_assert(COLS % 2 == 0);
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 0;
|
||||
|
||||
__m512i va;
|
||||
__m512i vb[COLS];
|
||||
__m512i vc[ROWS * COLS];
|
||||
__m512i vcomp[COLS];
|
||||
__m512 vd0;
|
||||
__m512 vd1[COLS];
|
||||
|
||||
// oops! 4x4 spills but we use 4x2
|
||||
__m512 vbias[COLS];
|
||||
|
||||
// [NOTE]: s8s8 igemm compensation in avx512-vnni
|
||||
//
|
||||
// avx512-vnni has no s8s8, so we need to change s8s8 to u8s8 with compensate:
|
||||
//
|
||||
// a * b = (a + 128) * b - 128 * b
|
||||
// s s u s u s
|
||||
//
|
||||
// 1) 128 * b is pre-computed when packing B to vnni formats
|
||||
// 2) a + 128 is fused when dynamically quantize A
|
||||
//
|
||||
auto loadc = [&](auto i) { vc[i] = _mm512_set1_epi32(0); };
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int64_t K4 = K >> 2;
|
||||
const int64_t lda4 = lda >> 2;
|
||||
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
|
||||
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
|
||||
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(B);
|
||||
|
||||
auto compute = [&](auto i, int64_t k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16);
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb4 + col * 16, _MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
|
||||
};
|
||||
for (int64_t k = 0; k < K4; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
// load a scale
|
||||
if constexpr (col == 0) {
|
||||
vd0 = _mm512_set1_ps(As[row]);
|
||||
}
|
||||
// load b scale and vcomp per 2 vectors
|
||||
// also load bias if any
|
||||
if constexpr (row == 0) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
vd1[col + 0] = _mm512_loadu_ps(Bs + col * 16);
|
||||
vd1[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16);
|
||||
vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16);
|
||||
vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16);
|
||||
if constexpr (has_bias) {
|
||||
vbias[col + 0] = _mm512_loadu_ps(bias + col * 16);
|
||||
vbias[col + 1] = _mm512_loadu_ps(bias + col * 16 + 16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
__m512 vc0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 0], vcomp[col + 0]));
|
||||
__m512 vc1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 1], vcomp[col + 1]));
|
||||
if constexpr (has_bias) {
|
||||
vc0 = _mm512_fmadd_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0], vbias[col + 0]);
|
||||
vc1 = _mm512_fmadd_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1], vbias[col + 1]);
|
||||
} else {
|
||||
vc0 = _mm512_mul_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0]);
|
||||
vc1 = _mm512_mul_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1]);
|
||||
}
|
||||
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), (__m512i)(_mm512_cvtne2ps_pbh(vc1, vc0)));
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, \
|
||||
B + nb_start * 4, \
|
||||
C + mb_start * ldc + nb_start, \
|
||||
As + mb_start, \
|
||||
Bs + nb_start, \
|
||||
Bcomp + nb_start, \
|
||||
has_bias ? bias + nb_start : nullptr, \
|
||||
K, \
|
||||
lda, \
|
||||
ldb, \
|
||||
ldc);
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
int32_t* __restrict__ Ctmp,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg) {
|
||||
// B compensation
|
||||
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K);
|
||||
|
||||
if (brg) {
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
at::native::cpublas::brgemm(M, N, K, lda, ldb, BLOCK_N, /* add_C */ false, A, B, Ctmp);
|
||||
|
||||
// apply compensation and scale
|
||||
for (int64_t m = 0; m < M; ++m) {
|
||||
scale_C<scalar_t, has_bias, BLOCK_N>::apply(C + m * ldc, Ctmp + m * BLOCK_N, Bcomp, bias, As[m], Bs);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// pattern: 1-4-16
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int64_t mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch (mb_size << 4 | nb_size >> 4) {
|
||||
// mb_size = 1
|
||||
case 0x12:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
|
||||
break;
|
||||
case 0x14:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 64);
|
||||
break;
|
||||
// mb_size = 2
|
||||
case 0x22:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
|
||||
break;
|
||||
case 0x24:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 64);
|
||||
break;
|
||||
// mb_size = 3
|
||||
case 0x32:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
|
||||
break;
|
||||
case 0x34:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 64);
|
||||
break;
|
||||
// mb_size = 4
|
||||
case 0x42:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
|
||||
break;
|
||||
case 0x44:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 64);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void int8_scaled_mm_kernel_impl(
|
||||
scalar_t* __restrict__ out,
|
||||
const uint8_t* __restrict__ mat1,
|
||||
const int8_t* __restrict__ mat2,
|
||||
const float* __restrict__ scales1,
|
||||
const float* __restrict__ scales2,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<int8_t>(M);
|
||||
|
||||
// K + 4 after compensation
|
||||
const int64_t packed_row_size = get_row_size<int8_t>(K);
|
||||
|
||||
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
|
||||
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// for brgemm, use int32_t for accumulate
|
||||
alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int nb_start = nb * BLOCK_N;
|
||||
int nb_size = std::min(N - nb_start, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t, has_bias>(
|
||||
/* A */ mat1 + mb_start * K,
|
||||
/* B */ mat2 + nb_start * packed_row_size /* nb * BLOCK_N * (K + 4) */,
|
||||
/* C */ out + mb_start * N + nb_start,
|
||||
/* Ctmp*/ Ctmp,
|
||||
/* As */ scales1 + mb_start,
|
||||
/* Bs */ scales2 + nb_start,
|
||||
/* bias*/ bias + nb_start,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ N,
|
||||
/* brg */ use_brgemm);
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
int32_t* __restrict__ Ctmp,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg) {
|
||||
tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, As, Bs, nullptr, M, N, K, lda, ldb, ldc, brg);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||
template void tinygemm_kernel<TYPE>( \
|
||||
const uint8_t* __restrict__ A, \
|
||||
const int8_t* __restrict__ B, \
|
||||
TYPE* __restrict__ C, \
|
||||
int32_t* __restrict__ Ctmp, \
|
||||
const float* __restrict__ As, \
|
||||
const float* __restrict__ Bs, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K, \
|
||||
int64_t lda, \
|
||||
int64_t ldb, \
|
||||
int64_t ldc, \
|
||||
bool brg)
|
||||
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> per_token_quant_int8_cpu(at::Tensor& A) {
|
||||
RECORD_FUNCTION("sgl-kernel::per_token_quant_int8_cpu", std::vector<c10::IValue>({A}));
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(A);
|
||||
CHECK_DIM(2, A);
|
||||
|
||||
int64_t M = A.size(0);
|
||||
int64_t K = A.size(1);
|
||||
int64_t lda = A.stride(0);
|
||||
|
||||
const auto st = A.scalar_type();
|
||||
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, "per_token_quant_int8: expect A to be bfloat16 or half.");
|
||||
|
||||
auto Aq = at::empty({M, K}, A.options().dtype(at::kByte));
|
||||
auto As = at::empty({M}, A.options().dtype(at::kFloat));
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "per_token_quant_int8", [&] {
|
||||
uint8_t* __restrict__ Aq_data = Aq.data_ptr<uint8_t>();
|
||||
float* __restrict__ As_data = As.data_ptr<float>();
|
||||
const scalar_t* __restrict__ A_data = A.data_ptr<scalar_t>();
|
||||
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(Aq_data + m * K, As_data[m], A_data + m * lda, K);
|
||||
}
|
||||
});
|
||||
});
|
||||
return std::make_tuple(Aq, As);
|
||||
}
|
||||
|
||||
// weight : static, per-channel, symmetric
|
||||
// activation : dynamic, per-token, symmetric
|
||||
//
|
||||
// mat1 : [M, K]
|
||||
// mat2 : [N, K]
|
||||
// scales1 : [M]
|
||||
// scales2 : [N]
|
||||
// bias : [N]
|
||||
// out : [M, N]
|
||||
//
|
||||
at::Tensor int8_scaled_mm_cpu(
|
||||
at::Tensor& mat1,
|
||||
at::Tensor& mat2,
|
||||
at::Tensor& scales1,
|
||||
at::Tensor& scales2,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype,
|
||||
bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales1, scales2, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
CHECK_INPUT(scales1);
|
||||
CHECK_INPUT(scales2);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat1.size(1);
|
||||
|
||||
// see [NOTE]: s8s8 igemm compensation in avx512-vnni
|
||||
CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K));
|
||||
CHECK_EQ(scales1.numel(), M);
|
||||
CHECK_EQ(scales2.numel(), N);
|
||||
|
||||
TORCH_CHECK(mat1.scalar_type() == at::kByte, "int8_scaled_mm: expect mat1 to be uint8.");
|
||||
TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm: expect mat2 to be int8.");
|
||||
TORCH_CHECK(
|
||||
scales1.scalar_type() == at::kFloat && scales2.scalar_type() == at::kFloat,
|
||||
"int8_scaled_mm: expect scales to be float32.");
|
||||
|
||||
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_kernel_impl", [&] {
|
||||
int8_scaled_mm_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
mat1.data_ptr<uint8_t>(),
|
||||
packed_w.data_ptr<int8_t>(),
|
||||
scales1.data_ptr<float>(),
|
||||
scales2.data_ptr<float>(),
|
||||
bias_data,
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
// fused `per_token_quant_int8_cpu` and `int8_scaled_mm_cpu`
|
||||
at::Tensor int8_scaled_mm_with_quant(
|
||||
at::Tensor& mat1,
|
||||
at::Tensor& mat2,
|
||||
at::Tensor& scales2,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype,
|
||||
bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
CHECK_INPUT(scales2);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat1.size(1);
|
||||
int64_t lda = mat1.stride(0);
|
||||
|
||||
// see [NOTE]: s8s8 igemm compensation in avx512-vnni
|
||||
CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K));
|
||||
CHECK_EQ(scales2.numel(), N);
|
||||
|
||||
const auto st = mat1.scalar_type();
|
||||
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, "int8_scaled_mm_with_quant: expect A to be bfloat16 or half.");
|
||||
TORCH_CHECK(st == out_dtype, "int8_scaled_mm_with_quant: expect A has same dtype with out_dtype.");
|
||||
TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm_with_quant: expect mat2 to be int8.");
|
||||
TORCH_CHECK(scales2.scalar_type() == at::kFloat, "int8_scaled_mm_with_quant: expect scales to be float32.");
|
||||
|
||||
const int64_t buffer_size = M * K + M * sizeof(float);
|
||||
auto buffer = at::empty({buffer_size}, mat1.options().dtype(at::kByte));
|
||||
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_with_quant_kernel_impl", [&] {
|
||||
uint8_t* __restrict__ Aq_data = buffer.data_ptr<uint8_t>();
|
||||
float* __restrict__ As_data = (float*)((void*)(Aq_data + M * K));
|
||||
const scalar_t* __restrict__ A_data = mat1.data_ptr<scalar_t>();
|
||||
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(Aq_data + m * K, As_data[m], A_data + m * lda, K);
|
||||
}
|
||||
});
|
||||
|
||||
int8_scaled_mm_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
Aq_data,
|
||||
packed_w.data_ptr<int8_t>(),
|
||||
As_data,
|
||||
scales2.data_ptr<float>(),
|
||||
bias_data,
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
});
|
||||
return out;
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user