adapt to sglang v0.5.2rc1 on dcu

This commit is contained in:
maxiao
2025-09-04 15:56:33 +08:00
commit 909abb58f5
2320 changed files with 489411 additions and 0 deletions

15
sgl-kernel/.clang-format Normal file
View File

@@ -0,0 +1,15 @@
BasedOnStyle: Google
IndentWidth: 2
ColumnLimit: 120
AllowShortFunctionsOnASingleLine: Empty
DerivePointerAlignment: false
PointerAlignment: Left
NamespaceIndentation: None
SortIncludes: true
AllowShortLoopsOnASingleLine: false
BinPackParameters: false # Prevents packing parameters in declarations
BinPackArguments: false # Prevents packing arguments in function calls
AlignAfterOpenBracket: AlwaysBreak # Forces a break after the opening parenthesis
AlignOperands: Align # Aligns arguments vertically
PenaltyBreakBeforeFirstCallParameter: 1 # Encourages breaking before the first argument
PenaltyReturnTypeOnItsOwnLine: 100 # Keeps return type with function name

503
sgl-kernel/CMakeLists.txt Normal file
View File

@@ -0,0 +1,503 @@
cmake_minimum_required(VERSION 3.26 FATAL_ERROR)
project(sgl-kernel LANGUAGES CXX CUDA)
# CMake
cmake_policy(SET CMP0169 OLD)
cmake_policy(SET CMP0177 NEW)
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
set(CMAKE_COLOR_DIAGNOSTICS ON)
set(CMAKE_VERBOSE_MAKEFILE ON CACHE BOOL "ON")
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_SHARED_LIBRARY_PREFIX "")
# Python
find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
# CXX
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
# CUDA
enable_language(CUDA)
find_package(CUDAToolkit REQUIRED)
set_property(GLOBAL PROPERTY CUDA_SEPARABLE_COMPILATION ON)
message(STATUS "Detected CUDA_VERSION=${CUDA_VERSION}")
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "13.0")
message("CUDA_VERSION ${CUDA_VERSION} >= 13.0")
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8")
message("CUDA_VERSION ${CUDA_VERSION} >= 12.8")
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4")
message("CUDA_VERSION ${CUDA_VERSION} >= 12.4")
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.1")
message("CUDA_VERSION ${CUDA_VERSION} >= 12.1")
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "11.8")
message("CUDA_VERSION ${CUDA_VERSION} >= 11.8")
endif()
# Torch
find_package(Torch REQUIRED)
# clean Torch Flag
clear_cuda_arches(CMAKE_FLAG)
include(FetchContent)
# cutlass
FetchContent_Declare(
repo-cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
GIT_TAG a49a78ffefc86a87160dfe0ccc3a3a2d1622c918
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-cutlass)
# DeepGEMM
FetchContent_Declare(
repo-deepgemm
GIT_REPOSITORY https://github.com/sgl-project/DeepGEMM
GIT_TAG sgl
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-deepgemm)
FetchContent_Declare(
repo-fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt
GIT_TAG 553ec11ec06fbe0beebfbb45f9dc3c9eabd83d28
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-fmt)
# Triton
FetchContent_Declare(
repo-triton
GIT_REPOSITORY "https://github.com/triton-lang/triton"
GIT_TAG 8f9f695ea8fde23a0c7c88e4ab256634ca27789f
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-triton)
# flashinfer
FetchContent_Declare(
repo-flashinfer
GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git
GIT_TAG 018b551825c8e5579206e6eb9d3229fa679202b3
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-flashinfer)
# flash-attention
FetchContent_Declare(
repo-flash-attention
GIT_REPOSITORY https://github.com/sgl-project/sgl-attn
GIT_TAG sgl-kernel
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-flash-attention)
# mscclpp
FetchContent_Declare(
repo-mscclpp
GIT_REPOSITORY https://github.com/microsoft/mscclpp.git
GIT_TAG 51eca89d20f0cfb3764ccd764338d7b22cd486a6
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-mscclpp)
# ccache option
option(ENABLE_CCACHE "Whether to use ccache" ON)
find_program(CCACHE_FOUND ccache)
if(CCACHE_FOUND AND ENABLE_CCACHE AND DEFINED ENV{CCACHE_DIR})
message(STATUS "Building with CCACHE enabled")
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "ccache")
set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK "ccache")
endif()
# Enable gencode below SM90
option(ENABLE_BELOW_SM90 "Enable below SM90" ON)
if (CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
set(ENABLE_BELOW_SM90 OFF)
message(STATUS "For aarch64, disable gencode below SM90 by default")
endif()
include_directories(
${PROJECT_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/csrc
${repo-cutlass_SOURCE_DIR}/include
${repo-cutlass_SOURCE_DIR}/tools/util/include
${repo-flashinfer_SOURCE_DIR}/include
${repo-flashinfer_SOURCE_DIR}/csrc
${repo-mscclpp_SOURCE_DIR}/include
)
set(SGL_KERNEL_CUDA_FLAGS
"-DNDEBUG"
"-DOPERATOR_NAMESPACE=sgl-kernel"
"-O3"
"-Xcompiler"
"-fPIC"
"-gencode=arch=compute_90,code=sm_90"
"-std=c++17"
"-DFLASHINFER_ENABLE_F16"
"-DCUTE_USE_PACKED_TUPLE=1"
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
"-DCUTLASS_VERSIONS_GENERATED"
"-DCUTLASS_TEST_LEVEL=0"
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
"--expt-relaxed-constexpr"
"--expt-extended-lambda"
"--threads=32"
# Supress warnings
"-Xcompiler=-Wno-clang-format-violations"
"-Xcompiler=-Wno-conversion"
"-Xcompiler=-Wno-deprecated-declarations"
"-Xcompiler=-Wno-terminate"
"-Xcompiler=-Wfatal-errors"
"-Xcompiler=-ftemplate-backtrace-limit=1"
"-Xcudafe=--diag_suppress=177" # variable was declared but never referenced
# uncomment to debug
# "--ptxas-options=-v"
# "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage"
)
option(SGL_KERNEL_ENABLE_BF16 "Enable BF16" ON)
option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON)
option(SGL_KERNEL_ENABLE_FP4 "Enable FP4" OFF)
option(SGL_KERNEL_ENABLE_FA3 "Enable FA3" OFF)
option(SGL_KERNEL_ENABLE_SM90A "Enable SM90A" OFF)
option(SGL_KERNEL_ENABLE_SM100A "Enable SM100A" OFF)
if (SGL_KERNEL_ENABLE_BF16)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-DFLASHINFER_ENABLE_BF16"
)
endif()
if (SGL_KERNEL_ENABLE_FP8)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-DFLASHINFER_ENABLE_FP8"
"-DFLASHINFER_ENABLE_FP8_E4M3"
"-DFLASHINFER_ENABLE_FP8_E5M2"
)
endif()
if (ENABLE_BELOW_SM90)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_80,code=sm_80"
"-gencode=arch=compute_89,code=sm_89"
)
endif()
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_100,code=sm_100"
"-gencode=arch=compute_100a,code=sm_100a"
"-gencode=arch=compute_120,code=sm_120"
"-gencode=arch=compute_120a,code=sm_120a"
)
# refer sm_121, sm_110 and sm_101 description https://github.com/pytorch/pytorch/pull/156176
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "13.0")
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_103,code=sm_103"
"-gencode=arch=compute_103a,code=sm_103a"
"-gencode=arch=compute_110,code=sm_110"
"-gencode=arch=compute_110a,code=sm_110a"
"-gencode=arch=compute_121,code=sm_121"
"-gencode=arch=compute_121a,code=sm_121a"
"--compress-mode=size"
)
else()
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_101,code=sm_101"
"-gencode=arch=compute_101a,code=sm_101a"
)
endif()
else()
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-use_fast_math"
)
endif()
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A)
set(SGL_KERNEL_ENABLE_FA3 ON)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_90a,code=sm_90a"
)
endif()
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_FP4)
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-DENABLE_NVFP4=1"
)
endif()
set(SOURCES
"csrc/allreduce/custom_all_reduce.cu"
"csrc/allreduce/mscclpp_allreduce.cu"
"csrc/attention/cascade.cu"
"csrc/attention/cutlass_mla_kernel.cu"
"csrc/attention/lightning_attention_decode_kernel.cu"
"csrc/attention/merge_attn_states.cu"
"csrc/attention/vertical_slash_index.cu"
"csrc/elementwise/activation.cu"
"csrc/elementwise/cast.cu"
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
"csrc/elementwise/rope.cu"
"csrc/common_extension.cc"
"csrc/gemm/awq_kernel.cu"
"csrc/gemm/bmm_fp8.cu"
"csrc/gemm/dsv3_fused_a_gemm.cu"
"csrc/gemm/dsv3_router_gemm_bf16_out.cu"
"csrc/gemm/dsv3_router_gemm_entry.cu"
"csrc/gemm/dsv3_router_gemm_float_out.cu"
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
"csrc/gemm/fp8_gemm_kernel.cu"
"csrc/gemm/int8_gemm_kernel.cu"
"csrc/gemm/nvfp4_expert_quant.cu"
"csrc/gemm/nvfp4_quant_entry.cu"
"csrc/gemm/nvfp4_quant_kernels.cu"
"csrc/gemm/nvfp4_scaled_mm_entry.cu"
"csrc/gemm/nvfp4_scaled_mm_kernels.cu"
"csrc/gemm/per_tensor_quant_fp8.cu"
"csrc/gemm/per_token_group_quant_8bit.cu"
"csrc/gemm/per_token_quant_fp8.cu"
"csrc/gemm/qserve_w4a8_per_chn_gemm.cu"
"csrc/gemm/qserve_w4a8_per_group_gemm.cu"
"csrc/gemm/marlin/gptq_marlin.cu"
"csrc/gemm/marlin/gptq_marlin_repack.cu"
"csrc/gemm/marlin/awq_marlin_repack.cu"
"csrc/gemm/gptq/gptq_kernel.cu"
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
"csrc/moe/cutlass_moe/w4a8/scaled_mm_entry.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
"csrc/moe/marlin_moe_wna16/ops.cu"
"csrc/moe/moe_align_kernel.cu"
"csrc/moe/moe_fused_gate.cu"
"csrc/moe/moe_topk_softmax_kernels.cu"
"csrc/moe/nvfp4_blockwise_moe.cu"
"csrc/moe/fp8_blockwise_moe_kernel.cu"
"csrc/moe/prepare_moe_input.cu"
"csrc/moe/ep_moe_reorder_kernel.cu"
"csrc/moe/ep_moe_silu_and_mul_kernel.cu"
"csrc/memory/store.cu"
"csrc/kvcacheio/transfer.cu"
"csrc/speculative/eagle_utils.cu"
"csrc/speculative/packbit.cu"
"csrc/speculative/speculative_sampling.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_sm80.cu"
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/flash_sparse_api.cpp"
)
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
target_compile_options(common_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
target_include_directories(common_ops PRIVATE
${repo-cutlass_SOURCE_DIR}/examples/77_blackwell_fmha
${repo-cutlass_SOURCE_DIR}/examples/common
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src
)
find_package(Python3 COMPONENTS Interpreter REQUIRED)
execute_process(
COMMAND ${Python3_EXECUTABLE} -c "import torch; print(int(torch._C._GLIBCXX_USE_CXX11_ABI))"
OUTPUT_VARIABLE TORCH_CXX11_ABI
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(TORCH_CXX11_ABI STREQUAL "0")
message(STATUS "Using old C++ ABI (-D_GLIBCXX_USE_CXX11_ABI=0)")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
else()
message(STATUS "Using new C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI=1)")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1")
endif()
# mscclpp
set(MSCCLPP_USE_CUDA ON)
set(MSCCLPP_BYPASS_GPU_CHECK ON)
set(MSCCLPP_BUILD_TESTS OFF)
add_subdirectory(
${repo-mscclpp_SOURCE_DIR}
${CMAKE_CURRENT_BINARY_DIR}/mscclpp-build
)
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static)
# flash attention
target_compile_definitions(common_ops PRIVATE
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
FLASHATTENTION_DISABLE_UNEVEN_K
)
install(TARGETS common_ops LIBRARY DESTINATION sgl_kernel)
# ============================ Optional Install ============================= #
# set flash-attention sources file
# Now FA3 support sm80/sm86/sm90
if (SGL_KERNEL_ENABLE_FA3)
set(SGL_FLASH_KERNEL_CUDA_FLAGS
"-DNDEBUG"
"-DOPERATOR_NAMESPACE=sgl-kernel"
"-O3"
"-Xcompiler"
"-fPIC"
"-gencode=arch=compute_90a,code=sm_90a"
"-std=c++17"
"-DCUTE_USE_PACKED_TUPLE=1"
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
"-DCUTLASS_VERSIONS_GENERATED"
"-DCUTLASS_TEST_LEVEL=0"
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
"--expt-relaxed-constexpr"
"--expt-extended-lambda"
"--use_fast_math"
"-Xcompiler=-Wconversion"
"-Xcompiler=-fno-strict-aliasing"
)
if (ENABLE_BELOW_SM90)
list(APPEND SGL_FLASH_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_80,code=sm_80"
"-gencode=arch=compute_86,code=sm_86"
)
# SM8X Logic
file(GLOB FA3_SM8X_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdim*_sm80.cu")
endif()
file(GLOB FA3_BF16_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu")
file(GLOB FA3_BF16_GEN_SRCS_
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu")
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})
# FP16 source files
file(GLOB FA3_FP16_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu")
file(GLOB FA3_FP16_GEN_SRCS_
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu")
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})
# FP8 source files
file(GLOB FA3_FP8_GEN_SRCS
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu")
file(GLOB FA3_FP8_GEN_SRCS_
"${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu")
list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_})
set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS} ${FA3_SM8X_GEN_SRCS})
set(FLASH_SOURCES
"csrc/flash_extension.cc"
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_prepare_scheduler.cu"
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_api.cpp"
"${repo-flash-attention_SOURCE_DIR}/hopper/flash_fwd_combine.cu"
"${FA3_GEN_SRCS}"
)
Python_add_library(flash_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${FLASH_SOURCES})
target_compile_options(flash_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_FLASH_KERNEL_CUDA_FLAGS}>)
target_include_directories(flash_ops PRIVATE
${repo-flash-attention_SOURCE_DIR}/hopper
)
target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel")
set(FLASH_OPS_COMPILE_DEFS
FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
FLASHATTENTION_DISABLE_UNEVEN_K
FLASHATTENTION_VARLEN_ONLY
)
if(NOT ENABLE_BELOW_SM90)
list(APPEND FLASH_OPS_COMPILE_DEFS FLASHATTENTION_DISABLE_SM8x)
endif()
target_compile_definitions(flash_ops PRIVATE ${FLASH_OPS_COMPILE_DEFS})
endif()
# Build spatial_ops as a separate, optional extension for green contexts
set(SPATIAL_SOURCES
"csrc/spatial/greenctx_stream.cu"
"csrc/spatial_extension.cc"
)
Python_add_library(spatial_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SPATIAL_SOURCES})
target_compile_options(spatial_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
target_link_libraries(spatial_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
install(TARGETS spatial_ops LIBRARY DESTINATION sgl_kernel)
# ============================ DeepGEMM (JIT) ============================= #
# Create a separate library for DeepGEMM's Python API.
# This keeps its compilation isolated from the main common_ops.
set(DEEPGEMM_SOURCES
"${repo-deepgemm_SOURCE_DIR}/csrc/python_api.cpp"
)
Python_add_library(deep_gemm_cpp MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${DEEPGEMM_SOURCES})
# Link against necessary libraries, including nvrtc for JIT compilation.
target_link_libraries(deep_gemm_cpp PRIVATE ${TORCH_LIBRARIES} c10 cuda nvrtc mscclpp_static)
# Add include directories needed by DeepGEMM.
target_include_directories(deep_gemm_cpp PRIVATE
${repo-deepgemm_SOURCE_DIR}/deep_gemm/include
${repo-cutlass_SOURCE_DIR}/include
${repo-fmt_SOURCE_DIR}/include
)
# Apply the same compile options as common_ops.
target_compile_options(deep_gemm_cpp PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
# Create an empty __init__.py to make `deepgemm` a Python package.
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/deepgemm_pkg_init.py "")
install(
FILES ${CMAKE_CURRENT_BINARY_DIR}/deepgemm_pkg_init.py
DESTINATION deep_gemm
RENAME __init__.py
)
# Install the compiled DeepGEMM API library.
install(TARGETS deep_gemm_cpp LIBRARY DESTINATION deep_gemm)
# Install the source files required by DeepGEMM for runtime JIT compilation.
install(
DIRECTORY ${repo-deepgemm_SOURCE_DIR}/deep_gemm/
DESTINATION deep_gemm
)
install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cute/"
DESTINATION "deep_gemm/include/cute")
install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cutlass/"
DESTINATION "deep_gemm/include/cutlass")
# triton_kernels
install(DIRECTORY "${repo-triton_SOURCE_DIR}/python/triton_kernels/triton_kernels/"
DESTINATION "triton_kernels"
PATTERN ".git*" EXCLUDE
PATTERN "__pycache__" EXCLUDE)

201
sgl-kernel/LICENSE Normal file
View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

71
sgl-kernel/Makefile Normal file
View File

@@ -0,0 +1,71 @@
.PHONY: help check-deps install-deps tree ln submodule install build clean rebuild test format update
# Show help for each target
help: ## Show this help message
@echo "Available targets:"
@grep -E '^[a-zA-Z0-9_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}'
check-deps: ## Check and install required Python formatting dependencies
@command -v isort >/dev/null 2>&1 || (echo "Installing isort..." && pip install isort)
@command -v black >/dev/null 2>&1 || (echo "Installing black..." && pip install black)
install-deps: ## Install Python formatting tools (isort and black)
pip install scikit-build-core isort black
tree: ## Show project directory structure
@tree --prune -I "__pycache__|*.egg-info|*.so|build|3rdparty|dist"
submodule: ## Initialize and update git submodules
@git submodule update --init --recursive
ln: submodule ## Create compilation database
@rm -rf build && mkdir build && cd build && cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=YES -DCMAKE_POLICY_VERSION_MINIMUM=3.5
install: submodule ## Install package in development mode
@pip install -e . --no-build-isolation
build: install-deps submodule ## Build and install wheel package
@rm -rf dist/* || true && CMAKE_POLICY_VERSION_MINIMUM=3.5 MAX_JOBS=$(nproc) CMAKE_BUILD_PARALLEL_LEVEL=$(nproc) uv build --wheel -Cbuild-dir=build . --verbose --color=always --no-build-isolation && pip3 install dist/*whl --force-reinstall --no-deps
clean: ## Remove build artifacts
@rm -rf build dist *.egg-info
rebuild: clean submodule build ## Clean and rebuild the project
@echo "Succeed to rebuild"
test: ## Run all tests
@find tests -name "test_*.py" | xargs -n 1 python3
format: check-deps ## Format all source files
@echo "Formatting source files..."
@find csrc tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i
@find python tests -name '*.py' | xargs isort
@find python tests -name '*.py' | xargs black
@pre-commit run --all-files
FILES_TO_UPDATE = python/sgl_kernel/version.py \
pyproject.toml \
pyproject_rocm.toml \
pyproject_cpu.toml \
../docker/Dockerfile \
../.github/workflows/pr-test-pd-router.yml
update: ## Update version numbers across project files. Usage: make update <new_version>
@if [ -z "$(filter-out $@,$(MAKECMDGOALS))" ]; then \
echo "Version required. Usage: make update <new_version>"; \
exit 1; \
fi
@OLD_VERSION=$$(grep "version" python/sgl_kernel/version.py | cut -d '"' -f2); \
NEW_VERSION=$(filter-out $@,$(MAKECMDGOALS)); \
echo "Updating version from $$OLD_VERSION to $$NEW_VERSION"; \
for file in $(FILES_TO_UPDATE); do \
if [ "$(shell uname)" = "Darwin" ]; then \
sed -i '' -e "s/$$OLD_VERSION/$$NEW_VERSION/g" $$file; \
else \
sed -i -e "s/$$OLD_VERSION/$$NEW_VERSION/g" $$file; \
fi \
done; \
echo "Version update complete"
%:
@:

261
sgl-kernel/README.md Normal file
View File

@@ -0,0 +1,261 @@
# SGL Kernel
[Kernel Library](https://github.com/sgl-project/sglang/tree/main/sgl-kernel) for SGLang
[![PyPI](https://img.shields.io/pypi/v/sgl-kernel)](https://pypi.org/project/sgl-kernel)
## Installation
For CUDA 12.1 and above:
```bash
pip3 install sgl-kernel
```
For CUDA 11.8:
```bash
pip3 install sgl-kernel -i https://docs.sglang.ai/whl/cu118
```
## Build from source
Development build:
```bash
make build
```
Note:
The `sgl-kernel` is rapidly evolving. If you experience a compilation failure, try using `make rebuild`.
### Build with [ccache](https://github.com/ccache/ccache)
```bash
# or `yum install -y ccache`.
apt-get install -y ccache
# Building with ccache is enabled when ccache is installed and CCACHE_DIR is set.
export CCACHE_DIR=/path/to/your/ccache/dir
export CCACHE_BACKEND=""
export CCACHE_KEEP_LOCAL_STORAGE="TRUE"
unset CCACHE_READONLY
python -m uv build --wheel -Cbuild-dir=build --color=always .
```
### Configuring CMake Build Options
Cmake options can be configuring by adding `-Ccmake.define.<option>=<value>` to the `uv build` flags.
For example, to enable building FP4 kernels, use:
```bash
python -m uv build --wheel -Cbuild-dir=build -Ccmake.define.SGL_KERNEL_ENABLE_FP4=1 --color=always .
```
See CMakeLists.txt for more options.
### Parallel Build
We highly recommend you build sgl-kernel with Ninja. Ninja can automatically build sgl-kernel in parallel.
And if you build the sgl-kernel with cmake, you need to add `CMAKE_BUILD_PARALLEL_LEVEL` for parallel build like:
```bash
CMAKE_BUILD_PARALLEL_LEVEL=$(nproc) python -m uv build --wheel -Cbuild-dir=build --color=always .
```
### ⚠️ Compilation Issue with `sgl-kernel` and CUDA 12.6
When compiling `sgl-kernel` with FlashAttention on a Hopper GPU using CUDA 12.6, you may encounter a segmentation fault:
```bash
kernel/build/_deps/repo-flash-attention-src/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu -o CMakeFiles/flash_ops.dir/_deps/repo-flash-attention-src/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu.o
Segmentation fault (core dumped)
```
⚠️ **Note**: To ensure that FlashAttention compiles correctly on Hopper GPU Architecture(sm90), it is strongly [recommended](https://github.com/Dao-AILab/flash-attention/issues/1453) to use:
- nvcc version: 12.6
- ptxas version: 12.8
**1. Check Current Versions**
Before proceeding, verify your current CUDA tool versions:
```bash
nvcc --version
ptxas --version
```
**2. Update ptxas to 12.8 (if needed)**
1. Save the following script to a file (e.g., `update_ptxas.sh`).
```bash
#!/usr/bin/env bash
# Source: https://github.com/Dao-AILab/flash-attention/blob/7ff1b621112ba8b538e2fc6a316f2a6b6f22e518/hopper/setup.py#L404
set -ex
if [ -z "$1" ]; then
echo "Usage: $0 <CUDA_VERSION>"
exit 1
fi
CUDA_VERSION=$1
if awk "BEGIN {exit !("$CUDA_VERSION" >= 12.6 && "$CUDA_VERSION" < 12.8)}"; then
NVCC_ARCHIVE_VERSION="12.8.93"
NVCC_ARCHIVE_NAME="cuda_nvcc-linux-x86_64-${NVCC_ARCHIVE_VERSION}-archive"
NVCC_ARCHIVE_TAR="${NVCC_ARCHIVE_NAME}.tar.xz"
NVCC_ARCHIVE_URL="https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/linux-x86_64/${NVCC_ARCHIVE_TAR}"
wget "$NVCC_ARCHIVE_URL"
tar -xf "$NVCC_ARCHIVE_TAR"
mkdir -p /usr/local/cuda/bin
cp "${NVCC_ARCHIVE_NAME}/bin/ptxas" /usr/local/cuda/bin/
# Clean up temporary files
rm -f "${NVCC_ARCHIVE_TAR}"
rm -rf "${NVCC_ARCHIVE_NAME}"
fi
```
2. Run the script with your CUDA version as the argument, using `sudo`:
```bash
sudo bash update_ptxas.sh 12.6
# Check the version
ptxas --version
```
# Developer Guide
## Development Environment Setup
Use Docker to set up the development environment. See [Docker setup guide](https://github.com/sgl-project/sglang/blob/main/docs/developer_guide/development_guide_using_docker.md#setup-docker-container).
Create and enter development container:
```bash
docker run -itd --shm-size 32g --gpus all -v $HOME/.cache:/root/.cache --ipc=host --name sglang_zhyncs lmsysorg/sglang:dev /bin/zsh
docker exec -it sglang_zhyncs /bin/zsh
```
## Project Structure
### Dependencies
Third-party libraries:
- [CUTLASS](https://github.com/NVIDIA/cutlass)
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer)
- [DeepGEMM](https://github.com/deepseek-ai/DeepGEMM)
- [FlashAttention](https://github.com/Dao-AILab/flash-attention)
### FlashAttention FYI
FA3 can fail without a enough shared memory for a some shapes, such as higher hidden_dim or some special cases. Right now, fa3 is supported for sm80/sm87 and sm86/sm89.
The main different Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x.
And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. That means if you use **A100(tested)**/A*0/**L20(tested)**/L40/L40s/**3090(tested)** you can use fa3.
### Kernel Development
Steps to add a new kernel:
1. Implement the kernel in [csrc](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc)
2. Expose the interface in [include/sgl_kernel_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/include/sgl_kernel_ops.h)
3. Create torch extension in [csrc/common_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/csrc/common_extension.cc)
4. Update [CMakeLists.txt](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/CMakeLists.txt) to include new CUDA source
5. Expose Python interface in [python](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel)
### Development Tips
1. When implementing kernels in [csrc](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc), only define pure CUDA files and C++ interfaces. If you need to use `Torch::tensor`, use `<torch/all.h>` instead of `<torch/extension.h>`. Using `<torch/extension.h>` will cause compilation errors when using SABI.
2. When creating torch extensions, add the function definition with `m.def`, and device binding with `m.impl`:
- Using torch.compile need `m.def` with schema, it helps auto capture the custom kernel. Reference: [How to add FakeTensor](https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit?tab=t.0#heading=h.ptttacy8y1u9)
- How to write schema: [Schema reference](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func)
```cpp
// We need def with schema here for torch.compile
m.def(
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int "
"cublas_handle, int cuda_stream) -> ()");
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
```
3. When exposing Python interfaces, avoid using kwargs in C++ interface kernels.
**Avoid this:**
```cpp
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
q=query.view(query.shape[0], -1, head_size),
k=key.view(key.shape[0], -1, head_size),
q_rope=query.view(query.shape[0], -1, head_size),
k_rope=key.view(key.shape[0], -1, head_size),
cos_sin_cache=cos_sin_cache,
pos_ids=positions.long(),
interleave=(not is_neox),
cuda_stream=get_cuda_stream(),
)
```
**Use this instead:**
```cpp
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
query.view(query.shape[0], -1, head_size),
key.view(key.shape[0], -1, head_size),
query.view(query.shape[0], -1, head_size),
key.view(key.shape[0], -1, head_size),
cos_sin_cache,
positions.long(),
(not is_neox),
get_cuda_stream(),
)
```
### Integrating Third-Party Libraries with Data Type Conversion
When integrating new third-party libraries like flash-attention, you may encounter data type compatibility issues between the C++ interface and PyTorch bindings. For example, the third-party code might use `float` or `int` types, while PyTorch requires `double` and `int64_t`.
> The reason we need `double` and `int64_t` in torch binding is that TORCH_LIBRARY handles the `Python-to-C++` conversion process. Python's `float` data type actually corresponds to `double` in C++, while Python's `int` corresponds to `int64_t` in C++.
To address this issue, we provide the `make_pytorch_shim` function in [sgl_kernel_torch_shim](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/include/sgl_kernel_torch_shim.h) that handles data type conversions automatically.
When you need to support new data type conversions, you can easily add conversion functions like this:
```cpp
// Map `int` -> `int64_t`
template <>
struct pytorch_library_compatible_type<int> {
using type = int64_t;
static int convert_from_type(int64_t arg) {
TORCH_CHECK(arg <= std::numeric_limits<int>::max(), "int64_t value is too large to be converted to int");
TORCH_CHECK(arg >= std::numeric_limits<int>::min(), "int64_t value is too small to be converted to int");
return arg;
}
};
```
To use this with your library functions, simply wrap them with make_pytorch_shim:
```cpp
/*
* From flash-attention
*/
m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
```
### Testing & Benchmarking
1. Add pytest tests in [tests/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests), if you need to skip some test, please use `@pytest.mark.skipif`
```python
@pytest.mark.skipif(
skip_condition, reason="Nvfp4 Requires compute capability of 10 or above."
)
```
2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in [benchmark/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/benchmark)
3. Run test suite
### FAQ
- When encountering this error while compiling using ccache: `ImportError: /usr/local/lib/python3.10/dist-packages/sgl_kernel/common_ops.abi3.so: undefined symbol: _ZN3c108ListType3getERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEENS_4Type24SingletonOrSharedTypePtrIS9_EE`, please modify the last command as follows to resolve it: `python3 -m uv build --wheel -Cbuild-dir=build . --color=always --no-build-isolation` .
### Release new version
Update version in [pyproject.toml](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/pyproject.toml) and [version.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel/version.py)

View File

@@ -0,0 +1,488 @@
Notice for flashinfer-ai/flashinfer
-------------------------------
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-------------------------------------------------------------------------------------------------
Some of the code in this project are adapted from other open-source projects with different
licenses. This product also bundles some third-party components under other open source licenses.
This section summarizes those components and their licenses.
See licenses/ for text of these licenses.
BSD 3-Clause License
--------------------
include/flashinfer/attention/hopper/epilogue.cuh
include/flashinfer/attention/hopper/mainloop.cuh
include/flashinfer/attention/hopper/kernel_traits.cuh
include/flashinfer/attention/hopper/named_barrier.cuh
include/flashinfer/attention/hopper/tile_scheduler.cuh
include/flashinfer/attention/hopper/utils.cuh
BSD 3-Clause "New" License
--------------------------
3rdparty/cutlass
include/flashinfer/attention/hopper/block_sparse_gather.cuh
Notice for NVIDIA/TensorRT-LLM
-------------------------------
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Notice for deepseek-ai/DeepGEMM
-------------------------------
MIT License
Copyright (c) 2025 DeepSeek
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Notice for Dao-AILab/flash-attention
-------------------------------
BSD 3-Clause License
Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -0,0 +1,153 @@
# Benchmarks SGLang kernels versus vLLM across
# (kernel, dtype, batch_size, seq_len, dim) and prints speed-up.
import argparse
import itertools
import re
from typing import List, Tuple
import sgl_kernel
import torch
import torch.nn.functional as F
import triton
import triton.testing
from sgl_kernel import gelu_quick # activation-only kernel
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from vllm import _custom_ops as vllm_ops
if not hasattr(vllm_ops, "silu_and_mul"):
vllm_ops = torch.ops._C
def str2int_list(arg: str) -> List[int]:
if arg in ("", None):
return []
if re.fullmatch(r"\d+(,\d+)*", arg.strip()) is None:
raise argparse.ArgumentTypeError(f"Bad int list: {arg}")
return [int(x) for x in arg.split(",")]
def calculate_diff(
kernel: str, dtype: torch.dtype, batch_size: int, seq_len: int, dim: int
) -> bool:
"""Compare vLLM with SGLang for one shape."""
device = torch.device("cuda")
# activation-only quick GELU
if kernel == "gelu_quick":
x = torch.randn(batch_size, seq_len, dim, dtype=dtype, device=device)
ref_out = torch.zeros_like(x)
getattr(vllm_ops, kernel)(ref_out, x)
test_out = getattr(sgl_kernel, kernel)(x)
# fused activation x mul kernels
else:
x = torch.randn(batch_size, seq_len, 2 * dim, dtype=dtype, device=device)
ref_out = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
getattr(vllm_ops, kernel)(ref_out, x)
test_out = getattr(sgl_kernel, kernel)(x)
ok = torch.allclose(ref_out, test_out, rtol=1e-3, atol=1e-5)
tag = "✅ match" if ok else "❌ mismatch"
print(
f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | "
f"L={seq_len:3d} | D={dim:5d}] {tag}"
)
return ok
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul", "gelu_quick"]
dtypes = [torch.float16, torch.bfloat16]
def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[Tuple]:
return list(itertools.product(kernels, dtypes, bsizes, slens, dims_))
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64
default_dims = [2**i for i in range(7, 15)] # 128...16384
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["kernel", "dtype", "batch_size", "seq_len", "dim"],
x_vals=[],
line_arg="provider",
line_vals=["vllm", "sglang", "speedup"],
line_names=["vLLM", "SGL Kernel", "Speed-up (x)"],
styles=[("blue", "-"), ("green", "-"), ("red", "--")],
ylabel="µs (median) or × (speed-up)",
plot_name="activation-performance",
args={},
)
)
def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
device = torch.device("cuda")
in_mult = 1 if kernel == "gelu_quick" else 2
x = torch.randn(batch_size, seq_len, in_mult * dim, dtype=dtype, device=device)
y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
vllm_kernel = getattr(vllm_ops, kernel)
sglang_kernel = getattr(sgl_kernel, kernel)
def baseline():
tmp = y0.clone()
vllm_kernel(tmp, x)
return tmp
def sglang():
return sglang_kernel(x)
# one-time correctness check
if provider == "vllm" and not calculate_diff(
kernel, dtype, batch_size, seq_len, dim
):
raise ValueError("Mismatch abort benchmark")
# timing helper
def timed(fn):
for _ in range(5):
fn()
torch.cuda.synchronize()
ms, qmin, qmax = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
return 1000 * ms, 1000 * qmax, 1000 * qmin
if provider == "vllm":
return timed(baseline)
if provider == "sglang":
return timed(sglang)
# provider == "speedup"
t_ref, _, _ = timed(baseline)
t_sgl, _, _ = timed(sglang)
spd = t_ref / t_sgl
return (spd, spd, spd)
if __name__ == "__main__":
p = argparse.ArgumentParser("Activation kernel benchmark")
p.add_argument("--batch_sizes", type=str2int_list, default=default_batch_sizes)
p.add_argument("--seq_lens", type=str2int_list, default=default_seq_lens)
p.add_argument("--dims", type=str2int_list, default=default_dims)
p.add_argument("--verify_only", action="store_true")
args = p.parse_args()
# coerce lists
if isinstance(args.batch_sizes, str):
args.batch_sizes = str2int_list(args.batch_sizes)
if isinstance(args.seq_lens, str):
args.seq_lens = str2int_list(args.seq_lens)
if isinstance(args.dims, str):
args.dims = str2int_list(args.dims)
# patch perf_report grid
benchmark_grid = make_configs(args.batch_sizes, args.seq_lens, args.dims)
if hasattr(benchmark, "benchmarks"):
benchmark.benchmarks.x_vals = benchmark_grid
else:
benchmark.benchmark.x_vals = benchmark_grid
if args.verify_only:
ok = calculate_diff("gelu_quick", torch.float16, 1, 1, args.dims[0])
print("✅ sanity pass" if ok else "❌ mismatch")
else:
benchmark.run(print_data=True)

View File

@@ -0,0 +1,118 @@
import itertools
from typing import List, Tuple
import torch
import triton
import triton.testing
from sgl_kernel import awq_dequantize
from vllm import _custom_ops as ops
def vllm_awq_dequantize(
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
def sglang_awq_dequantize(
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
return awq_dequantize(qweight, scales, qzeros)
def calculate_diff(qweight_row: int, qweight_col: int):
"""Calculate difference between VLLM and SGLang implementations."""
device = torch.device("cuda")
qweight = torch.randint(
0,
torch.iinfo(torch.int32).max,
(qweight_row, qweight_col),
dtype=torch.int32,
device=device,
)
group_size = qweight_row
scales_row = qweight_row // group_size
scales_col = qweight_col * 8
scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
qzeros = torch.randint(
0,
torch.iinfo(torch.int32).max,
(scales_row, qweight_col),
dtype=torch.int32,
device=device,
)
vllm_out = vllm_awq_dequantize(qweight, scales, qzeros)
sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)
output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
if torch.allclose(
vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
qweight_row_range = [3584, 18944, 128, 256, 512, 1024]
qweight_cols_range = [448, 576, 4736, 16, 32, 64, 128]
configs = list(itertools.product(qweight_row_range, qweight_cols_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["qweight_row", "qweight_col"],
x_vals=configs,
line_arg="provider",
line_vals=["vllm", "sglang"],
line_names=["VLLM", "SGL Kernel"],
styles=[("blue", "-"), ("green", "-")],
ylabel="us",
plot_name="awq-dequantize-performance",
args={},
)
)
def benchmark(qweight_row, qweight_col, provider):
dtype = torch.float16
device = torch.device("cuda")
qweight = torch.randint(
0,
torch.iinfo(torch.int32).max,
(qweight_row, qweight_col),
dtype=torch.int32,
device=device,
)
group_size = qweight_row
scales_row = qweight_row // group_size
scales_col = qweight_col * 8
scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
qzeros = torch.randint(
0,
torch.iinfo(torch.int32).max,
(scales_row, qweight_col),
dtype=torch.int32,
device=device,
)
quantiles = [0.5, 0.2, 0.8]
if provider == "vllm":
fn = lambda: vllm_awq_dequantize(
qweight.clone(), scales.clone(), qzeros.clone()
)
elif provider == "sglang":
fn = lambda: sglang_awq_dequantize(
qweight.clone(), scales.clone(), qzeros.clone()
)
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
calculate_diff(qweight_row=3584, qweight_col=448)
benchmark.run(print_data=True)

View File

@@ -0,0 +1,145 @@
import argparse
import copy
import itertools
import torch
import triton
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
bs_range = [1, 8, 32, 64, 128, 256]
qlen_range = [1, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
configs = list(itertools.product(bs_range, qlen_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len"],
x_vals=configs,
x_log=False,
line_arg="provider",
line_vals=[
"128 heads",
"64 heads",
"32 heads",
"16 heads",
],
line_names=[
"128 heads",
"64 heads",
"32 heads",
"16 heads",
],
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
ylabel="GB/s",
plot_name="cutlass mla",
args={},
)
)
def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
d = 576
dn = 64
dv = 512
h_q_map = {
"128": 128,
"64": 64,
"32": 32,
"16": 16,
}
parsed_h_q = next(
(value for key, value in h_q_map.items() if key in provider), None
)
if parsed_h_q is None:
raise ValueError(f"Unknown head configuration in provider: {provider}")
h_q = parsed_h_q
seq_lens = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda")
max_seq_len = seq_lens.max().item()
block_num = (max_seq_len + block_size - 1) // block_size
# Pad block_num so that small blocks can be packed into full 128-sized CUTLASS tiles.
# One 128-wide tile can hold (128 // block_size) small blocks.
pack_factor = 128 // block_size
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor
qn = (
torch.randn(h_q, batch_size, d - dn, dtype=torch.bfloat16, device="cuda")
* 100.0
)
qr = torch.randn(batch_size, h_q, dn, dtype=torch.bfloat16, device="cuda") * 100.0
block_table = torch.randint(
0,
batch_size * block_num,
(batch_size, block_num),
dtype=torch.int32,
device="cuda",
)
kv_cache = torch.randn(
block_table.numel(), block_size, d, dtype=torch.bfloat16, device="cuda"
)
workspace_size = cutlass_mla_get_workspace_size(
block_num * block_size, batch_size, num_kv_splits=num_kv_splits
)
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: cutlass_mla_decode(
qn.transpose(0, 1),
qr,
kv_cache,
seq_lens,
block_table,
workspace,
1.44,
num_kv_splits,
),
quantiles=quantiles,
)
q_size = qn.numel() * qn.element_size() + qr.numel() * qr.element_size()
gbps = (
lambda ms: (
q_size + q_size * dv / d + kv_cache.numel() * kv_cache.element_size()
)
* 1e-9
/ (ms * 1e-3)
)
return gbps(ms), gbps(max_ms), gbps(min_ms)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--block-sizes",
nargs="+",
type=int,
default=[1, 32, 64, 128],
help="List of batch sizes",
)
parser.add_argument(
"--num-kv-splits",
nargs="+",
type=int,
default=[-1],
help="List of batch sizes",
)
args = parser.parse_args()
for block_size in args.block_sizes:
for kv_split in args.num_kv_splits:
print(f"block_size={block_size}, num_kv_splits={kv_split}: ")
benchmark.run(
print_data=True,
show_plots=True,
save_path="bench_blackwell_mla_res",
block_size=block_size,
num_kv_splits=kv_split,
)
print("Benchmark finished!")

View File

@@ -0,0 +1,57 @@
import argparse
import torch
import torch.nn.functional as F
import triton
import triton.testing
from sgl_kernel import dsv3_fused_a_gemm
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=[i + 1 for i in range(16)],
x_log=False,
line_arg="impl",
line_vals=["torch", "sgl-kernel"],
line_names=["torch (bf16)", "dsv3_fused_a_gemm"],
styles=[("blue", "-"), ("orange", "-")],
ylabel="TFLOPs",
plot_name="bf16 dsv3 fused a GEMM throughput",
args={},
)
)
def benchmark(num_tokens, impl):
kHdIn = 7168
kHdOut = 2112
M, K, N = num_tokens, kHdIn, kHdOut
mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").transpose(0, 1)
quantiles = [0.5, 0.2, 0.8]
if impl == "torch":
def runner():
F.linear(mat_a, mat_b.T)
elif impl == "sgl-kernel":
def runner():
dsv3_fused_a_gemm(mat_a, mat_b)
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
def tflops(t_ms):
flops = 2 * M * K * N
return flops / (t_ms * 1e-3) / 1e12
return tflops(ms), tflops(max_ms), tflops(min_ms)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
args = parser.parse_args()
benchmark.run(print_data=True, show_plots=True, save_path="bench_dsv3_gemm")

View File

@@ -0,0 +1,127 @@
import argparse
import torch
import torch.nn.functional as F
import triton
import triton.testing
from sgl_kernel import dsv3_router_gemm
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=[i + 1 for i in range(16)],
x_log=False,
line_arg="impl",
line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"],
line_names=[
"torch-256",
"dsv3_router_gemm-256",
"torch-384",
"dsv3_router_gemm-384",
],
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
ylabel="TFLOPs",
plot_name="input-bf16-output-bf16 dsv3 router gemm throughput",
args={},
)
)
def benchmark_bf16_output(num_tokens, impl):
# M: num_tokens, K: hidden_dim, N: num_experts
M, K = num_tokens, 7168
if impl == "torch-256" or impl == "sgl-kernel-256":
N = 256
elif impl == "torch-384" or impl == "sgl-kernel-384":
N = 384
else:
raise ValueError(f"Unknown impl: {impl}")
mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous()
quantiles = [0.5, 0.2, 0.8]
if impl == "torch-256" or impl == "torch-384":
def runner():
F.linear(mat_a, mat_b)
elif impl == "sgl-kernel-256" or impl == "sgl-kernel-384":
def runner():
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16)
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
def tflops(t_ms):
flops = 2 * M * K * N
return flops / (t_ms * 1e-3) / 1e12
return tflops(ms), tflops(max_ms), tflops(min_ms)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=[i + 1 for i in range(16)],
x_log=False,
line_arg="impl",
line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"],
line_names=[
"torch-256",
"dsv3_router_gemm-256",
"torch-384",
"dsv3_router_gemm-384",
],
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
ylabel="TFLOPs",
plot_name="input-bf16-output-fp32 dsv3 router gemm throughput",
args={},
)
)
def benchmark_float_output(num_tokens, impl):
# M: num_tokens, K: hidden_dim, N: num_experts
M, K = num_tokens, 7168
if impl == "torch-256" or impl == "sgl-kernel-256":
N = 256
elif impl == "torch-384" or impl == "sgl-kernel-384":
N = 384
else:
raise ValueError(f"Unknown impl: {impl}")
mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").contiguous()
quantiles = [0.5, 0.2, 0.8]
if impl == "torch-256" or impl == "torch-384":
def runner():
F.linear(mat_a, mat_b).to(torch.float32)
elif impl == "sgl-kernel-256" or impl == "sgl-kernel-384":
def runner():
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32)
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
def tflops(t_ms):
flops = 2 * M * K * N
return flops / (t_ms * 1e-3) / 1e12
return tflops(ms), tflops(max_ms), tflops(min_ms)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
args = parser.parse_args()
benchmark_bf16_output.run(
print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm"
)
benchmark_float_output.run(
print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm"
)

View File

@@ -0,0 +1,210 @@
import argparse
import copy
import csv
import itertools
import pytest
import torch
import triton
from flashinfer import mm_fp4
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
def get_weight_shapes(args):
models_tps = args.tp_sizes
if models_tps == [4]:
return [[1024, 3584], [7168, 256], [7168, 2304], [9216, 3584]]
if models_tps == [8]:
return [[512, 3584], [7168, 128], [7168, 1152], [4608, 3584]]
return [
[1024, 3584],
[7168, 256],
[7168, 2304],
[9216, 3584],
[512, 3584],
[7168, 128],
[7168, 1152],
[4608, 3584],
]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[
1,
2,
4,
8,
16,
32,
64,
128,
256,
512,
1024,
2048,
3072,
4096,
8192,
16384,
],
# x_vals = [64],
x_log=False,
line_arg="provider",
line_vals=["cutlass", "cudnn", "trtllm"],
line_names=["baseline cutlass fp4", "cudnn fp4", "trtllm fp4"],
styles=[("red", "solid"), ("blue", "solid"), ("green", "solid")],
ylabel="latency (ms)",
plot_name="fp4_gemm_benchmark",
args={},
)
)
def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file):
M = batch_size
packed_k = K
K = 2 * packed_k
a_dtype = torch.randn((M, K), dtype=dtype, device="cuda")
b_dtype = torch.randn((N, K), dtype=dtype, device="cuda")
a_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)
).to(torch.float32)
b_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
).to(torch.float32)
alpha = 1.0 / (a_global_scale * b_global_scale)
a_fp4, a_scale_interleaved = scaled_fp4_quant(a_dtype, a_global_scale)
# print("a_fp4", a_fp4)
b_fp4, b_scale_interleaved = scaled_fp4_quant(b_dtype, b_global_scale)
res_fi = torch.empty((M, N), dtype=dtype, device="cuda")
quantiles = [0.5, 0.2, 0.8]
if provider == "cutlass":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: cutlass_scaled_fp4_mm(
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
),
quantiles=quantiles,
)
if provider == "cudnn":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: mm_fp4(
a_fp4,
b_fp4.T,
a_scale_interleaved,
b_scale_interleaved.T,
alpha,
dtype,
res_fi,
),
quantiles=quantiles,
)
if provider == "trtllm":
a_scale_interleaved = a_scale_interleaved.to(torch.uint8)
b_scale_interleaved = b_scale_interleaved.to(torch.uint8)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: mm_fp4(
a_fp4,
b_fp4.T,
a_scale_interleaved,
b_scale_interleaved.T,
alpha,
dtype,
res_fi,
backend="trtllm",
),
quantiles=quantiles,
)
if correctness:
res_cutlass = cutlass_scaled_fp4_mm(
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
)
mm_fp4(
a_fp4,
b_fp4.T,
a_scale_interleaved,
b_scale_interleaved.T,
alpha,
dtype,
res_fi,
backend="cudnn",
)
assert torch.allclose(
res_fi, res_cutlass, atol=1e-3, rtol=1e-3
), "cudnn fp4 doesn't match cutlass fp4"
mm_fp4(
a_fp4,
b_fp4.T,
a_scale_interleaved,
b_scale_interleaved.T,
alpha,
dtype,
res_fi,
backend="trtllm",
)
assert torch.allclose(
res_fi, res_cutlass, atol=1e-3, rtol=1e-3
), "trtllm fp4 doesn't match cutlass fp4"
if csv_file:
with open(csv_file, "a", newline="") as f:
writer = csv.writer(f)
writer.writerow([provider, M, N, K, ms])
return ms, min_ms, max_ms
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tp-sizes",
nargs="+",
type=int,
default=[1],
help="List of tensor parallel sizes",
)
parser.add_argument(
"--dtype",
type=torch.dtype,
default=torch.bfloat16,
help="Data type",
)
parser.add_argument(
"--correctness",
action="store_true",
help="Check correctness",
)
parser.add_argument(
"--csv",
type=str,
default="results_cutlass_cudnn.csv",
help="CSV file to save results",
)
args = parser.parse_args()
if args.csv:
with open(args.csv, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["provider", "m", "n", "k", "time_ms"])
NKs = get_weight_shapes(args)
for N, K in NKs:
print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ")
benchmark.run(
print_data=True,
show_plots=True,
save_path="bench_fp4_res",
N=N,
K=K,
dtype=args.dtype,
correctness=args.correctness,
csv_file=args.csv,
)
print("Benchmark finished!")

View File

@@ -0,0 +1,183 @@
import argparse
import copy
import itertools
import deep_gemm
import torch
import triton
from deep_gemm import get_col_major_tma_aligned_tensor
from sgl_kernel import fp8_blockwise_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from sglang.srt.layers.quantization.fp8_kernel import (
w8a8_block_fp8_matmul_triton as w8a8_block_fp8_matmul,
)
def get_weight_shapes(args):
models_tps = list(itertools.product(args.models, args.tp_sizes))
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model.
# cannot TP
total = [
(512 + 64, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(7168, 16384),
(7168, 18432),
]
# N can TP
n_tp = [
(18432 * 2, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(24576, 1536),
(4096, 7168),
]
# K can TP
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
# only support Deepseek-V3
SUPPORT_MODEL = ["deepseek-ai/DeepSeek-V3"]
weight_shapes = []
for model, tp_size in models_tps:
assert model in SUPPORT_MODEL
for t in total:
new_t = [t[0], t[1], model]
weight_shapes.append(new_t)
for n_t in n_tp:
new_t = [n_t[0] // tp_size, n_t[1], model]
weight_shapes.append(new_t)
for k_t in k_tp:
new_t = [k_t[0], k_t[1] // tp_size, model]
weight_shapes.append(new_t)
return weight_shapes
def cdiv(a: int, b: int) -> int:
"""Ceiling division."""
return -(a // -b)
def fp8_gemm_deepgemm(
x_fp8: torch.Tensor,
x_scale: torch.Tensor,
y_fp8: torch.Tensor,
y_scale: torch.Tensor,
m: int,
n: int,
k: int,
):
"""DeepGEMM implementation of FP8 GEMM"""
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
# Run DeepGEMM kernel
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
return out
def scale_shape(shape, group_shape):
assert len(shape) == len(group_shape)
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096],
x_log=False,
line_arg="provider",
line_vals=["vllm", "sgl-kernel", "triton", "deepgemm"],
line_names=["vllm", "sgl-kernel", "sglang triton", "deepgemm"],
styles=[("blue", "-"), ("orange", "-"), ("red", "-"), ("yellow", "-")],
ylabel="GB/s",
plot_name="fp8 blockwise scaled matmul",
args={},
)
)
def benchmark(batch_size, provider, N, K):
M = batch_size
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a_fp32 = (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
b_fp32 = (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
scale_a_group_shape = (1, 128)
scale_b_group_shape = (128, 128)
scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape)
scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape)
scale_a = torch.randn(scale_a_shape, device="cuda", dtype=torch.float32)
scale_b = torch.randn(scale_b_shape, device="cuda", dtype=torch.float32)
quantiles = [0.5, 0.2, 0.8]
if provider == "sgl-kernel":
scale_a = scale_a.t().contiguous().t()
b_fp8, scale_b = b_fp8.t(), scale_b.t()
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fp8_blockwise_scaled_mm(
a_fp8, b_fp8, scale_a, scale_b, torch.float16
),
quantiles=quantiles,
)
if provider == "vllm":
scale_a = scale_a.t().contiguous().t()
b_fp8, scale_b = b_fp8.t(), scale_b.t()
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16),
quantiles=quantiles,
)
if provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: w8a8_block_fp8_matmul(
a_fp8, b_fp8, scale_a, scale_b, [128, 128], torch.float16
),
quantiles=quantiles,
)
if provider == "deepgemm":
scale_a_col_major = get_col_major_tma_aligned_tensor(scale_a.clone())
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fp8_gemm_deepgemm(
a_fp8, scale_a_col_major, b_fp8, scale_b, M, N, K
),
quantiles=quantiles,
)
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["deepseek-ai/DeepSeek-V3"],
help="List of models to benchmark",
)
parser.add_argument(
"--tp-sizes",
nargs="+",
type=int,
default=[1],
help="List of tensor parallel sizes",
)
args = parser.parse_args()
NK_model_names = get_weight_shapes(args)
for N, K, model_name in NK_model_names:
if N % 128 != 0 or K % 128 != 0:
print(f"Skip {N=}, {K=} now")
continue
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True,
show_plots=True,
save_path="bench_fp8_blockwise_res",
N=N,
K=K,
)
print("Benchmark finished!")

View File

@@ -0,0 +1,328 @@
import argparse
import random
from dataclasses import dataclass
from typing import List, Tuple
import deep_gemm
import torch
from sgl_kernel import fp8_blockwise_scaled_grouped_mm
def get_m_alignment_for_contiguous_layout():
return 128
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
pad_size = (128 - (n % 128)) % 128
x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device
)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(
x_view.size(0), x_view.size(2)
)
def construct_contiguous_grouped(
num_groups: int, expected_m_per_group: int, k: int, n: int
) -> Tuple[
int,
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor],
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
alignment = get_m_alignment_for_contiguous_layout()
group_ms = [int(expected_m_per_group) for _ in range(num_groups)]
m = sum([ceil_div(x, alignment) * alignment for x in group_ms])
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
y = torch.randn((num_groups, n, k), device="cuda", dtype=torch.bfloat16)
m_indices = torch.empty(m, device="cuda", dtype=torch.int32)
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
start = 0
for i, group_m in enumerate(group_ms):
actual_end = start + group_m
aligned_end = start + ceil_div(group_m, alignment) * alignment
m_indices[start:actual_end] = i
m_indices[actual_end:aligned_end] = -1
start = aligned_end
assert m % 4 == 0, f"TMA alignment error: {m}"
x_fp8 = per_token_cast_to_fp8(x)
y_fp8 = (
torch.empty_like(y, dtype=torch.float8_e4m3fn),
torch.empty(
(num_groups, ceil_div(n, 128), k // 128), device="cuda", dtype=torch.float
),
)
for i in range(num_groups):
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
return m, x_fp8, y_fp8, m_indices, out
def bench_deepgemm(
expected_m_per_group: int,
n: int,
k: int,
num_groups: int,
num_warmup: int,
num_run: int,
) -> Tuple[float, int]:
# construct tensors
m, x_fp8, y_fp8, m_indices, out = construct_contiguous_grouped(
num_groups, expected_m_per_group, k, n
)
def run_deepgemm():
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(x_fp8, y_fp8, out, m_indices)
# warmup
for _ in range(num_warmup):
run_deepgemm()
torch.cuda.synchronize()
# run
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
latencies: list[float] = []
start_event.record()
for _ in range(num_run):
run_deepgemm()
end_event.record()
end_event.synchronize()
torch.cuda.synchronize()
avg = start_event.elapsed_time(end_event) / num_run * 1000 # us
return avg, m
def bench_cutlass(
expected_m_per_group: int,
n: int,
k: int,
num_groups: int,
num_warmup: int,
num_run: int,
) -> Tuple[float, int]:
device = "cuda"
alignment = 16
n_g = ceil_div(n, alignment) * alignment
k_g = ceil_div(k, alignment) * alignment
out_dtype = torch.bfloat16
expert_offsets = torch.zeros((num_groups + 1), device=device, dtype=torch.int32)
problem_sizes = torch.zeros((num_groups, 3), device=device, dtype=torch.int32)
layout_sfa = torch.zeros((num_groups, 5), device=device, dtype=torch.int32)
layout_sfb = torch.zeros((num_groups, 5), device=device, dtype=torch.int32)
a_tensors = []
b_tensors = []
a_scales_tensors = []
b_scales_tensors = []
# TODO(@TianQiLin666666): Unique group_ms in all bench function
group_ms = [
alignment * ceil_div(int(expected_m_per_group), alignment)
for _ in range(num_groups)
]
for g in range(num_groups):
m_g = group_ms[g]
expert_offsets[g + 1] = expert_offsets[g] + m_g
problem_sizes[g][:] = torch.tensor([m_g, n_g, k_g], device=device)
a_g, a_scale = per_token_cast_to_fp8(torch.randn((m_g, k_g), device=device))
b_g, b_scale = per_block_cast_to_fp8(torch.randn((n_g, k_g), device=device).t())
a_tensors.append(a_g)
b_tensors.append(b_g)
a_scales_tensors.append(a_scale)
b_scales_tensors.append(b_scale)
a_stack = torch.empty(
(expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn
)
b_stack = torch.empty(
(num_groups, n_g, k_g), device=device, dtype=torch.float8_e4m3fn
)
for g in range(num_groups):
a_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[g]
b_stack[g] = b_tensors[g].t()
b_stack = b_stack.transpose(1, 2)
a_scale_stack = torch.empty(
(expert_offsets[-1], k_g // 128), device=device, dtype=torch.float32
)
b_scale_stack = torch.empty(
(num_groups, n_g // 128, k_g // 128), device=device, dtype=torch.float32
)
for g in range(num_groups):
a_scale_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_scales_tensors[g]
b_scale_stack[g] = b_scales_tensors[g].t()
b_scale_stack = b_scale_stack.transpose(1, 2)
c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype)
a_strides = torch.full(
(num_groups,), a_stack.stride(0), device=device, dtype=torch.int64
)
c_strides = torch.full(
(num_groups,), c_out.stride(0), device=device, dtype=torch.int64
)
workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
a_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
b_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
out_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
a_scales_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
b_scales_ptrs = torch.empty((num_groups,), device=device, dtype=torch.int64)
def run_cutlass():
fp8_blockwise_scaled_grouped_mm(
c_out,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a_stack,
b_stack,
a_scale_stack,
b_scale_stack,
a_strides,
a_strides,
c_strides,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets[:-1],
workspace,
)
# warmup
for _ in range(num_warmup):
run_cutlass()
torch.cuda.synchronize()
# run
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(num_run):
run_cutlass()
end_event.record()
end_event.synchronize()
torch.cuda.synchronize()
avg = start_event.elapsed_time(end_event) / num_run * 1000 # us
return avg, expert_offsets[-1]
def bench_sglang_triton(
expected_m_per_group: int,
n: int,
k: int,
num_groups: int,
num_warmup: int,
num_run: int,
) -> Tuple[float, int]:
pass
benchmark_kernels = {
"deepgemm": bench_deepgemm,
"cutlass": bench_cutlass,
# "triton": bench_sglang_triton,
}
@dataclass
class ShapeArg:
expected_m_per_group: int
n: int
k: int
num_groups: int
def benchmark_one_shape(
shape_args: List[ShapeArg],
num_warmup: int,
num_run: int,
):
for shape in shape_args:
print(
f"\nBenchmark: expected_m_per_group={shape.expected_m_per_group}, n={shape.n}, k={shape.k}, num_groups={shape.num_groups}"
)
for kernel_name, kernel_func in benchmark_kernels.items():
average_time, m = kernel_func(
shape.expected_m_per_group,
shape.n,
shape.k,
shape.num_groups,
num_warmup,
num_run,
)
print(f"{kernel_name}: {average_time} us")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--num-warmup", type=int, default=3)
parser.add_argument("--num-run", type=int, default=10)
shape_args = [
# Prefill, DeepSeek-R1, gateup, chunk_size = 4096, TP = 8
ShapeArg(expected_m_per_group=128, n=512, k=7168, num_groups=256),
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 8
ShapeArg(expected_m_per_group=256, n=512, k=7168, num_groups=256),
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 16
ShapeArg(expected_m_per_group=256, n=256, k=7168, num_groups=256),
# Prefill, DeepSeek-R1, gateup, chunk_size = 16384, TP = 16
ShapeArg(expected_m_per_group=512, n=256, k=7168, num_groups=256),
# Decode, DeepSeek-R1, gateup, bs = 32, TP = 8
ShapeArg(expected_m_per_group=1, n=512, k=7168, num_groups=256),
# Decode, DeepSeek-R1, gateup, bs = 64, TP = 16
ShapeArg(expected_m_per_group=2, n=256, k=7168, num_groups=256),
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, EP = 8
ShapeArg(expected_m_per_group=256, n=4096, k=7168, num_groups=32),
# Prefill, DeepSeek-R1, gateup, chunk_size = 16384, EP = 16
ShapeArg(expected_m_per_group=512, n=4096, k=7168, num_groups=16),
# Decode, DeepSeek-R1, gateup, bs = 128, EP = 8
ShapeArg(expected_m_per_group=4, n=4096, k=7168, num_groups=32),
# Decode, DeepSeek-R1, gateup, bs = 256, EP = 16
ShapeArg(expected_m_per_group=8, n=4096, k=7168, num_groups=16),
# Prefill, Qwen3-235B-A22B-FP8, gateup, chunk_size = 16384, TP = 4
ShapeArg(expected_m_per_group=1024, n=768, k=4096, num_groups=128),
# Prefill, Qwen3-235B-A22B-FP8, down, chunk_size = 16384, TP = 4
ShapeArg(expected_m_per_group=1024, n=4096, k=384, num_groups=128),
# Decode, Qwen3-235B-A22B-FP8, gateup, bs = 256, TP = 4
ShapeArg(expected_m_per_group=16, n=768, k=4096, num_groups=128),
# Decode, Qwen3-235B-A22B-FP8, down, bs = 256, TP = 4
ShapeArg(expected_m_per_group=16, n=4096, k=384, num_groups=128),
]
args = parser.parse_args()
benchmark_one_shape(shape_args, args.num_warmup, args.num_run)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,184 @@
import argparse
import copy
import itertools
from typing import Optional, Tuple
import torch
import triton
from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm
from sgl_kernel import sgl_per_tensor_quant_fp8
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
# Example:
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
# - TP1 : K = 14336, N = 4096
# - TP2 : K = 7168, N = 4096
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
# - TP1 : K = 4096, N = 6144
# - TP4 : K = 4096, N = 1536
# TP1 shapes
WEIGHT_SHAPES = {
"meta-llama/Llama-3.1-8B-Instruct": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-3.3-70B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 57344], 1),
([28672, 8192], 0),
],
"mistralai/Mistral-Large-Instruct-2407": [
([12288, 14336], 1),
([12288, 12288], 0),
([12288, 57344], 1),
([28672, 12288], 0),
],
"Qwen/Qwen2.5-7B-Instruct": [
([3584, 4608], 1),
([3584, 3584], 0),
([3584, 37888], 1),
([18944, 3584], 0),
],
"Qwen/Qwen2.5-32B-Instruct": [
([5120, 7168], 1),
([5120, 5120], 0),
([5120, 55296], 1),
([27648, 5120], 0),
],
"Qwen/Qwen2.5-72B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 59136], 1),
([29568, 8192], 0),
],
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
([2048, 3072], 1),
([2048, 4096], 1),
([2048, 2048], 0),
([2048, 576], 0),
([2048, 21888], 1),
([10944, 2048], 0),
([2048, 2816], 1),
([1408, 2048], 0),
],
}
def sglang_scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
fp8_type_: torch.dtype = torch.float8_e4m3fn
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
is_static = True
if scale is None:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
is_static = False
sgl_per_tensor_quant_fp8(input, output, scale, is_static)
return output, scale
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048],
x_log=False,
line_arg="provider",
line_vals=[
"vllm-fp8-fp16",
"vllm-fp8-bf16",
"sglang-fp8-fp16",
"sglang-fp8-bf16",
],
line_names=[
"vllm-fp8-fp16",
"vllm-fp8-bf16",
"sglang-fp8-fp16",
"sglang-fp8-bf16",
],
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
ylabel="GB/s",
plot_name="fp8 scaled matmul",
args={},
)
)
def benchmark(batch_size, provider, N, K):
# M, N, K = batch_size, 4096, 8192
M = batch_size
a = torch.ones((M, K), device="cuda") * 5.0
b = torch.ones((N, K), device="cuda") * 5.0
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
quantiles = [0.5, 0.2, 0.8]
dtype = torch.float16 if "fp16" in provider else torch.bfloat16
if "vllm-fp8" in provider:
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
b_fp8 = b_fp8.t()
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
quantiles=quantiles,
)
elif "sglang-fp8" in provider:
a_fp8, scale_a_fp8 = sglang_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = sglang_scaled_fp8_quant(b, scale_b)
b_fp8 = b_fp8.t()
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: sgl_scaled_mm(
a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None
),
quantiles=quantiles,
)
gbps = lambda ms: (2 * M * N * K + M * N) * a.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(max_ms), gbps(min_ms)
def prepare_shapes(args):
KN_model_names = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
assert model in WEIGHT_SHAPES
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KN.append(model)
KN_model_names.append(KN)
return KN_model_names
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["meta-llama/Llama-3.1-8B-Instruct"],
help="List of models to benchmark",
)
parser.add_argument(
"--tp-sizes",
nargs="+",
type=int,
default=[1],
help="List of tensor parallel sizes",
)
args = parser.parse_args()
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True, show_plots=True, save_path="bench_fp8_res", N=N, K=K
)
print("Benchmark finished!")

View File

@@ -0,0 +1,146 @@
import argparse
import copy
import itertools
import torch
import triton
from sgl_kernel import int8_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
WEIGHT_SHAPES = {
"meta-llama/Llama-3.1-8B-Instruct": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-3.3-70B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 57344], 1),
([28672, 8192], 0),
],
"mistralai/Mistral-Large-Instruct-2407": [
([12288, 14336], 1),
([12288, 12288], 0),
([12288, 57344], 1),
([28672, 12288], 0),
],
"Qwen/Qwen2.5-7B-Instruct": [
([3584, 4608], 1),
([3584, 3584], 0),
([3584, 37888], 1),
([18944, 3584], 0),
],
"Qwen/Qwen2.5-32B-Instruct": [
([5120, 7168], 1),
([5120, 5120], 0),
([5120, 55296], 1),
([27648, 5120], 0),
],
"Qwen/Qwen2.5-72B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 59136], 1),
([29568, 8192], 0),
],
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
([2048, 3072], 1),
([2048, 4096], 1),
([2048, 2048], 0),
([2048, 576], 0),
([2048, 21888], 1),
([10944, 2048], 0),
([2048, 2816], 1),
([1408, 2048], 0),
],
}
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
x_log=False,
line_arg="provider",
line_vals=["vllm", "sgl-kernel"],
line_names=["vllm int8 gemm", "sgl-kernel int8 gemm"],
styles=[("blue", "-"), ("orange", "-")],
ylabel="GB/s",
plot_name="int8 scaled matmul",
args={},
)
)
def benchmark(batch_size, provider, N, K):
M = batch_size
a = to_int8(torch.randn((M, K), device="cuda") * 5)
b = to_int8(torch.randn((N, K), device="cuda").t() * 5)
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
bias = torch.randn((N,), device="cuda", dtype=torch.float16)
quantiles = [0.5, 0.2, 0.8]
if provider == "sgl-kernel":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
quantiles=quantiles,
)
if provider == "vllm":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: vllm_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
quantiles=quantiles,
)
gbps = (
lambda ms: (
(2 * M * N * K - M * N) * a.element_size()
+ (3 * M * N) * scale_a.element_size()
)
* 1e-9
/ (ms * 1e-3)
)
return gbps(ms), gbps(max_ms), gbps(min_ms)
def prepare_shapes(args):
KN_model_names = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
assert model in WEIGHT_SHAPES
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KN.append(model)
KN_model_names.append(KN)
return KN_model_names
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["meta-llama/Llama-3.1-8B-Instruct"],
help="List of models to benchmark",
)
parser.add_argument(
"--tp-sizes",
nargs="+",
type=int,
default=[1],
help="List of tensor parallel sizes",
)
args = parser.parse_args()
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True, show_plots=True, save_path="bench_int8_res", N=N, K=K
)
print("Benchmark finished!")

View File

@@ -0,0 +1,299 @@
import itertools
import math
import torch
import triton
import triton.language as tl
from sgl_kernel import lightning_attention_decode
def next_power_of_2(n):
return 2 ** (int(math.ceil(math.log(n, 2))))
@triton.jit
def _decode_kernel(
Q,
K,
V,
KV,
Out,
S,
b: tl.constexpr,
h: tl.constexpr,
n: tl.constexpr,
d: tl.constexpr,
d_original: tl.constexpr,
e: tl.constexpr,
e_original: tl.constexpr,
):
off_bh = tl.program_id(0)
off_h = off_bh % h
qk_offset = off_bh * n * d
v_offset = off_bh * n * e
o_offset = off_bh * n * e
kv_offset = off_bh * d * e
s = tl.load(S + off_h)
ratio = tl.exp(-s)
d_idx = tl.arange(0, d)
e_idx = tl.arange(0, e)
# Create masks for original dimensions
d_mask = d_idx < d_original
e_mask = e_idx < e_original
# Load with masking
q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0)
k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0)
v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0)
# Load KV with 2D masking
kv = tl.load(
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
mask=(d_mask[:, None] & e_mask[None, :]),
other=0.0,
)
# Compute outer product using element-wise operations
k_v_prod = k[:, None] * v[None, :]
kv = ratio * kv + k_v_prod
# Store KV with 2D masking
tl.store(
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
kv.to(KV.dtype.element_ty),
mask=(d_mask[:, None] & e_mask[None, :]),
)
# Compute matrix-vector multiplication using element-wise operations and reduction
o = tl.sum(q[:, None] * kv, axis=0)
# Store output with masking
tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask)
def triton_lightning_attn_decode(q, k, v, kv, s):
"""Triton implementation of Lightning Attention decode operation"""
b, h, n, d = q.shape
e = v.shape[-1]
assert n == 1, "Sequence length must be 1 in decode mode"
# Get padded dimensions (power of 2)
d_padded = next_power_of_2(d)
e_padded = next_power_of_2(e)
# Create output tensor (padded)
o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
# Create padded tensors without actually padding the data
q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device)
k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device)
v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
kv_padded = torch.empty(
b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device
)
# Copy data to padded tensors
q_padded[..., :d] = q
k_padded[..., :d] = k
v_padded[..., :e] = v
kv_padded[..., :d, :e] = kv
# Launch kernel
grid = (b * h, 1)
_decode_kernel[grid](
q_padded,
k_padded,
v_padded,
kv_padded,
o_padded,
s,
b=b,
h=h,
n=n,
d=d_padded,
d_original=d,
e=e_padded,
e_original=e,
)
# Get unpadded outputs
o = o_padded[..., :e]
kv_out = kv_padded[..., :d, :e]
return o, kv_out
def lightning_attention_decode_naive(q, k, v, past_kv, slope):
"""Naive implementation of lightning attention decode"""
original_dtype = q.dtype
ratio = torch.exp(-slope) # [h, 1, 1]
kv = past_kv
b, h, n, d = q.shape
output = []
for i in range(n):
kv = ratio * kv.to(torch.float32) + torch.einsum(
"... n d, ... n e -> ... d e",
k[:, :, i : i + 1],
v[:, :, i : i + 1],
)
qkv = torch.einsum(
"... n e, ... e d -> ... n d",
q[:, :, i : i + 1].to(torch.float32),
kv.to(torch.float32),
)
output.append(qkv)
output = torch.cat(output, dim=-2)
return output.to(original_dtype), kv
def lightning_attention_decode_kernel(q, k, v, past_kv, slope, output, new_kv):
return lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv)
def calculate_diff(batch_size):
dtype = torch.bfloat16
device = torch.device("cuda")
num_heads = 64
head_dim = 96
seq_len = 1
q = torch.randn(
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
)
k = torch.randn(
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
)
v = torch.randn(
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
)
past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device)
slope = torch.randn(num_heads, 1, 1, device=device)
output_naive, new_kv_naive = lightning_attention_decode_naive(
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
)
output_kernel = torch.empty_like(output_naive)
new_kv_kernel = torch.empty_like(new_kv_naive)
lightning_attention_decode_kernel(
q.clone(),
k.clone(),
v.clone(),
past_kv.clone(),
slope.clone(),
output_kernel,
new_kv_kernel,
)
output_triton, new_kv_triton = triton_lightning_attn_decode(
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
)
if (
torch.allclose(output_naive, output_kernel, atol=1e-2, rtol=1e-2)
and torch.allclose(output_naive, output_triton, atol=1e-2, rtol=1e-2)
and torch.allclose(new_kv_naive, new_kv_kernel, atol=1e-2, rtol=1e-2)
and torch.allclose(new_kv_naive, new_kv_triton, atol=1e-2, rtol=1e-2)
):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
batch_size_range = [i for i in range(1, 65)] # 1 to 128
configs = [(bs,) for bs in batch_size_range]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["naive", "kernel", "triton"],
line_names=["PyTorch Naive", "SGL Kernel", "Triton"],
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
ylabel="us",
plot_name="lightning-attention-decode-performance",
args={},
)
)
def benchmark(batch_size, provider):
dtype = torch.bfloat16
device = torch.device("cuda")
num_heads = 64
head_dim = 96
seq_len = 1
q = torch.randn(
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
)
k = torch.randn(
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
)
v = torch.randn(
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
)
past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device)
slope = torch.randn(num_heads, 1, 1, device=device)
quantiles = [0.5, 0.2, 0.8]
if provider == "naive":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: lightning_attention_decode_naive(
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
),
quantiles=quantiles,
)
elif provider == "kernel":
output = torch.empty(
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
)
new_kv = torch.empty(batch_size, num_heads, head_dim, head_dim, device=device)
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: lightning_attention_decode_kernel(
q.clone(),
k.clone(),
v.clone(),
past_kv.clone(),
slope.clone(),
output,
new_kv,
),
quantiles=quantiles,
)
elif provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: triton_lightning_attn_decode(
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/lightning_attention_decode_sgl/",
help="Path to save lightning attention decode benchmark results",
)
args = parser.parse_args()
# Run correctness test
calculate_diff(batch_size=4)
# Run performance benchmark
benchmark.run(print_data=True)

View File

@@ -0,0 +1,401 @@
import argparse
import itertools
import torch
import triton
import triton.language as tl
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
try:
from vllm import _custom_ops as ops
except ImportError:
ops = None
USE_RANDOM_PERM = False
def ceil_div(a, b):
return (a + b - 1) // b
@triton.jit
def moe_align_block_size_stage1(
topk_ids_ptr,
tokens_cnts_ptr,
num_experts: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
):
pid = tl.program_id(0)
start_idx = pid * tokens_per_thread
off_c = (pid + 1) * num_experts
for i in range(tokens_per_thread):
if start_idx + i < numel:
idx = tl.load(topk_ids_ptr + start_idx + i)
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
@triton.jit
def moe_align_block_size_stage2(
tokens_cnts_ptr,
num_experts: tl.constexpr,
):
pid = tl.program_id(0)
last_cnt = 0
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
last_cnt = last_cnt + token_cnt
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
@triton.jit
def moe_align_block_size_stage3(
total_tokens_post_pad_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
):
last_cumsum = 0
off_cnt = num_experts * num_experts
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
tl.store(cumsum_ptr + i, last_cumsum)
tl.store(total_tokens_post_pad_ptr, last_cumsum)
@triton.jit
def moe_align_block_size_stage4(
topk_ids_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
):
pid = tl.program_id(0)
start_idx = tl.load(cumsum_ptr + pid)
end_idx = tl.load(cumsum_ptr + pid + 1)
for i in range(start_idx, end_idx, block_size):
tl.store(expert_ids_ptr + i // block_size, pid)
start_idx = pid * tokens_per_thread
off_t = pid * num_experts
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
expert_id = tl.load(topk_ids_ptr + i)
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
def moe_align_block_size_triton(
topk_ids: torch.Tensor,
num_experts: int,
block_size: int,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
) -> None:
numel = topk_ids.numel()
grid = (num_experts,)
tokens_cnts = torch.zeros(
(num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
)
cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
tokens_per_thread = ceil_div(numel, num_experts)
moe_align_block_size_stage1[grid](
topk_ids,
tokens_cnts,
num_experts,
numel,
tokens_per_thread,
)
moe_align_block_size_stage2[grid](
tokens_cnts,
num_experts,
)
moe_align_block_size_stage3[(1,)](
num_tokens_post_pad,
tokens_cnts,
cumsum,
num_experts,
block_size,
)
moe_align_block_size_stage4[grid](
topk_ids,
sorted_token_ids,
expert_ids,
tokens_cnts,
cumsum,
num_experts,
block_size,
numel,
tokens_per_thread,
)
def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
topk_ids = torch.stack(
[
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
for _ in range(num_tokens)
]
)
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids_cuda = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
)
sorted_ids_cuda.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids_cuda = torch.zeros(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_pad_cuda = torch.empty(
(1), dtype=torch.int32, device=topk_ids.device
)
cumsum_buffer = torch.zeros(
num_experts + 1, dtype=torch.int32, device=topk_ids.device
)
sorted_ids_triton = torch.empty_like(sorted_ids_cuda)
sorted_ids_triton.fill_(topk_ids.numel())
expert_ids_triton = torch.zeros_like(expert_ids_cuda)
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
sorted_ids_vllm = torch.empty_like(sorted_ids_cuda)
sorted_ids_vllm.fill_(topk_ids.numel())
expert_ids_vllm = torch.zeros_like(expert_ids_cuda)
num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_cuda)
# compare the performance of cuda, triton and vllm implementation
sgl_moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids_cuda,
expert_ids_cuda,
num_tokens_post_pad_cuda,
cumsum_buffer,
)
moe_align_block_size_triton(
topk_ids,
num_experts,
block_size,
sorted_ids_triton,
expert_ids_triton,
num_tokens_post_pad_triton,
)
try:
ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids_vllm,
expert_ids_vllm,
num_tokens_post_pad_vllm,
)
print(f"✅ VLLM implementation works with {num_experts} experts!")
vllm_works = True
except Exception as e:
print(f"❌ VLLM implementation failed with {num_experts} experts: {e}")
vllm_works = False
if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose(
num_tokens_post_pad_cuda, num_tokens_post_pad_triton
):
print("✅ SGL and Triton implementations match")
else:
print("❌ SGL and Triton implementations do not match")
print("SGL expert_ids:", expert_ids_cuda)
print("Triton expert_ids:", expert_ids_triton)
print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda)
print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton)
if (
vllm_works
and torch.allclose(expert_ids_cuda, expert_ids_vllm)
and torch.allclose(num_tokens_post_pad_cuda, num_tokens_post_pad_vllm)
):
print("✅ SGL and VLLM implementations match")
else:
if not vllm_works:
print("⚠️ VLLM comparison skipped due to failure")
else:
print("❌ SGL and VLLM implementations do not match")
print("SGL expert_ids:", expert_ids_cuda)
print("VLLM expert_ids:", expert_ids_vllm)
print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda)
print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm)
# Test range
num_tokens_range = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
num_experts_range = [8, 32, 64, 128, 256]
topk_range = [1, 2, 4, 8]
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
topk_ids = torch.zeros((num_tokens, topk), dtype=torch.int32, device="cuda")
for i in range(num_tokens):
topk_ids[i, :] = torch.randperm(num_experts, dtype=torch.int32, device="cuda")[
:topk
]
return topk_ids
def sgl_moe_align_block_size_with_empty(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
pad_sorted_token_ids=False,
):
if not pad_sorted_token_ids:
sorted_ids.fill_(topk_ids.numel())
cumsum_buffer = torch.empty(
num_experts + 1, dtype=torch.int32, device=topk_ids.device
)
sgl_moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids.clone(),
expert_ids.clone(),
num_tokens_post_pad.clone(),
cumsum_buffer,
pad_sorted_token_ids,
)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens", "num_experts", "topk"],
x_vals=configs,
line_arg="provider",
line_vals=["sgl", "sgl_fusion", "triton"],
line_names=["SGL", "SGL Fusion", "Triton"],
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
ylabel="us",
plot_name="moe-align-block-size-performance",
args={},
)
)
def benchmark(num_tokens, num_experts, topk, provider):
block_size = 128
if USE_RANDOM_PERM:
topk_ids = get_topk_ids(num_tokens, num_experts, topk)
else:
topk_ids = torch.randint(
0,
num_experts,
(num_tokens, topk),
dtype=torch.int32,
device="cuda",
)
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
)
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids = torch.empty(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
quantiles = [0.5, 0.2, 0.8]
if provider == "sgl":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: sgl_moe_align_block_size_with_empty(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
),
quantiles=quantiles,
)
elif provider == "sgl_fusion":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: sgl_moe_align_block_size_with_empty(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
pad_sorted_token_ids=True,
),
quantiles=quantiles,
)
elif provider == "triton":
sorted_ids.fill_(topk_ids.numel())
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: moe_align_block_size_triton(
topk_ids,
num_experts,
block_size,
sorted_ids.clone(),
expert_ids.clone(),
num_tokens_post_pad.clone(),
),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/moe_align_blocks/",
help="Path to save moe align benchmark results",
)
parser.add_argument(
"--num_experts",
type=int,
default=256,
choices=[8, 16, 32, 64, 128, 256],
help="Number of experts for benchmark",
)
parser.add_argument(
"--topk",
type=int,
default=8,
choices=[2, 4, 8],
help="Top-k value for benchmark",
)
parser.add_argument(
"--skip_full_benchmark",
action="store_true",
help="Only run the calculate_diff function, skip full benchmarking",
)
args = parser.parse_args()
calculate_diff(num_tokens=1024, num_experts=args.num_experts, topk=args.topk)
if not args.skip_full_benchmark:
print(f"\n📊 Running performance benchmark for {args.num_experts} experts...")
benchmark.run(print_data=True)

View File

@@ -0,0 +1,93 @@
import torch
import triton
from sgl_kernel import ep_moe_post_reorder
from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096]
configs = [(bs,) for bs in batch_sizes]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["cuda", "triton"],
line_names=["CUDA Kernel", "Triton Kernel"],
styles=[("green", "-"), ("orange", "-")],
ylabel="us",
plot_name="ep-moe-post-reorder-performance",
args={},
)
)
def benchmark(batch_size, provider):
dtype = torch.bfloat16
device = torch.device("cuda")
hidden_size, topk, start_expert_id, end_expert_id, block_size = 4096, 8, 0, 255, 512
def alloc_tensors():
down_output = torch.randn(
batch_size * topk, hidden_size, dtype=dtype, device=device
)
output = torch.zeros(batch_size, hidden_size, dtype=dtype, device=device)
src2dst = torch.randint(
0, batch_size * topk, (batch_size, topk), dtype=torch.int32, device=device
)
topk_ids = torch.randint(
start_expert_id,
end_expert_id + 1,
(batch_size, topk),
dtype=torch.int32,
device=device,
)
topk_weights = torch.rand(batch_size, topk, dtype=dtype, device=device)
return down_output, output, src2dst, topk_ids, topk_weights
quantiles = [0.5, 0.2, 0.8]
if provider == "cuda":
d_out, out, s2d, tk_ids, tk_weights = alloc_tensors()
def run_cuda():
ep_moe_post_reorder(
d_out,
out,
s2d,
tk_ids,
tk_weights,
start_expert_id,
end_expert_id,
topk,
)
ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles)
elif provider == "triton":
d_out, out, s2d, tk_ids, tk_weights = alloc_tensors()
def run_triton():
post_reorder_triton_kernel[(batch_size,)](
d_out.view(-1),
out.view(-1),
s2d.view(-1),
tk_ids.view(-1),
tk_weights.view(-1),
start_expert_id,
end_expert_id,
topk,
hidden_size,
0,
block_size,
)
ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles)
else:
raise ValueError(f"Unknown provider: {provider}")
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
benchmark.run(print_data=True)

View File

@@ -0,0 +1,103 @@
import torch
import triton
from sgl_kernel import ep_moe_pre_reorder
from sglang.srt.layers.moe.ep_moe.kernels import pre_reorder_triton_kernel
batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096]
configs = [(bs,) for bs in batch_sizes]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["cuda", "triton"],
line_names=["CUDA Kernel", "Triton Kernel"],
styles=[("green", "-"), ("orange", "-")],
ylabel="us",
plot_name="ep-moe-pre-reorder-performance",
args={},
)
)
def benchmark(batch_size, provider):
dtype = torch.bfloat16
device = torch.device("cuda")
hidden_size, topk, start_expert_id, end_expert_id, block_size = (
4096,
8,
0,
255,
512,
)
# Allocate fresh tensors for every run to match bench_moe_fused_gate style
def alloc_tensors():
input_ = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
gateup_input = torch.zeros(
batch_size * topk, hidden_size, dtype=dtype, device=device
)
src2dst = torch.randint(
0, batch_size * topk, (batch_size, topk), dtype=torch.int32, device=device
)
topk_ids = torch.randint(
start_expert_id,
end_expert_id + 1,
(batch_size, topk),
dtype=torch.int32,
device=device,
)
a1_scales = torch.rand(
end_expert_id - start_expert_id + 1, dtype=torch.float32, device=device
)
return input_, gateup_input, src2dst, topk_ids, a1_scales
quantiles = [0.5, 0.2, 0.8]
if provider == "cuda":
inp, gout, s2d, tk_ids, scales = alloc_tensors()
def run_cuda():
ep_moe_pre_reorder(
inp,
gout,
s2d,
tk_ids,
scales,
start_expert_id,
end_expert_id,
topk,
True,
)
ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles)
elif provider == "triton":
inp, gout, s2d, tk_ids, scales = alloc_tensors()
def run_triton():
pre_reorder_triton_kernel[(batch_size,)](
inp.view(-1),
gout.view(-1),
s2d.view(-1),
tk_ids.view(-1),
scales,
start_expert_id,
end_expert_id,
topk,
hidden_size,
block_size,
True,
)
ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles)
else:
raise ValueError(f"Unknown provider: {provider}")
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
benchmark.run(print_data=True)

View File

@@ -0,0 +1,77 @@
import itertools
import math
import torch
import triton
import triton.language as tl
from sgl_kernel import moe_fused_gate
from sglang.srt.layers.moe.topk import biased_grouped_topk
def biased_grouped_topk_org(scores, bias, num_expert_group, topk_group, topk):
return biased_grouped_topk(
scores,
scores,
bias,
topk=topk,
renormalize=True,
num_expert_group=num_expert_group,
topk_group=topk_group,
routed_scaling_factor=2.5, # DeepSeek-R1 : 2.5, Kimi K2: 2.872
)
def biased_grouped_topk_org_fuse_kernel(
scores, bias, num_expert_group, topk_group, topk
):
return moe_fused_gate(scores, bias, num_expert_group, topk_group, topk)
seq_length_range = [5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000]
configs = [(sq,) for sq in seq_length_range]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["seq_length"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["original", "kernel"],
line_names=["Original", "SGL Kernel"],
styles=[("blue", "-"), ("red", "-")],
ylabel="us",
plot_name="moe-fused-gate-performance",
args={},
)
)
def benchmark(seq_length, provider):
dtype = torch.bfloat16
device = torch.device("cuda")
num_experts, num_expert_group, topk_group, topk = 256, 8, 4, 8
scores = torch.randn((seq_length, num_experts), device=device, dtype=dtype)
bias = torch.rand(num_experts, device=device, dtype=dtype)
quantiles = [0.5, 0.2, 0.8]
if provider == "original":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: biased_grouped_topk_org(
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
),
quantiles=quantiles,
)
elif provider == "kernel":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: biased_grouped_topk_org_fuse_kernel(
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
benchmark.run(print_data=True)

View File

@@ -0,0 +1,92 @@
import itertools
import torch
import triton
from sgl_kernel import ep_moe_silu_and_mul
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_triton_kernel
batch_size_range = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096]
hidden_size_range = [1024, 2048, 4096, 8192]
block_size_range = [128, 256, 512]
configs = list(itertools.product(batch_size_range, hidden_size_range, block_size_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "hidden_size", "block_size"],
x_vals=[list(cfg) for cfg in configs],
line_arg="provider",
line_vals=["cuda", "triton"],
line_names=["CUDA Kernel", "Triton Kernel"],
styles=[("green", "-"), ("orange", "-")],
ylabel="us",
plot_name="ep-moe-silu-and-mul-performance",
args={},
)
)
def benchmark(batch_size, hidden_size, block_size, provider):
dtype = torch.bfloat16
device = torch.device("cuda")
half_hidden_size = hidden_size // 2
start_expert_id, end_expert_id = 0, 255
block_size = 512
quantiles = [0.5, 0.2, 0.8]
def alloc_tensors():
gateup_output = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
down_input = torch.empty(
batch_size, half_hidden_size, dtype=dtype, device=device
)
reorder_topk_ids = torch.randint(
start_expert_id,
end_expert_id + 1,
(batch_size,),
dtype=torch.int32,
device=device,
)
scales = torch.rand(
end_expert_id - start_expert_id + 1, dtype=torch.float32, device=device
)
return gateup_output, down_input, reorder_topk_ids, scales
if provider == "cuda":
gateup, down, ids, scales = alloc_tensors()
def run_cuda():
ep_moe_silu_and_mul(
gateup,
down,
ids,
scales,
start_expert_id,
end_expert_id,
)
ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles)
elif provider == "triton":
gateup, down, ids, scales = alloc_tensors()
def run_triton():
silu_and_mul_triton_kernel[(batch_size,)](
gateup.view(-1),
down.view(-1),
hidden_size,
ids,
scales,
start_expert_id,
end_expert_id,
block_size,
)
ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles)
else:
raise ValueError(f"Unknown provider: {provider}")
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
benchmark.run(print_data=True)

View File

@@ -0,0 +1,116 @@
import itertools
import pytest
import torch
import triton
from sgl_kernel import topk_softmax
from vllm import _custom_ops as vllm_custom_ops
def vllm_topk_softmax(gating_output, topk):
num_tokens, num_experts = gating_output.shape
topk_weights = torch.empty(
(num_tokens, topk), device=gating_output.device, dtype=torch.float32
)
topk_indices = torch.empty(
(num_tokens, topk), dtype=torch.int32, device=gating_output.device
)
token_expert_indices = torch.empty(
(num_tokens, topk), dtype=torch.int32, device=gating_output.device
)
torch.ops._moe_C.topk_softmax(
topk_weights, topk_indices, token_expert_indices, gating_output
)
return topk_weights, topk_indices
def sglang_topk_softmax(gating_output, topk):
num_tokens, num_experts = gating_output.shape
topk_weights = torch.empty(
(num_tokens, topk), device=gating_output.device, dtype=torch.float32
)
topk_indices = torch.empty(
(num_tokens, topk), dtype=torch.int32, device=gating_output.device
)
topk_softmax(
topk_weights=topk_weights,
topk_ids=topk_indices,
gating_output=gating_output,
)
return topk_weights, topk_indices
def calculate_diff(num_tokens, num_experts, topk):
gating_output = torch.randn(
(num_tokens, num_experts), device="cuda", dtype=torch.float32
)
weights_vllm, indices_vllm = vllm_topk_softmax(gating_output.clone(), topk)
weights_sglang, indices_sglang = sglang_topk_softmax(gating_output.clone(), topk)
weights_diff = torch.abs(weights_vllm - weights_sglang).mean().item()
indices_match = torch.equal(indices_vllm, indices_sglang)
if (
torch.allclose(weights_vllm, weights_sglang, atol=1e-3, rtol=1e-3)
and indices_match
):
print("✅ VLLM and SGLang topk_softmax implementations match")
else:
print(
f"❌ Implementations differ: Weights diff={weights_diff}, Indices match={indices_match}"
)
num_tokens_range = [128, 512, 1024, 2048, 4096, 8192, 16384, 32768]
num_experts_range = [32, 64, 128, 256, 12, 512]
topk_range = [1, 2, 4, 8]
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens", "num_experts", "topk"],
x_vals=configs,
line_arg="provider",
line_vals=["sglang", "vllm"],
line_names=["SGLang", "VLLM"],
styles=[("blue", "-"), ("green", "-")],
ylabel="Latency (us)",
plot_name="topk-softmax-performance",
args={},
)
)
def benchmark(num_tokens, num_experts, topk, provider):
gating_output = torch.randn(
(num_tokens, num_experts), device="cuda", dtype=torch.float32
)
if provider == "vllm" or provider == "vllm1":
fn = lambda: vllm_topk_softmax(gating_output, topk)
elif provider == "sglang" or provider == "sglang1":
fn = lambda: sglang_topk_softmax(gating_output, topk)
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
configs = [
(20, 256, 4),
(20, 256, 8),
(20, 12, 4),
(20, 12, 1),
(20, 512, 4),
(20, 512, 1),
]
for num_tokens, num_experts, topk in configs:
calculate_diff(num_tokens, num_experts, topk)
benchmark.run(print_data=True)

View File

@@ -0,0 +1,172 @@
import argparse
import copy
import itertools
import torch
import triton
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
# Example:
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
# - TP1 : K = 14336, N = 4096
# - TP2 : K = 7168, N = 4096
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
# - TP1 : K = 4096, N = 6144
# - TP4 : K = 4096, N = 1536
# TP1 shapes
WEIGHT_SHAPES = {
"meta-llama/Llama-3.1-8B-Instruct": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-3.3-70B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 57344], 1),
([28672, 8192], 0),
],
"mistralai/Mistral-Large-Instruct-2407": [
([12288, 14336], 1),
([12288, 12288], 0),
([12288, 57344], 1),
([28672, 12288], 0),
],
"Qwen/Qwen2.5-7B-Instruct": [
([3584, 4608], 1),
([3584, 3584], 0),
([3584, 37888], 1),
([18944, 3584], 0),
],
"Qwen/Qwen2.5-32B-Instruct": [
([5120, 7168], 1),
([5120, 5120], 0),
([5120, 55296], 1),
([27648, 5120], 0),
],
"Qwen/Qwen2.5-72B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 59136], 1),
([29568, 8192], 0),
],
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
([2048, 3072], 1),
([2048, 4096], 1),
([2048, 2048], 0),
([2048, 576], 0),
([2048, 21888], 1),
([10944, 2048], 0),
([2048, 2816], 1),
([1408, 2048], 0),
],
}
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048],
x_log=False,
line_arg="provider",
line_vals=[
"sglang-fp4-fp16",
"sglang-fp4-bf16",
],
line_names=[
"sglang-fp4-fp16",
"sglang-fp4-bf16",
],
styles=[("green", "-"), ("blue", "-")],
ylabel="TFLOPS",
plot_name="fp4 block scaled matmul",
args={},
)
)
def benchmark(batch_size, provider, N, K):
# M, N, K = batch_size, 4096, 8192
run_step = 100
dtype = torch.float16 if "fp16" in provider else torch.bfloat16
M = batch_size
a = torch.randn((M, K), dtype=dtype, device="cuda")
b = torch.randn((N, K), dtype=dtype, device="cuda")
a_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1)
).to(torch.float32)
b_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b.flatten(), dim=-1)
).to(torch.float32)
alpha = 1.0 / (a_global_scale * b_global_scale)
a_fp4, a_scale_interleaved = scaled_fp4_quant(a, a_global_scale)
b_fp4, b_scale_interleaved = scaled_fp4_quant(b, b_global_scale)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# Bridging the gap between CPU and GPU
for _ in range(25):
c = a @ b.t()
# Warmup
for _ in range(5):
cutlass_scaled_fp4_mm(
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
)
start_event.record()
for _ in range(run_step):
cutlass_scaled_fp4_mm(
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
)
end_event.record()
end_event.synchronize()
torch.cuda.synchronize()
ms = start_event.elapsed_time(end_event) / run_step
tflops = lambda ms: (2 * M * N * K) * 1e-9 / ms
return tflops(ms)
def prepare_shapes(args):
KN_model_names = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
assert model in WEIGHT_SHAPES
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KN.append(model)
KN_model_names.append(KN)
return KN_model_names
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["meta-llama/Llama-3.1-8B-Instruct"],
help="List of models to benchmark",
)
parser.add_argument(
"--tp-sizes",
nargs="+",
type=int,
default=[1],
help="List of tensor parallel sizes",
)
args = parser.parse_args()
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True, show_plots=True, save_path="bench_fp4_res", N=N, K=K
)
print("Benchmark finished!")

View File

@@ -0,0 +1,98 @@
import itertools
import math
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
import triton
import triton.testing
from sgl_kernel import sgl_per_tensor_quant_fp8
from vllm import _custom_ops as ops
from sglang.srt.utils import is_hip
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
def vllm_scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
return ops.scaled_fp8_quant(input, scale)
def sglang_scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
fp8_type_: torch.dtype = torch.float8_e4m3fn
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
is_static = True
if scale is None:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
is_static = False
sgl_per_tensor_quant_fp8(input, output, scale, is_static)
return output, scale
def calculate_diff(batch_size: int, seq_len: int):
"""Calculate difference between VLLM and SGLang implementations."""
device = torch.device("cuda")
x = torch.rand((batch_size, seq_len), dtype=torch.float16, device=device)
vllm_out, vllm_scale = vllm_scaled_fp8_quant(x)
sglang_out, sglang_scale = sglang_scaled_fp8_quant(x)
scale_diff = torch.abs(vllm_scale - sglang_scale).item()
output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
if torch.allclose(
vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
) and torch.allclose(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
batch_size_range = [16, 32, 64, 128]
seq_len_range = [64, 128, 256, 512, 1024, 2048]
configs = list(itertools.product(batch_size_range, seq_len_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len"],
x_vals=configs,
line_arg="provider",
line_vals=["vllm", "sglang"],
line_names=["VLLM", "SGL Kernel"],
styles=[("blue", "-"), ("green", "-")],
ylabel="us",
plot_name="per-tensor-quant-fp8-performance",
args={},
)
)
def benchmark(batch_size, seq_len, provider):
dtype = torch.float16
device = torch.device("cuda")
x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype)
quantiles = [0.5, 0.2, 0.8]
if provider == "vllm":
fn = lambda: vllm_scaled_fp8_quant(x.clone())
elif provider == "sglang":
fn = lambda: sglang_scaled_fp8_quant(x.clone())
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
calculate_diff(batch_size=4, seq_len=4096)
benchmark.run(print_data=True)

View File

@@ -0,0 +1,98 @@
import itertools
import time
from functools import partial
from pathlib import Path
import torch
import triton
from sglang.srt.bench_utils import bench_kineto
from sglang.srt.layers.quantization.fp8_kernel import (
create_per_token_group_quant_fp8_output_scale,
)
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_8bit as triton_per_token_group_quant_8bit,
)
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
from sglang.srt.utils import is_hip
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384]
hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1
group_size_range = [128] # For DeepSeek V3/R1
# TODO test int8
dst_dtype_range = [fp8_type_]
flags_range = [
dict(
column_major_scales=False,
scale_tma_aligned=False,
scale_ue8m0=False,
),
dict(
column_major_scales=True,
scale_tma_aligned=False,
scale_ue8m0=False,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=False,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
),
]
configs = list(
itertools.product(
num_tokens_range,
hidden_dim_range,
group_size_range,
dst_dtype_range,
flags_range,
)
)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens", "hidden_dim", "group_size", "dst_dtype", "flags"],
x_vals=configs,
line_arg="provider",
line_vals=["triton", "sglang"],
line_names=["Triton", "SGL Kernel"],
styles=[("blue", "-"), ("green", "-")],
ylabel="us",
plot_name="per-token-group-quant-8bit-performance",
args={},
)
)
def benchmark(num_tokens, hidden_dim, group_size, dst_dtype, flags, provider):
if flags["scale_ue8m0"] and group_size != 128:
return
device = torch.device("cuda")
x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16)
fn, kernel_names = {
"triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_fp8"),
"sglang": (
sglang_per_token_group_quant_8bit,
"per_token_group_quant_8bit_kernel",
),
}[provider]
bench_fn = lambda: fn(x=x, group_size=group_size, dst_dtype=dst_dtype, **flags)
time_s = bench_kineto(bench_fn, kernel_names=kernel_names)
return time_s * 1e6
if __name__ == "__main__":
benchmark.run(print_data=True)

View File

@@ -0,0 +1,177 @@
import itertools
from typing import Optional, Tuple
import torch
import triton
import triton.testing
from sgl_kernel import sgl_per_token_quant_fp8
from vllm import _custom_ops as ops
from sglang.srt.utils import is_hip
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
# Get correct FP8 E4M3 maximum value
if _is_hip:
FP8_E4M3_MAX = 224.0 # ROCM uses 224.0
else:
# For CUDA, get the actual max value from the type
FP8_E4M3_MAX = float(torch.finfo(fp8_type_).max)
def torch_per_token_quant_fp8(
input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Pure PyTorch reference implementation for per-token FP8 quantization."""
device = input.device
dtype = input.dtype
# Find max absolute value per token (row) - exactly like CUDA kernel
max_vals = torch.abs(input).max(dim=1)[0] # [num_tokens]
# Calculate scale per token - exactly like CUDA kernel: scale = max_value / FP8_E4M3_MAX
scales = max_vals / FP8_E4M3_MAX # [num_tokens]
# No special zero handling - directly compute 1.0 / scale like CUDA kernel
scale_inv = 1.0 / scales # [num_tokens]
# Quantize: input * scale_inv, then clamp to FP8 range
quantized_float = input * scale_inv.unsqueeze(1) # Broadcast scale_inv
quantized_float = torch.clamp(quantized_float, -FP8_E4M3_MAX, FP8_E4M3_MAX)
# Convert to FP8 - use more explicit conversion
quantized_fp8 = quantized_float.to(fp8_type_)
return quantized_fp8, scales
def vllm_per_token_quant_fp8(
input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
return ops.scaled_fp8_quant(input, use_per_token_if_dynamic=True)
def sglang_per_token_quant_fp8(
input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
scale = torch.zeros(input.size(0), device=input.device, dtype=torch.float32)
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
sgl_per_token_quant_fp8(input, output, scale)
return output, scale
def calculate_diff(batch_size: int, seq_len: int, hidden_dim: int):
"""Compare Torch reference, VLLM, and SGLang implementations."""
device = torch.device("cuda")
x = torch.rand(
(batch_size * seq_len, hidden_dim), dtype=torch.float16, device=device
)
# Get all three implementations
torch_out, torch_scale = torch_per_token_quant_fp8(x)
vllm_out, vllm_scale = vllm_per_token_quant_fp8(x)
sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)
print(f"\n=== Comparison for hidden_dim={hidden_dim} ===")
# Compare scales
torch_vllm_scale_diff = torch.abs(torch_scale - vllm_scale).mean().item()
torch_sglang_scale_diff = torch.abs(torch_scale - sglang_scale).mean().item()
vllm_sglang_scale_diff = torch.abs(vllm_scale - sglang_scale).mean().item()
print(f"Scale differences:")
print(f" Torch vs VLLM: {torch_vllm_scale_diff:.8f}")
print(f" Torch vs SGLang: {torch_sglang_scale_diff:.8f}")
print(f" VLLM vs SGLang: {vllm_sglang_scale_diff:.8f}")
# Compare outputs
torch_vllm_out_diff = torch.abs(torch_out.float() - vllm_out.float()).mean().item()
torch_sglang_out_diff = (
torch.abs(torch_out.float() - sglang_out.float()).mean().item()
)
vllm_sglang_out_diff = (
torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
)
print(f"Output differences:")
print(f" Torch vs VLLM: {torch_vllm_out_diff:.8f}")
print(f" Torch vs SGLang: {torch_sglang_out_diff:.8f}")
print(f" VLLM vs SGLang: {vllm_sglang_out_diff:.8f}")
# Check tolerances
rtol, atol = 1e-3, 1e-5
torch_vllm_match = torch.allclose(
torch_out.float(), vllm_out.float(), rtol=rtol, atol=atol
) and torch.allclose(torch_scale, vllm_scale, rtol=rtol, atol=atol)
torch_sglang_match = torch.allclose(
torch_out.float(), sglang_out.float(), rtol=rtol, atol=atol
) and torch.allclose(torch_scale, sglang_scale, rtol=rtol, atol=atol)
if hidden_dim == 1368:
rtol = 1e-2
# we found vllm sglang has diff when hidden dim is not dividable by 16
# and we believe SGLang is closer to Torch implementation
vllm_sglang_match = torch.allclose(
vllm_out.float(), sglang_out.float(), rtol=rtol, atol=atol
) and torch.allclose(vllm_scale, sglang_scale, rtol=rtol, atol=atol)
print(f"Matches (rtol={rtol}, atol={atol}):")
print(f" Torch vs VLLM: {'' if torch_vllm_match else ''}")
print(f" Torch vs SGLang: {'' if torch_sglang_match else ''}")
print(f" VLLM vs SGLang: {'' if vllm_sglang_match else ''}")
batch_size_range = [16, 32, 64, 128]
seq_len_range = [64, 128, 256, 512, 1024, 2048, 4096]
hidden_dim_range = [1368, 2048, 4096]
configs = list(itertools.product(batch_size_range, seq_len_range, hidden_dim_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len", "hidden_dim"],
x_vals=configs,
line_arg="provider",
line_vals=["torch", "vllm", "sglang"],
line_names=["Torch Reference", "VLLM", "SGL Kernel"],
styles=[("red", "-"), ("blue", "-"), ("green", "-")],
ylabel="us",
plot_name="per-token-dynamic-quant-fp8-performance",
args={},
)
)
def benchmark_quantization(batch_size, seq_len, hidden_dim, provider):
dtype = torch.float16
device = torch.device("cuda")
x = torch.randn(batch_size * seq_len, hidden_dim, device=device, dtype=dtype)
quantiles = [0.5, 0.2, 0.8]
if provider == "torch":
fn = lambda: torch_per_token_quant_fp8(x.clone())
elif provider == "vllm":
fn = lambda: vllm_per_token_quant_fp8(x.clone())
elif provider == "sglang":
fn = lambda: sglang_per_token_quant_fp8(x.clone())
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
# Test various hidden dimensions for correctness
test_dims = [1368, 2048, 4096]
for dim in test_dims:
calculate_diff(batch_size=4, seq_len=4096, hidden_dim=dim)
print("\n" + "=" * 60)
print("Starting performance benchmark...")
benchmark_quantization.run(print_data=True)

View File

@@ -0,0 +1,198 @@
import argparse
import copy
import itertools
import torch
import triton
from sgl_kernel import (
int8_scaled_mm,
qserve_w4a8_per_chn_gemm,
qserve_w4a8_per_group_gemm,
)
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
WEIGHT_SHAPES = {
"meta-llama/Llama-3.1-8B-Instruct": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-3.3-70B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 57344], 1),
([28672, 8192], 0),
],
"mistralai/Mistral-Large-Instruct-2407": [
([12288, 14336], 1),
([12288, 12288], 0),
([12288, 57344], 1),
([28672, 12288], 0),
],
"Qwen/Qwen2.5-7B-Instruct": [
([3584, 4608], 1),
([3584, 3584], 0),
([3584, 37888], 1),
([18944, 3584], 0),
],
"Qwen/Qwen2.5-32B-Instruct": [
([5120, 7168], 1),
([5120, 5120], 0),
([5120, 55296], 1),
([27648, 5120], 0),
],
"Qwen/Qwen2.5-72B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 59136], 1),
([29568, 8192], 0),
],
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
([2048, 3072], 1),
([2048, 4096], 1),
([2048, 2048], 0),
([2048, 576], 0),
([2048, 21888], 1),
([10944, 2048], 0),
([2048, 2816], 1),
([1408, 2048], 0),
],
}
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
x_log=False,
line_arg="provider",
line_vals=["FP16", "W8A8", "Qserve_W4A8_Per_Channel", "Qserve_W4A8_Per_Group"],
line_names=["FP16", "W8A8", "Qserve_W4A8_Per_Channel", "Qserve_W4A8_Per_Group"],
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
ylabel="ms",
plot_name="FP16_vs_W8A8_vs_Qserve_W4A8_GEMM",
args={},
)
)
def benchmark(batch_size, provider, N, K):
M = batch_size
# For W8A8
a = to_int8(torch.randn((M, K), device="cuda") * 5)
b = to_int8(torch.randn((N, K), device="cuda").t() * 5)
a_fp16 = a.to(torch.float16)
b_fp16 = b.to(torch.float16)
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
# For Qserve W4A8 per channel
a_qserve_chn = a
# two int4s pack into one int8
b_qserve_chn = to_int8(torch.randn((N, K // 2), device="cuda") * 5)
# b_qserve_chn = b.t().contiguous()
scale_a_qserve_chn = scale_a.to(torch.float16)
scale_b_qserve_chn = scale_b.to(torch.float16)
szero_b_qserve_chn = torch.randn((N,), device="cuda", dtype=torch.float16)
a_sum_qserve_chn = torch.randn((M,), device="cuda", dtype=torch.float16)
# For Qserve W4A8 per group
group_size = 128
assert K % group_size == 0, "K must be divisible by group_size"
a_qserve_group = a
# two int4s pack into one int8
b_qserve_group = to_int8(torch.randn((N, K // 2), device="cuda") * 5)
# b_qserve_group = b.t().contiguous()
scale_a_qserve_group = scale_a.to(torch.float16)
scale_b_qserve_group = scale_b.to(torch.float16)
scale_i8_b_qserve_group = to_int8(
torch.randn((K // group_size, N), device="cuda", dtype=torch.float16)
)
zero_i8_b_qserve_group = to_int8(
torch.randn((K // group_size, N), device="cuda", dtype=torch.float16)
)
quantiles = [0.5, 0.2, 0.8]
if provider == "FP16":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch.matmul(a_fp16, b_fp16),
quantiles=quantiles,
)
if provider == "W8A8":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16),
quantiles=quantiles,
)
if provider == "Qserve_W4A8_Per_Channel":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: qserve_w4a8_per_chn_gemm(
a_qserve_chn,
b_qserve_chn,
scale_b_qserve_chn,
scale_a_qserve_chn,
szero_b_qserve_chn,
a_sum_qserve_chn,
),
quantiles=quantiles,
)
if provider == "Qserve_W4A8_Per_Group":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: qserve_w4a8_per_group_gemm(
a_qserve_group,
b_qserve_group,
zero_i8_b_qserve_group,
scale_i8_b_qserve_group,
scale_b_qserve_group,
scale_a_qserve_group,
),
quantiles=quantiles,
)
return ms, max_ms, min_ms
def prepare_shapes(args):
KN_model_names = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
assert model in WEIGHT_SHAPES
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KN.append(model)
KN_model_names.append(KN)
return KN_model_names
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["meta-llama/Llama-3.1-8B-Instruct"],
help="List of models to benchmark",
)
parser.add_argument(
"--tp-sizes",
nargs="+",
type=int,
default=[1],
help="List of tensor parallel sizes",
)
args = parser.parse_args()
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True,
show_plots=True,
save_path="bench_qserve_w4a8_gemm_res",
N=N,
K=K,
)
print("Benchmark finished!")

View File

@@ -0,0 +1,96 @@
import itertools
import torch
import triton
from sgl_kernel import FusedSetKVBufferArg
from sgl_kernel.testing.rotary_embedding import (
FlashInferRotaryEmbedding,
MHATokenToKVPool,
RotaryEmbedding,
create_inputs,
)
from sglang.srt.bench_utils import bench_kineto
configs = [
(batch_size, seq_len, save_kv_cache)
for batch_size, seq_len in (
(1, 1),
(32, 1),
(128, 1),
(512, 1),
(2, 512),
(4, 4096),
)
for save_kv_cache in (False, True)
]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len", "save_kv_cache"],
x_vals=configs,
line_arg="provider",
line_vals=["sglang"],
line_names=["SGL Kernel"],
styles=[("green", "-")],
ylabel="us",
plot_name="bench_rotary_embedding",
args={},
)
)
def benchmark(batch_size, seq_len, save_kv_cache, provider):
device = torch.device("cuda")
num_q_heads = 32
num_kv_heads = 8
head_size = 64
dtype = torch.bfloat16
config = dict(
head_size=head_size,
rotary_dim=64,
max_position_embeddings=4096,
base=8000,
is_neox_style=True,
dtype=dtype,
)
rope_flashinfer = FlashInferRotaryEmbedding(**config).to(device)
pool_flashinfer = MHATokenToKVPool(head_num=num_kv_heads, head_dim=head_size)
inputs = create_inputs(
head_size=head_size,
batch_size=batch_size,
seq_len=seq_len,
device=device,
dtype=dtype,
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
)
query_flashinfer, key_flashinfer = inputs["query"].clone(), inputs["key"].clone()
bench_fn = lambda: rope_flashinfer.forward_cuda(
inputs["pos_ids"],
query_flashinfer,
key_flashinfer,
fused_set_kv_buffer_arg=(
FusedSetKVBufferArg(
value=inputs["value"],
k_buffer=pool_flashinfer.k_buffer[0].view(-1, num_kv_heads * head_size),
v_buffer=pool_flashinfer.v_buffer[0].view(-1, num_kv_heads * head_size),
k_scale=None,
v_scale=None,
cache_loc=inputs["out_cache_loc"],
)
if save_kv_cache
else None
),
)
time_s = bench_kineto(bench_fn, kernel_names="BatchQKApplyRotaryPosIds")
return time_s * 1e6
if __name__ == "__main__":
benchmark.run(print_data=True)

View File

@@ -0,0 +1,128 @@
import itertools
import sgl_kernel
import torch
import triton
import triton.testing
def torch_top_k_top_p_joint_sampling_from_probs(
normalized_prob, top_k, top_p, eps=1e-4
):
"""Reference PyTorch implementation of joint top-k top-p sampling."""
batch_size, vocab_size = normalized_prob.shape
samples = torch.empty(batch_size, dtype=torch.int64, device=normalized_prob.device)
for i in range(batch_size):
p_val = top_p[i].item()
k_val = top_k[i].item()
# top-p mask
sorted_prob, indices = torch.sort(normalized_prob[i], descending=False)
cdf = torch.cumsum(sorted_prob, dim=-1)
mask_top_p = torch.zeros(
vocab_size, dtype=torch.int32, device=normalized_prob.device
)
mask_top_p.scatter_add_(0, indices, (cdf > (1 - p_val) - eps).int())
# top-k mask
sorted_prob_desc, _ = torch.sort(normalized_prob[i], descending=True)
pivot = sorted_prob_desc[k_val - 1]
mask_top_k = (normalized_prob[i] >= pivot).int()
# joint mask
mask = torch.minimum(mask_top_p, mask_top_k).bool()
# sample from masked probs
masked_probs = normalized_prob[i] * mask
masked_probs = masked_probs / masked_probs.sum()
idx = torch.multinomial(masked_probs, 1)
samples[i] = idx
return samples
def calculate_diff(batch_size, vocab_size, p):
"""Compare Torch reference and SGLang kernel for correctness."""
torch.manual_seed(42)
if p == 0.1:
k = int(vocab_size * 0.5)
elif p == 0.5:
k = int(vocab_size * 0.1)
else:
raise ValueError("p not recognized")
device = torch.device("cuda")
pre_norm_prob = torch.rand(batch_size, vocab_size, device=device)
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
top_p_tensor = torch.full((batch_size,), p, device=device)
top_k_tensor = torch.full((batch_size,), k, device=device)
torch_samples = torch_top_k_top_p_joint_sampling_from_probs(
normalized_prob, top_k_tensor, top_p_tensor
)
sglang_samples = sgl_kernel.top_k_top_p_sampling_from_probs(
normalized_prob, top_k_tensor, top_p_tensor, filter_apply_order="joint"
)
# parameter space
batch_size_range = [16, 64, 128]
vocab_size_range = [111, 32000]
p_range = [0.1, 0.5]
configs = list(itertools.product(batch_size_range, vocab_size_range, p_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "vocab_size", "p"],
x_vals=configs,
line_arg="provider",
line_vals=["torch", "sglang"],
line_names=["Torch Reference", "SGL Kernel"],
styles=[("red", "-"), ("green", "-")],
ylabel="us",
plot_name="top-k-top-p-joint-sampling-performance",
args={},
)
)
def benchmark_sampling(batch_size, vocab_size, p, provider):
torch.manual_seed(42)
if p == 0.1:
k = int(vocab_size * 0.5)
elif p == 0.5:
k = int(vocab_size * 0.1)
else:
raise ValueError("p not recognized")
device = torch.device("cuda")
pre_norm_prob = torch.rand(batch_size, vocab_size, device=device)
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
top_p_tensor = torch.full((batch_size,), p, device=device)
top_k_tensor = torch.full((batch_size,), k, device=device)
if provider == "torch":
fn = lambda: torch_top_k_top_p_joint_sampling_from_probs(
normalized_prob.clone(), top_k_tensor, top_p_tensor
)
elif provider == "sglang":
fn = lambda: sgl_kernel.top_k_top_p_sampling_from_probs(
normalized_prob.clone(),
top_k_tensor,
top_p_tensor,
filter_apply_order="joint",
)
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
# Correctness check
for cfg in configs:
calculate_diff(*cfg)
print("\n" + "=" * 60)
print("Starting performance benchmark...")
benchmark_sampling.run(print_data=True)

75
sgl-kernel/build.sh Executable file
View File

@@ -0,0 +1,75 @@
#!/bin/bash
set -ex
PYTHON_VERSION=$1
CUDA_VERSION=$2
PYTHON_ROOT_PATH=/opt/python/cp${PYTHON_VERSION//.}-cp${PYTHON_VERSION//.}
if [ -z "$3" ]; then
ARCH=$(uname -i)
else
ARCH=$3
fi
echo "ARCH: $ARCH"
if [ ${ARCH} = "aarch64" ]; then
LIBCUDA_ARCH="sbsa"
BUILDER_NAME="pytorch/manylinuxaarch64-builder"
CMAKE_BUILD_PARALLEL_LEVEL=16
else
LIBCUDA_ARCH=${ARCH}
BUILDER_NAME="pytorch/manylinux2_28-builder"
fi
if [ ${CUDA_VERSION} = "12.9" ]; then
DOCKER_IMAGE="${BUILDER_NAME}:cuda${CUDA_VERSION}"
TORCH_INSTALL="pip install --no-cache-dir torch==2.8.0 --index-url https://download.pytorch.org/whl/cu129"
elif [ ${CUDA_VERSION} = "12.8" ]; then
DOCKER_IMAGE="${BUILDER_NAME}:cuda${CUDA_VERSION}"
TORCH_INSTALL="pip install --no-cache-dir torch==2.8.0 --index-url https://download.pytorch.org/whl/cu128"
else
DOCKER_IMAGE="${BUILDER_NAME}:cuda${CUDA_VERSION}"
TORCH_INSTALL="pip install --no-cache-dir torch==2.8.0 --index-url https://download.pytorch.org/whl/cu126"
fi
docker run --rm \
-v $(pwd):/sgl-kernel \
${DOCKER_IMAGE} \
bash -c "
# Install CMake (version >= 3.26) - Robust Installation
export CMAKE_VERSION_MAJOR=3.31
export CMAKE_VERSION_MINOR=1
# Setting these flags to reduce OOM chance only on ARM
if [ \"${ARCH}\" = \"aarch64\" ]; then
export CUDA_NVCC_FLAGS=\"-Xcudafe --threads=2\"
export MAKEFLAGS='-j2'
export CMAKE_BUILD_PARALLEL_LEVEL=2
export NINJAFLAGS='-j2'
fi
echo \"Downloading CMake from: https://cmake.org/files/v\${CMAKE_VERSION_MAJOR}/cmake-\${CMAKE_VERSION_MAJOR}.\${CMAKE_VERSION_MINOR}-linux-${ARCH}.tar.gz\"
wget https://cmake.org/files/v\${CMAKE_VERSION_MAJOR}/cmake-\${CMAKE_VERSION_MAJOR}.\${CMAKE_VERSION_MINOR}-linux-${ARCH}.tar.gz
tar -xzf cmake-\${CMAKE_VERSION_MAJOR}.\${CMAKE_VERSION_MINOR}-linux-${ARCH}.tar.gz
mv cmake-\${CMAKE_VERSION_MAJOR}.\${CMAKE_VERSION_MINOR}-linux-${ARCH} /opt/cmake
export PATH=/opt/cmake/bin:\$PATH
export LD_LIBRARY_PATH=/lib64:\$LD_LIBRARY_PATH
# Debugging CMake
echo \"PATH: \$PATH\"
which cmake
cmake --version
yum install numactl-devel -y && \
yum install libibverbs -y --nogpgcheck && \
ln -sv /usr/lib64/libibverbs.so.1 /usr/lib64/libibverbs.so && \
${PYTHON_ROOT_PATH}/bin/${TORCH_INSTALL} && \
${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir ninja setuptools==75.0.0 wheel==0.41.0 numpy uv scikit-build-core && \
export TORCH_CUDA_ARCH_LIST='8.0 8.9 9.0+PTX' && \
export CUDA_VERSION=${CUDA_VERSION} && \
mkdir -p /usr/lib/${ARCH}-linux-gnu/ && \
ln -s /usr/local/cuda-${CUDA_VERSION}/targets/${LIBCUDA_ARCH}-linux/lib/stubs/libcuda.so /usr/lib/${ARCH}-linux-gnu/libcuda.so && \
cd /sgl-kernel && \
ls -la ${PYTHON_ROOT_PATH}/lib/python${PYTHON_VERSION}/site-packages/wheel/ && \
PYTHONPATH=${PYTHON_ROOT_PATH}/lib/python${PYTHON_VERSION}/site-packages ${PYTHON_ROOT_PATH}/bin/python -m uv build --wheel -Cbuild-dir=build . --color=always --no-build-isolation && \
./rename_wheels.sh
"

View File

@@ -0,0 +1,117 @@
import ctypes
import os
import platform
import torch
SYSTEM_ARCH = platform.machine()
cuda_path = f"/usr/local/cuda/targets/{SYSTEM_ARCH}-linux/lib/libcudart.so.12"
if os.path.exists(cuda_path):
ctypes.CDLL(cuda_path, mode=ctypes.RTLD_GLOBAL)
from sgl_kernel import common_ops
from sgl_kernel.allreduce import *
from sgl_kernel.attention import (
cutlass_mla_decode,
cutlass_mla_get_workspace_size,
lightning_attention_decode,
merge_state,
merge_state_v2,
)
from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_data
from sgl_kernel.elementwise import (
FusedSetKVBufferArg,
apply_rope_with_cos_sin_cache_inplace,
downcast_fp8,
fused_add_rmsnorm,
gelu_and_mul,
gelu_tanh_and_mul,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
rmsnorm,
silu_and_mul,
)
if torch.version.hip is not None:
from sgl_kernel.elementwise import gelu_quick
from sgl_kernel.fused_moe import fused_marlin_moe
from sgl_kernel.gemm import (
awq_dequantize,
bmm_fp8,
cutlass_scaled_fp4_mm,
dsv3_fused_a_gemm,
dsv3_router_gemm,
fp8_blockwise_scaled_mm,
fp8_scaled_mm,
gptq_gemm,
gptq_marlin_gemm,
gptq_shuffle,
int8_scaled_mm,
qserve_w4a8_per_chn_gemm,
qserve_w4a8_per_group_gemm,
scaled_fp4_experts_quant,
scaled_fp4_grouped_quant,
scaled_fp4_quant,
sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_fp8,
sgl_per_token_group_quant_int8,
sgl_per_token_quant_fp8,
shuffle_rows,
silu_and_mul_scaled_fp4_grouped_quant,
)
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
from sgl_kernel.kvcacheio import (
transfer_kv_all_layer,
transfer_kv_all_layer_mla,
transfer_kv_per_layer,
transfer_kv_per_layer_mla,
)
from sgl_kernel.marlin import (
awq_marlin_moe_repack,
awq_marlin_repack,
gptq_marlin_repack,
)
from sgl_kernel.memory import set_kv_buffer_kernel
from sgl_kernel.moe import (
apply_shuffle_mul_sum,
cutlass_fp4_group_mm,
ep_moe_post_reorder,
ep_moe_pre_reorder,
ep_moe_silu_and_mul,
fp8_blockwise_scaled_grouped_mm,
moe_align_block_size,
moe_fused_gate,
prepare_moe_input,
topk_softmax,
)
from sgl_kernel.sampling import (
min_p_sampling_from_probs,
top_k_mask_logits,
top_k_renorm_prob,
top_k_top_p_sampling_from_logits,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
top_p_sampling_from_probs,
)
from sgl_kernel.speculative import (
build_tree_kernel_efficient,
segment_packbits,
tree_speculative_sampling_target_only,
verify_tree_greedy,
)
from sgl_kernel.top_k import fast_topk
from sgl_kernel.version import __version__
def create_greenctx_stream_by_value(*args, **kwargs):
from sgl_kernel.spatial import create_greenctx_stream_by_value as _impl
return _impl(*args, **kwargs)
def get_sm_available(*args, **kwargs):
from sgl_kernel.spatial import get_sm_available as _impl
return _impl(*args, **kwargs)

View File

@@ -0,0 +1,173 @@
from typing import List, Optional, Tuple
import torch
if torch.version.hip is not None:
# ROCM custom allreduce
def init_custom_ar(
meta: torch.Tensor,
rank_data: torch.Tensor,
handles: List[str],
offsets: List[int],
rank: int,
full_nvlink: bool,
) -> int:
return torch.ops.sgl_kernel.init_custom_ar.default(
meta, rank_data, handles, offsets, rank, full_nvlink
)
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
torch.ops.sgl_kernel.all_reduce_reg.default(fa, inp, out)
def all_reduce_unreg(
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
) -> None:
torch.ops.sgl_kernel.all_reduce_unreg.default(fa, inp, reg_buffer, out)
def dispose(fa: int) -> None:
torch.ops.sgl_kernel.dispose.default(fa)
def meta_size() -> int:
return torch.ops.sgl_kernel.meta_size.default()
def register_buffer(
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
) -> None:
return torch.ops.sgl_kernel.register_buffer.default(fa, t, handles, offsets)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta.default(fa)
def register_graph_buffers(
fa: int, handles: List[str], offsets: List[List[int]]
) -> None:
torch.ops.sgl_kernel.register_graph_buffers.default(fa, handles, offsets)
def allocate_meta_buffer(size: int) -> torch.Tensor:
return torch.ops.sgl_kernel.allocate_meta_buffer.default(size)
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle.default(inp)
# ROCM quick allreduce
def init_custom_qr(
rank: int, world_size: int, qr_max_size: Optional[int] = None
) -> int:
return torch.ops.sgl_kernel.init_custom_qr.default(
world_size, rank, qr_max_size
)
def qr_get_handle(fa: int) -> torch.Tensor:
return torch.ops.sgl_kernel.qr_get_handle.default(fa)
def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None:
torch.ops.sgl_kernel.qr_open_handles.default(fa, handles)
def qr_all_reduce(
fa: int,
profile: int,
inp: torch.Tensor,
out: torch.Tensor,
cast_bf162half: bool,
) -> None:
torch.ops.sgl_kernel.qr_all_reduce.default(
fa, profile, inp, out, cast_bf162half
)
def qr_destroy(fa: int) -> None:
torch.ops.sgl_kernel.qr_destroy.default(fa)
def qr_max_size() -> int:
return torch.ops.sgl_kernel.qr_max_size.default()
# mscclpp
def mscclpp_generate_unique_id() -> bytes:
raise NotImplementedError()
def mscclpp_init_context(
unique_id: bytes,
rank: int,
world_size: int,
scratch: torch.Tensor,
put_buffer: torch.Tensor,
nranks_per_node: int,
rank_to_node: List[int],
rank_to_ib: List[int],
context_selection: int,
) -> int:
raise NotImplementedError()
def mscclpp_allreduce(
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
) -> None:
raise NotImplementedError()
else:
def init_custom_ar(
ipc_tensors: List[int], rank_data: torch.Tensor, rank: int, full_nvlink: bool
) -> int:
return torch.ops.sgl_kernel.init_custom_ar.default(
ipc_tensors, rank_data, rank, full_nvlink
)
def dispose(fa: int) -> None:
torch.ops.sgl_kernel.dispose.default(fa)
def all_reduce(
fa: int,
inp: torch.Tensor,
out: torch.Tensor,
reg_buffer: int,
reg_buffer_sz_bytes: int,
) -> None:
torch.ops.sgl_kernel.all_reduce.default(
fa, inp, out, reg_buffer, reg_buffer_sz_bytes
)
def get_graph_buffer_ipc_meta(fa) -> Tuple[List[int], List[int]]:
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta.default(fa)
def register_buffer(fa: int, fake_ipc_ptrs: List[int]) -> None:
return torch.ops.sgl_kernel.register_buffer.default(fa, fake_ipc_ptrs)
def register_graph_buffers(
fa: int, handles: List[List[int]], offsets: List[List[int]]
) -> None:
torch.ops.sgl_kernel.register_graph_buffers.default(fa, handles, offsets)
def meta_size() -> int:
return torch.ops.sgl_kernel.meta_size.default()
def mscclpp_generate_unique_id() -> torch.Tensor:
return torch.ops.sgl_kernel.mscclpp_generate_unique_id.default()
def mscclpp_init_context(
unique_id: torch.Tensor,
rank: int,
world_size: int,
scratch: torch.Tensor,
put_buffer: torch.Tensor,
nranks_per_node: int,
rank_to_node: List[int],
rank_to_ib: List[int],
context_selection: int,
) -> int:
return torch.ops.sgl_kernel.mscclpp_init_context.default(
unique_id,
rank,
world_size,
scratch,
put_buffer,
nranks_per_node,
rank_to_node,
rank_to_ib,
context_selection,
)
def mscclpp_allreduce(
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
) -> None:
torch.ops.sgl_kernel.mscclpp_allreduce.default(
context, inp, out, nthreads, nblocks
)

View File

@@ -0,0 +1,138 @@
from typing import Optional, Tuple
import torch
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
torch.ops.sgl_kernel.lightning_attention_decode.default(
q, k, v, past_kv, slope, output, new_kv
)
def merge_state(
v_a: torch.Tensor,
s_a: torch.Tensor,
v_b: torch.Tensor,
s_b: torch.Tensor,
v_merged: Optional[torch.Tensor] = None,
s_merged: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
s_a = s_a.to(torch.float32)
s_b = s_b.to(torch.float32)
# Avoid creating new tensors if they are already provided
if v_merged is None:
v_merged = torch.empty_like(v_a)
if s_merged is None:
s_merged = torch.empty_like(s_a)
torch.ops.sgl_kernel.merge_state.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
return v_merged, s_merged
def merge_state_v2(
v_a: torch.Tensor,
s_a: torch.Tensor,
v_b: torch.Tensor,
s_b: torch.Tensor,
v_merged: Optional[torch.Tensor] = None,
s_merged: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
s_a = s_a.to(torch.float32)
s_b = s_b.to(torch.float32)
# TODO(DefTruth): Currently, the custom merge_attn_states kernel
# does not support the FP8 data type and non - CUDA devices.
# It may be necessary to fall back to using the Triton kernel.
# Avoid creating new tensors if they are already provided
if v_merged is None:
v_merged = torch.empty_like(v_a)
if s_merged is None:
s_merged = torch.empty_like(s_a)
torch.ops.sgl_kernel.merge_state_v2.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
return v_merged, s_merged
def cutlass_mla_decode(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
seq_lens: torch.Tensor,
page_table: torch.Tensor,
workspace: torch.Tensor,
sm_scale: float,
num_kv_splits: int = 1, # Set to 1 to avoid cuda_graph issue by default.
) -> torch.Tensor:
assert q_nope.ndim == 3, f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
assert q_pe.ndim == 3, f"q_pe must be a 3D tensor, but got {q_pe.ndim}"
assert (
kv_c_and_k_pe_cache.ndim == 3
), f"kv_c_and_k_pe_cache must be a 3D tensor, but got {kv_c_and_k_pe_cache.ndim}"
B_q, H, D_q_nope = q_nope.shape
B_q_2, H_2, D_q_pe = q_pe.shape
assert (B_q == B_q_2) and (H == H_2)
_, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape
D_latent = 512
D_rope = 64
assert D_q_nope == D_latent
assert D_q_pe == D_rope
assert D_ckv == D_latent + D_rope
MAX_HEADS = 128
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
if H < MAX_HEADS:
q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope))
q_nope_padded[:, :H] = q_nope
q_nope = q_nope_padded
q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe))
q_pe_padded[:, :H] = q_pe
q_pe = q_pe_padded
assert len(page_table.shape) == 2
B_block_table, block_num = page_table.shape
assert B_block_table == B_q
assert block_num > 0, f"block num must be greater than 0, got {block_num}"
assert block_num % (128 / PAGE_SIZE) == 0
# TODO(kaixih@nvidia): support fp8
assert q_nope.dtype in (
torch.float16,
torch.bfloat16,
), f"q_nope.dtype needs to be fp16 or bf16 but got {q_nope.dtype}."
assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype
assert (
seq_lens.dtype == torch.int32
), f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}."
assert (
page_table.dtype == torch.int32
), f"page_table.dtype needs to be int32 but got {page_table.dtype}."
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent))
torch.ops.sgl_kernel.cutlass_mla_decode.default(
out,
q_nope,
q_pe,
kv_c_and_k_pe_cache,
seq_lens,
page_table,
workspace,
sm_scale,
num_kv_splits,
)
return out[:, :H].contiguous()
def cutlass_mla_get_workspace_size(
max_seq_len: int,
num_batches: int,
sm_count: int = 0,
num_kv_splits: int = 1, # Set to 1 to avoid cuda_graph issue by default.
) -> int:
assert max_seq_len > 0, f"max_seq_len must be greater than 0, got {max_seq_len}"
assert num_batches > 0, f"num_batches must be greater than 0, got {num_batches}"
return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size.default(
max_seq_len, num_batches, sm_count, num_kv_splits
)

View File

@@ -0,0 +1,112 @@
import torch
def get_cutlass_w4a8_moe_mm_data(
topk_ids: torch.Tensor,
expert_offsets: torch.Tensor,
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
input_permutation: torch.Tensor,
output_permutation: torch.Tensor,
num_experts: int,
n: int,
k: int,
):
"""
Prepare data necessary to perform CUTLASS grouped matrix multiplications
used in CUTLASS-based fused MoE.
The function takes in topk_ids (token-expert mapping) and uses it to
compute:
- expert_offsets: Indices that mark at which token index each expert begins
its computation after the input is sorted with
input_permutation. The number of tokens computed with
expert E is expert_offsets[E + 1] - expert_offsets[E]
- problem_sizes1, problem_sizes2: MxNxK sizes of each expert's
multiplication in two grouped MMs used in
the fused MoE operation.
- input_permutation: Permutation that must be used to shuffle the input
before executing the MMs.
- output_permutation: Permutation that must be used to shuffle the output
after executing the MMs.
"""
torch.ops.sgl_kernel.get_cutlass_w4a8_moe_mm_data.default(
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
input_permutation,
output_permutation,
num_experts,
n,
k,
)
def cutlass_w4a8_moe_mm(
d: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
a_scales: torch.Tensor,
b_scales: torch.Tensor,
experts_offsets: torch.tensor,
problem_sizes: torch.tensor,
a_strides: torch.tensor,
b_strides: torch.tensor,
d_strides: torch.tensor,
s_strides: torch.tensor,
chunk_size: int = 128,
topk: int = 8,
):
"""
Perform grouped matrix multiplication between int4 weights and fp8 activations.
This function executes multiple GEMM operations in parallel, which is useful for
scenarios like Mixture of Experts (MoE) where different inputs go through different
experts. The implementation leverages NVIDIA Hopper architecture features for
optimal performance with quantized weights.
Args:
d: Output matrices of shape [total_m, total_n]
a: Activation matrices in FP8 (float_e4m3_t) format
Each tensor should be of shape [total_m, K] in row-major layout
b: Weight matrices in packed int4 format
Each tensor should be of shape [E, N, K//2] in column-major layout
where each byte contains two 4-bit integers
a_scales: Scale factors for the inputs
b_scales: Scale factors for the quantized weights
Each tensor should be of shape [E, K//512, N*8]
experts_offsets: Tensor containing expert offsets for determining group boundaries
problem_sizes: with shape [num_experts, 3] (M, N, K for each group) (int32)
a_strides: Strides information for A matrices
b_strides: Strides information for B matrices
d_strides: Strides information for D matrices
s_strides: Strides information for b_scales matrices
chunk_size: Number of elements each scale value applies to (K//512), default to 128
Requirements:
- All tensors must be on a CUDA device
- Requires an NVIDIA Hopper GPU (H100)
- A tensors must be in float8_e4m3fn format
- B tensors must contain packed int4 values (stored as int8)
Note:
The function computes: D = (A * (B * scales))
for each group of tensors in parallel
"""
torch.ops.sgl_kernel.cutlass_w4a8_moe_mm.default(
d,
a,
b,
a_scales,
b_scales,
experts_offsets,
problem_sizes,
a_strides,
b_strides,
d_strides,
s_strides,
chunk_size,
topk,
)

View File

@@ -0,0 +1,369 @@
from dataclasses import dataclass
from typing import Optional
import torch
from sgl_kernel.utils import get_cuda_stream, is_arch_support_pdl
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
# Kudos to @yzh119
def rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
out: Optional[torch.Tensor] = None,
enable_pdl: Optional[bool] = None,
) -> torch.Tensor:
r"""Root mean square normalization.
``out[i] = (input[i] / RMS(input)) * weight[i]``
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
out: Optional[torch.Tensor]
The output tensor, if specified, the kernel will update this tensor inplace.
enable_pdl: Optional[bool]
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
If None, will be automatically enabled on Hopper architecture.
Returns
-------
output: torch.Tensor
Normalized tensor, shape (batch_size, hidden_size).
"""
if out is None:
out = torch.empty_like(input)
if enable_pdl is None:
enable_pdl = is_arch_support_pdl()
torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, enable_pdl)
return out
def fused_add_rmsnorm(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
enable_pdl: Optional[bool] = None,
) -> None:
r"""Fused add root mean square normalization.
Step 1:
``residual[i] += input[i]``
Step 2:
``input[i] = (residual[i] / RMS(residual)) * weight[i]``
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
residual: torch.Tensor
Residual tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
enable_pdl: Optional[bool]
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
If None, will be automatically enabled on Hopper architecture.
"""
if enable_pdl is None:
enable_pdl = is_arch_support_pdl()
torch.ops.sgl_kernel.fused_add_rmsnorm.default(
input, residual, weight, eps, enable_pdl
)
def gemma_rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
out: Optional[torch.Tensor] = None,
enable_pdl: Optional[bool] = None,
) -> torch.Tensor:
r"""Gemma-style root mean square normalization.
``out[i] = (input[i] / RMS(input)) * (weight[i] + 1)``
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
out: Optional[torch.Tensor]
The output tensor, if specified, the kernel will update this tensor inplace.
enable_pdl: Optional[bool]
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
If None, will be automatically enabled on Hopper architecture.
Returns
-------
output: torch.Tensor
Gemma Normalized tensor, shape (batch_size, hidden_size).
"""
if out is None:
out = torch.empty_like(input)
if enable_pdl is None:
enable_pdl = is_arch_support_pdl()
torch.ops.sgl_kernel.gemma_rmsnorm.default(out, input, weight, eps, enable_pdl)
return out
def gemma_fused_add_rmsnorm(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
enable_pdl: Optional[bool] = None,
) -> None:
r"""Gemma-style fused add root mean square normalization.
Step 1:
``residual[i] += input[i]``
Step 2:
``input[i] = (residual[i] / RMS(residual)) * (weight + 1)``
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
residual: torch.Tensor
Residual tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
enable_pdl: Optional[bool]
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
If None, will be automatically enabled on Hopper architecture.
"""
if enable_pdl is None:
enable_pdl = is_arch_support_pdl()
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default(
input, residual, weight, eps, enable_pdl
)
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}"
assert (
input.shape[:-1] == output.shape[:-1]
), f"{input.shape[:-1]} != {output.shape[:-1]}"
assert (
input.shape[-1] == 2 * output.shape[-1]
), f"{input.shape[-1]} != {2 * output.shape[-1]}"
def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError("The pointers must be multiple of 16 bytes.")
if out is not None:
_check_shape(input, out)
else:
out = torch.empty(
input.shape[:-1] + (input.shape[-1] // 2,),
device=input.device,
dtype=input.dtype,
)
torch.ops.sgl_kernel.silu_and_mul.default(out, input)
return out
def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError("The pointers must be multiple of 16 bytes.")
if out is not None:
_check_shape(input, out)
else:
out = torch.empty(
input.shape[:-1] + (input.shape[-1] // 2,),
device=input.device,
dtype=input.dtype,
)
torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input)
return out
def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError("The pointers must be multiple of 16 bytes.")
if out is not None:
_check_shape(input, out)
else:
out = torch.empty(
input.shape[:-1] + (input.shape[-1] // 2,),
device=input.device,
dtype=input.dtype,
)
torch.ops.sgl_kernel.gelu_and_mul.default(out, input)
return out
if torch.version.hip is not None:
def gelu_quick(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
"""
Quick-GELU: y = x * sigmoid(1.702 * x)
The CUDA/HIP kernel uses 128-bit (16-byte) vector loads & stores,
so the last-dimension byte length must be a multiple of 16 bytes.
"""
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError(
f"The last dimension ({input.shape[-1]}) x itemsize "
f"({input.dtype.itemsize}) must be a multiple of 16 bytes."
)
if out is not None:
assert input.shape == out.shape, f"{input.shape} != {out.shape}"
else:
out = torch.empty_like(input)
torch.ops.sgl_kernel.gelu_quick(out, input)
return out
@dataclass
class FusedSetKVBufferArg:
"""
value : Optional[torch.Tensor]
Value tensor, shape: ``(nnz, num_v_heads * head_size)``.
k_buffer : Optional[torch.Tensor]
Buffer for keys, shape: ``(nnz, num_k_heads * head_size)``.
v_buffer : Optional[torch.Tensor]
Buffer for values, shape: ``(nnz, num_v_heads * head_size)``.
k_scale : Optional[float]
Scale factor for keys.
v_scale : Optional[float]
Scale factor for values.
cache_loc : Optional[torch.Tensor]
Cache location tensor, used for indexing kv cache.
"""
value: torch.Tensor
k_buffer: torch.Tensor
v_buffer: torch.Tensor
k_scale: Optional[float]
v_scale: Optional[float]
cache_loc: torch.Tensor
def apply_rope_with_cos_sin_cache_inplace(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool = True,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
enable_pdl: Optional[bool] = None,
) -> None:
r"""
Apply rotary embedding to keys and queries with precomputed cos/sin values.
This is designed to be compatible with the SGL/vLLM implementation.
The result is inplace applied to the input tensors.
Parameters
----------
positions : torch.Tensor
Position indices, shape: ``(nnz)``.
query : torch.Tensor
Query tensor, shape: ``(nnz, num_q_heads * head_size)``.
key : torch.Tensor
Key tensor, shape: ``(nnz, num_k_heads * head_size)``.
cos_sin_cache : torch.Tensor
Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``.
Cosine is the first half and Sine is the second half on rotary_dim.
is_neox : bool
Whether to use Neox style RoPE, default: ``True``.
* If ``True``, the last dimension of the query/key tensor is not interleaved, i.e.,
we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half
dimensions ``([..., head_dim//2:])``.
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
fused_set_kv_buffer_arg : FusedSetKVBufferArg
Fuse the set-kv-buffer operation into this kernel
Note
----
The rotary dimension is determined by the cosine cache and sine cache.
"""
if cos_sin_cache.dtype != torch.float32:
raise ValueError("cos_sin_cache should be float32")
if enable_pdl is None:
# the non-fused branch does not yet support PDL, but after we switch to our impl for that branch it will
enable_pdl = is_arch_support_pdl() and (fused_set_kv_buffer_arg is not None)
if (a := fused_set_kv_buffer_arg) is not None:
assert a.k_scale is None, "k_scale is not yet supported"
assert a.v_scale is None, "v_scale is not yet supported"
assert a.cache_loc.dtype == torch.int64, f"{a.cache_loc.dtype=}"
def _view_3d(x):
return x.view(x.shape[0], -1, head_size)
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
_view_3d(query),
_view_3d(key),
_view_3d(query),
_view_3d(key),
cos_sin_cache,
positions.long(),
(not is_neox),
enable_pdl,
get_cuda_stream(),
(
_view_3d(fused_set_kv_buffer_arg.value)
if fused_set_kv_buffer_arg is not None
else None
),
(
_view_3d(fused_set_kv_buffer_arg.k_buffer)
if fused_set_kv_buffer_arg is not None
else None
),
(
_view_3d(fused_set_kv_buffer_arg.v_buffer)
if fused_set_kv_buffer_arg is not None
else None
),
(
fused_set_kv_buffer_arg.cache_loc
if fused_set_kv_buffer_arg is not None
else None
),
)
def downcast_fp8(
k: torch.Tensor,
v: torch.Tensor,
k_out: torch.Tensor,
v_out: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
loc: torch.Tensor,
mult: int = 1,
offset: int = 0,
) -> None:
torch.ops.sgl_kernel.downcast_fp8(
k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, get_cuda_stream()
)

View File

@@ -0,0 +1,288 @@
from functools import lru_cache
from typing import Optional, Union
import torch
import torch.nn as nn
try:
from sgl_kernel import flash_ops
except:
raise ImportError("Can not import sgl_kernel. Please check your installation.")
@lru_cache(maxsize=1)
def is_fa3_supported(device=None) -> bool:
# There some fa3 FYI
# FA3 can fail without a enough shared memory for a some shapes, such as higher
# hidden_dim or some special cases.
# Right now, fa3 is supported for sm80/sm87 and sm86/sm89. The main different
# Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
# That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
return (torch.version.cuda >= "12.3") and (
torch.cuda.get_device_capability(device)[0] == 9
or torch.cuda.get_device_capability(device)[0] == 8
)
def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
def flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k=None,
v=None,
qv=None,
rotary_cos=None,
rotary_sin=None,
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k_new: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
rotary_seqlens: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
rotary_interleaved=True,
scheduler_metadata=None,
num_splits=0, # Can be tuned for speed
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
return_softmax_lse=False,
sinks=None,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
the previous step, and update them with the new keys/values from the current step, and do
attention with the updated cache, all in 1 kernel.
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Note: Does not support backward pass.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
page_block_size must be a multiple of 256.
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
k with k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
qv [optional]: (batch_size, seqlen, nheads, headdim_v)
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
KV cache.
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
If the indices are not distinct, and k and v are provided, the values updated in the cache
might come from any of the duplicate indices.
cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
(i.e. GPT-NeoX style).
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
Don't change this unless you know what you are doing.
return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
if softmax_scale is None:
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
-0.5
)
if cache_seqlens is not None and isinstance(cache_seqlens, int):
cache_seqlens = torch.full(
(k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
)
cache_seqlens = maybe_contiguous(cache_seqlens)
q, k_cache, k, v = [maybe_contiguous(x) for x in (q, k_cache, k, v)]
v_cache = (
v_cache.contiguous()
if v_cache.stride(-1) != 1 and v_cache.stride(-3) != 1
else v_cache
)
cu_seqlens_q, cu_seqlens_k_new = [
maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k_new)
]
page_table, cache_batch_idx, cache_leftpad = [
maybe_contiguous(x) for x in (page_table, cache_batch_idx, cache_leftpad)
]
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
rotary_seqlens = maybe_contiguous(rotary_seqlens)
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
q,
k_cache,
v_cache,
k,
v,
qv,
None, # out
cu_seqlens_q,
None, # cu_seqlens_k
cu_seqlens_k_new,
None, # seqused_q
cache_seqlens,
max_seqlen_q,
None, # max_seqlen_k
page_table,
cache_batch_idx,
cache_leftpad,
rotary_cos,
rotary_sin,
rotary_seqlens,
q_descale,
k_descale,
v_descale,
softmax_scale,
causal,
window_size[0],
window_size[1],
softcap,
rotary_interleaved,
scheduler_metadata,
num_splits,
pack_gqa,
sm_margin,
sinks,
)
# return (out, softmax_lse) if return_softmax_lse else out
return (out, softmax_lse, *rest) if return_softmax_lse else out
def flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
seqused_q=None,
seqused_k=None,
softmax_scale=None,
causal=False,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=(-1, -1),
softcap=0.0,
num_splits=1,
pack_gqa=None,
sm_margin=0,
return_softmax_lse=False,
sinks=None,
):
if not is_fa3_supported():
raise NotImplementedError(
"flash_attn at sgl-kernel is only supported on sm90 and above"
)
if softmax_scale is None:
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
-0.5
)
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
q,
k,
v,
None, # k_new
None, # v_new
qv, # qv
None, # out
cu_seqlens_q,
cu_seqlens_k,
None, # cu_seqlens_k_new
seqused_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
None, # page_table,
None, # kv_batch_idx
None, # leftpad_k
None, # rotary cos
None, # rotary sin
None, # seqlens_rotary
q_descale,
k_descale,
v_descale,
softmax_scale,
causal,
window_size[0],
window_size[1],
softcap,
is_rotary_interleaved=False,
scheduler_metadata=None,
num_splits=num_splits,
pack_gqa=pack_gqa,
sm_margin=sm_margin,
sinks=sinks,
)
return (out, softmax_lse, *rest) if return_softmax_lse else out

View File

@@ -0,0 +1,225 @@
import functools
from typing import Optional
import torch
from sgl_kernel import silu_and_mul
def get_scalar_type(num_bits: int, has_zp: bool):
from sgl_kernel.scalar_type import scalar_types
if has_zp:
assert num_bits == 4
return scalar_types.uint4
else:
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
def fused_marlin_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
inplace: bool = False,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- w1_scale (torch.Tensor): Scale to be used for w1.
- w2_scale (torch.Tensor): Scale to be used for w2.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
- sort_indices1 (Optional[torch.Tensor]): The first act_order input
permutation.
- sort_indices2 (Optional[torch.Tensor]): The second act_order input
permutation.
- topk_weights (torch.Tensor): Top-k weights.
- topk_ids (torch.Tensor): Indices of topk-k elements.
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Delay the import to avoid circular dependency
from sglang.srt.layers.moe.fused_moe_triton import (
moe_align_block_size,
try_get_optimal_moe_config,
)
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[1] == w2.shape[2] // (
num_bits // 2
), "Hidden size mismatch w2"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
assert num_bits in [4, 8]
M, K = hidden_states.shape
E = w1.shape[0]
N = w2.shape[1] * 16
topk = topk_ids.shape[1]
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
w2.shape,
topk_ids.shape[1],
None,
is_marlin=True,
)
config = get_config_func(M)
block_size_m = config["BLOCK_SIZE_M"]
if global_num_experts == -1:
global_num_experts = E
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, block_size_m, global_num_experts
)
if workspace is None:
max_workspace_size = (max(2 * N, K) // 64) * (
sorted_token_ids.size(0) // block_size_m
)
device = hidden_states.device
sms = torch.cuda.get_device_properties(device).multi_processor_count
max_workspace_size = min(max_workspace_size, sms * 4)
workspace = torch.zeros(
max_workspace_size, dtype=torch.int, device=device, requires_grad=False
)
scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None)
scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None)
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache13 = torch.empty(
(M * topk_ids.shape[1] * max(2 * N, K),),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache1 = intermediate_cache13[: M * topk_ids.shape[1] * 2 * N]
intermediate_cache1 = intermediate_cache1.view(-1, 2 * N)
intermediate_cache3 = intermediate_cache13[: M * topk_ids.shape[1] * K]
intermediate_cache3 = intermediate_cache3.view(-1, K)
use_atomic_add = (
hidden_states.dtype == torch.half
or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
)
intermediate_cache1 = torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default(
hidden_states,
intermediate_cache1,
w1,
w1_scale,
w1_zeros,
g_idx1,
sort_indices1,
workspace,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
topk_weights,
moe_block_size=block_size_m,
top_k=topk,
mul_topk_weights=False,
is_ep=expert_map is not None,
b_q_type_id=scalar_type1.id,
size_m=M,
size_n=2 * N,
size_k=K,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=True,
is_zp_float=False,
)
silu_and_mul(intermediate_cache1.view(-1, 2 * N), intermediate_cache2)
if expert_map is not None:
intermediate_cache3.zero_()
intermediate_cache3 = torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default(
intermediate_cache2,
intermediate_cache3,
w2,
w2_scale,
w2_zeros,
g_idx2,
sort_indices2,
workspace,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
topk_weights,
moe_block_size=block_size_m,
top_k=1,
mul_topk_weights=True,
is_ep=expert_map is not None,
b_q_type_id=scalar_type2.id,
size_m=M * topk,
size_n=K,
size_k=N,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=True,
is_zp_float=False,
).view(-1, topk, K)
output = hidden_states if inplace else torch.empty_like(hidden_states)
return torch.sum(
intermediate_cache3.view(*intermediate_cache3.shape), dim=1, out=output
)
def fused_marlin_moe_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
return torch.empty_like(hidden_states)

View File

@@ -0,0 +1,550 @@
from typing import Optional, Tuple
import torch
from sgl_kernel.scalar_type import ScalarType
from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
def awq_dequantize(
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
) -> torch.ByteTensor:
return torch.ops.sgl_kernel.awq_dequantize.default(qweight, scales, qzeros)
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
return torch.ops.sgl_kernel.int8_scaled_mm.default(
mat_a,
mat_b,
scales_a,
scales_b,
out_dtype,
bias,
)
def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm.default(
mat_a,
mat_b,
scales_a,
scales_b,
out_dtype,
)
def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
return torch.ops.sgl_kernel.fp8_scaled_mm.default(
mat_a,
mat_b,
scales_a,
scales_b,
out_dtype,
bias,
)
def _bmm_fp8_internal(
workspace_buffer: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
D: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
) -> None:
cublas_handle = torch.cuda.current_blas_handle()
torch.ops.sgl_kernel.bmm_fp8.default(
A,
B,
D,
A_scale,
B_scale,
workspace_buffer,
cublas_handle,
get_cuda_stream(),
)
def bmm_fp8(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
dtype: torch.dtype,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out is None:
out = torch.empty(
(A.shape[0], A.shape[1], B.shape[2]),
device=A.device,
dtype=dtype,
)
workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device)
_bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale)
return out
def dsv3_fused_a_gemm(
mat_a: torch.Tensor,
mat_b: torch.Tensor,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output is None:
output = torch.empty(
(mat_a.shape[0], mat_b.shape[1]),
device=mat_a.device,
dtype=mat_a.dtype,
)
torch.ops.sgl_kernel.dsv3_fused_a_gemm.default(output, mat_a, mat_b)
return output
def sgl_per_token_group_quant_fp8(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
group_size: int,
eps: float,
fp8_min: float,
fp8_max: float,
scale_ue8m0: bool,
) -> None:
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default(
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
)
def sgl_per_token_group_quant_int8(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
group_size: int,
eps: float,
int8_min: float,
int8_max: float,
) -> None:
torch.ops.sgl_kernel.sgl_per_token_group_quant_int8.default(
input, output_q, output_s, group_size, eps, int8_min, int8_max
)
def sgl_per_tensor_quant_fp8(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
is_static: bool,
) -> None:
torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8.default(
input, output_q, output_s, is_static
)
def sgl_per_token_quant_fp8(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
) -> None:
torch.ops.sgl_kernel.sgl_per_token_quant_fp8.default(input, output_q, output_s)
def cutlass_scaled_fp4_mm(
a: torch.Tensor,
b: torch.Tensor,
block_scale_a: torch.Tensor,
block_scale_b: torch.Tensor,
alpha: torch.Tensor,
out_dtype: torch.dtype,
) -> torch.Tensor:
assert a.ndim == 2 and b.ndim == 2
m, n = a.shape[0], b.shape[0]
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
torch.ops.sgl_kernel.cutlass_scaled_fp4_mm.default(
out, a, b, block_scale_a, block_scale_b, alpha
)
return out
def scaled_fp4_quant(
input: torch.Tensor, input_global_scale: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP4 and return quantized tensor and scale.
This function quantizes the last dimension of the given tensor `input`. For
every 16 consecutive elements, a single dynamically computed scaling factor
is shared. This scaling factor is quantized using the `input_global_scale`
and is stored in a swizzled layout (see
https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x).
Args:
input: The input tensor to be quantized to FP4
input_global_scale: A scalar scaling factor for the entire tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
two values are packed into a uint8 and float8_e4m3 scaling factors
in a sizzled layout.
"""
assert input.ndim >= 1, f"input.ndim needs to be >= 1, but got {input.ndim}."
other_dims = 1 if input.ndim == 1 else -1
input = input.reshape(other_dims, input.shape[-1])
m, n = input.shape
block_size = 16
device = input.device
assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}."
assert input.dtype in (
torch.float16,
torch.bfloat16,
), f"input.dtype needs to be fp16 or bf16 but got {input.dtype}."
# Two fp4 values will be packed into an uint8.
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
# We use the rounded values to store the swizzled values. Then, the scaling
# factors in float8_e4m3fn are packed into an int32 for every 4 values.
rounded_m = ((m + 128 - 1) // 128) * 128
scale_n = n // block_size
rounded_n = ((scale_n + 4 - 1) // 4) * 4
# padded part should be zeroed out
if rounded_n > scale_n:
output_scale = torch.zeros(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
)
else:
output_scale = torch.empty(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
)
torch.ops.sgl_kernel.scaled_fp4_quant.default(
output, input, output_scale, input_global_scale
)
output_scale = output_scale.view(torch.float8_e4m3fn)
return output, output_scale
def qserve_w4a8_per_chn_gemm(
in_feats: torch.Tensor,
kernel: torch.Tensor,
wscales: torch.Tensor,
ascales: torch.Tensor,
w_szs: torch.Tensor,
a_ssums: torch.Tensor,
out_feats: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out_feats is None:
# NOTE(HandH1998): qserve_w4a8_per_chn_gemm only supports out dtype=torch.float16 now
out_feats = torch.empty(
(in_feats.shape[0], kernel.shape[0]),
device=in_feats.device,
dtype=torch.float16,
)
torch.ops.sgl_kernel.qserve_w4a8_per_chn_gemm.default(
in_feats, kernel, wscales, ascales, w_szs, a_ssums, out_feats
)
return out_feats
def qserve_w4a8_per_group_gemm(
in_feats: torch.Tensor,
kernel: torch.Tensor,
zeros: torch.Tensor,
scales_i8: torch.Tensor,
wscales: torch.Tensor,
ascales: torch.Tensor,
out_feats: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out_feats is None:
# NOTE(HandH1998): qserve_w4a8_per_group_gemm only supports out dtype=torch.float16 now
out_feats = torch.empty(
(in_feats.shape[0], kernel.shape[0]),
device=in_feats.device,
dtype=torch.float16,
)
torch.ops.sgl_kernel.qserve_w4a8_per_group_gemm.default(
in_feats, kernel, zeros, scales_i8, wscales, ascales, out_feats
)
return out_feats
def dsv3_router_gemm(
hidden_states: torch.Tensor,
router_weights: torch.Tensor,
out_dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
output = torch.empty(
hidden_states.shape[0],
router_weights.shape[0],
device=hidden_states.device,
dtype=out_dtype,
)
torch.ops.sgl_kernel.dsv3_router_gemm(
output,
hidden_states,
router_weights,
)
return output
def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape):
output_tensor = torch.empty(
output_tensor_shape,
device=input_tensor.device,
dtype=input_tensor.dtype,
)
torch.ops.sgl_kernel.shuffle_rows.default(input_tensor, dst2src_map, output_tensor)
return output_tensor
def scaled_fp4_grouped_quant(
input_tensor: torch.Tensor,
input_global_scale: torch.Tensor,
mask: torch.Tensor,
):
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
grouped gemm inputs (e.g., grouped_gemm_nt_masked for flashinfer).
Args:
input: The input tensor to be quantized to FP4, with shape (l, m, k)
l is number of groups, m is number of tokens per group, k is number of features.
input_global_scale: A scalar scaling factor for the entire tensor, with
shape (l,).
Outputs:
output: The quantized tensor in FP4, with shape (m, k // 2, l) but the physical
layout is (l, m, k // 2). `// 2` is because two fp4 values are packed into
an uint8.
output_scales: The blockscale tensor in FP8-E4M3, with shape (32, 4, rm, 4, rk, l)
but the physical layout is (l, rm, rk, 32, 4, 4).
Note:
For the shape of output_scales, `32 * 4 * rm` is a padded m to nearest multiple of 128.
`4 * rk` is a padded `k // 16` to nearest multiple of 4. These layout constants are
required by the NVIDIA Blackwell MMA operations.
"""
device = input_tensor.device
l, m, k = input_tensor.shape
sf_vec_size = 16
assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}."
scale_k = k // sf_vec_size
padded_k = (scale_k + (4 - 1)) // 4 * 4
padded_k_int32 = padded_k // 4
padded_m = (m + (128 - 1)) // 128 * 128
output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8)
output_scales = torch.empty(
l, padded_m, padded_k_int32, device=device, dtype=torch.int32
)
torch.ops.sgl_kernel.silu_and_mul_scaled_fp4_experts_quant.default(
output.view(l * m, k // 2),
output_scales.view(l * padded_m, padded_k_int32),
input_tensor.view(l * m, k),
input_global_scale,
mask,
use_silu_and_mul=False,
)
# The physical layout of the output is (l, m, k // 2), but we want to return a
# logical layout (m, k // 2, l) required by the flashinfer masked group gemm.
output = output.permute(1, 2, 0)
# The physical layout of the output scales is already swizzled as (l, rm, rk, 32, 4, 4), a
# requirement for the flashinfer masked group gemm, where rm=m/128 and rk=k/4. The logic
# layout is (32, 4, rm, 4, rk, l).
output_scales = output_scales.view(torch.float8_e4m3fn).view(
l, padded_m // 128, padded_k // 4, 32, 4, 4
)
output_scales = output_scales.permute(3, 4, 1, 5, 2, 0)
return output, output_scales
def silu_and_mul_scaled_fp4_grouped_quant(
input_tensor: torch.Tensor,
input_global_scale: torch.Tensor,
mask: torch.Tensor,
):
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
grouped gemm inputs (e.g., grouped_gemm_nt_masked for flashinfer).
Args:
input: The input tensor to be quantized to FP4, with shape (l, m, k * 2)
l is number of groups, m is number of tokens per group, k is number of features.
input_global_scale: A scalar scaling factor for the entire tensor, with
shape (l,).
mask: The mask tensor, with shape (l,)
Outputs:
output: The quantized tensor in FP4, with shape (m, k // 2, l) but the physical
layout is (l, m, k // 2). `// 2` is because two fp4 values are packed into
an uint8.
output_scales: The blockscale tensor in FP8-E4M3, with shape (32, 4, rm, 4, rk, l)
but the physical layout is (l, rm, rk, 32, 4, 4).
Note:
For the shape of output_scales, `32 * 4 * rm` is a padded m to nearest multiple of 128.
`4 * rk` is a padded `k // 16` to nearest multiple of 4. These layout constants are
required by the NVIDIA Blackwell MMA operations.
"""
device = input_tensor.device
l, m, k_by_2 = input_tensor.shape
k = k_by_2 // 2
sf_vec_size = 16
assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}."
scale_k = k // sf_vec_size
padded_k = (scale_k + (4 - 1)) // 4 * 4
padded_k_int32 = padded_k // 4
padded_m = (m + (128 - 1)) // 128 * 128
output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8)
output_scales = torch.empty(
l, padded_m, padded_k_int32, device=device, dtype=torch.int32
)
torch.ops.sgl_kernel.silu_and_mul_scaled_fp4_experts_quant.default(
output.view(l * m, k // 2),
output_scales.view(l * padded_m, padded_k_int32),
input_tensor.view(l * m, k_by_2),
input_global_scale,
mask,
use_silu_and_mul=True,
)
# The physical layout of the output is (l, m, k // 2), but we want to return a
# logical layout (m, k // 2, l) required by the flashinfer masked group gemm.
output = output.permute(1, 2, 0)
# The physical layout of the output scales is already swizzled as (l, rm, rk, 32, 4, 4), a
# requirement for the flashinfer masked group gemm, where rm=m/128 and rk=k/4. The logic
# layout is (32, 4, rm, 4, rk, l).
output_scales = output_scales.view(torch.float8_e4m3fn).view(
l, padded_m // 128, padded_k // 4, 32, 4, 4
)
output_scales = output_scales.permute(3, 4, 1, 5, 2, 0)
return output, output_scales
def scaled_fp4_experts_quant(
input_tensor: torch.Tensor,
input_global_scale: torch.Tensor,
expert_offsets: torch.Tensor,
blockscale_offsets: torch.Tensor,
topk: int,
expert_map: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
packed MoE Inputs.
Args:
input: The input tensor to be quantized to FP4
expert_map: The expert map tensor
input_global_scale: A scalar scaling factor for the entire tensor.
expert_offsets: The expert offsets tensor
blockscale_offsets: The blockscale offsets tensor
Outputs:
output: The quantized tensor in FP4
output_scales: The blockscale tensor in FP8-E4M3
"""
assert (
input_tensor.ndim == 2
), f"input.ndim needs to be == 2, but got {input_tensor.ndim}."
if expert_map is not None:
(m, k) = input_tensor.shape
output_tensor_shape = (m * topk, k)
input_tensor = shuffle_rows(input_tensor, expert_map, output_tensor_shape)
m_numtopk, k = input_tensor.shape
# Control the maximum number of tokens per expert supported by the
# NVFP4 MoE Expert Quantization. This is used to prevent the kernel
# from running out of memory. This value can also be increased to support
# larger models.
import os
MAX_TOKENS_PER_EXPERT = os.environ.get("MODELOPT_MAX_TOKENS_PER_EXPERT", 65536)
assert m_numtopk <= MAX_TOKENS_PER_EXPERT * topk, (
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
f"{MAX_TOKENS_PER_EXPERT})"
f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use"
f" MODELOPT_MAX_TOKENS_PER_EXPERT to set this value."
)
scales_k = k // 16
padded_k = (scales_k + (4 - 1)) // 4
# output is uint8 and packed fp4 values
output = torch.empty(
m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8
)
# padded part should be zeroed out
if padded_k > scales_k:
output_scales = torch.zeros(
MAX_TOKENS_PER_EXPERT * topk,
padded_k,
dtype=torch.int32,
device=input_tensor.device,
)
else:
output_scales = torch.empty(
MAX_TOKENS_PER_EXPERT * topk,
padded_k,
dtype=torch.int32,
device=input_tensor.device,
)
torch.ops.sgl_kernel.scaled_fp4_experts_quant.default(
output,
output_scales,
input_tensor,
input_global_scale,
expert_offsets,
blockscale_offsets,
)
output_scales = output_scales.view(torch.float8_e4m3fn)
return output, output_scales
# GPTQ kernels
def gptq_marlin_gemm(
a: torch.Tensor,
c: Optional[torch.Tensor],
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
global_scale: Optional[torch.Tensor],
b_zeros: Optional[torch.Tensor],
g_idx: Optional[torch.Tensor],
perm: Optional[torch.Tensor],
workspace: torch.Tensor,
b_q_type: ScalarType,
size_m: int,
size_n: int,
size_k: int,
is_k_full: bool = True,
use_atomic_add: bool = False,
use_fp32_reduce: bool = False,
is_zp_float: bool = False,
) -> torch.Tensor:
return torch.ops.sgl_kernel.gptq_marlin_gemm(
a,
c,
b_q_weight,
b_scales,
global_scale,
b_zeros,
g_idx,
perm,
workspace,
b_q_type.id,
size_m,
size_n,
size_k,
is_k_full,
use_atomic_add,
use_fp32_reduce,
is_zp_float,
)
def gptq_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor,
b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor,
use_shuffle: bool,
bit: int,
) -> torch.Tensor:
return torch.ops.sgl_kernel.gptq_gemm(
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit
)
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None:
torch.torch.ops.sgl_kernel.gptq_shuffle(q_weight, q_perm, bit)

View File

@@ -0,0 +1,15 @@
from typing import List, Optional, Union
import torch
def apply_token_bitmask_inplace_cuda(
logits: torch.Tensor,
bitmask: torch.Tensor,
indices: Optional[Union[List[int], torch.Tensor]] = None,
) -> None:
if isinstance(indices, list):
indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
if indices is not None:
indices = indices.to(logits.device)
torch.ops.sgl_kernel.apply_token_bitmask_inplace_cuda(logits, bitmask, indices)

View File

@@ -0,0 +1,218 @@
from typing import List
import torch
def is_hip() -> bool:
return torch.version.hip is not None
_is_hip = is_hip()
def transfer_kv_per_layer(
src_k: torch.Tensor,
dst_k: torch.Tensor,
src_v: torch.Tensor,
dst_v: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
item_size: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_per_layer(
src_k,
dst_k,
src_v,
dst_v,
src_indices,
dst_indices,
item_size,
block_quota,
num_warps_per_block,
)
def transfer_kv_per_layer_pf_lf(
src_k: torch.Tensor,
dst_k: torch.Tensor,
src_v: torch.Tensor,
dst_v: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
layer_id: int,
item_size: int,
src_layout_dim: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_per_layer_pf_lf(
src_k,
dst_k,
src_v,
dst_v,
src_indices,
dst_indices,
layer_id,
item_size,
src_layout_dim,
block_quota,
num_warps_per_block,
)
def transfer_kv_all_layer(
src_k_layers: torch.Tensor,
dst_k_layers: torch.Tensor,
src_v_layers: torch.Tensor,
dst_v_layers: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
item_size: int,
num_layers: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_all_layer(
src_k_layers,
dst_k_layers,
src_v_layers,
dst_v_layers,
src_indices,
dst_indices,
item_size,
num_layers,
block_quota,
num_warps_per_block,
)
def transfer_kv_all_layer_lf_pf(
src_k_layers: torch.Tensor,
dst_k: torch.Tensor,
src_v_layers: torch.Tensor,
dst_v: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
item_size: int,
dst_layout_dim: int,
num_layers: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_all_layer_lf_pf(
src_k_layers,
dst_k,
src_v_layers,
dst_v,
src_indices,
dst_indices,
item_size,
dst_layout_dim,
num_layers,
block_quota,
num_warps_per_block,
)
def transfer_kv_direct(
src_layers: List[torch.Tensor],
dst_layers: List[torch.Tensor],
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
page_size: int,
):
torch.ops.sgl_kernel.transfer_kv_direct(
src_layers, dst_layers, src_indices, dst_indices, page_size
)
def transfer_kv_per_layer_mla(
src: torch.Tensor,
dst: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
item_size: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_per_layer_mla(
src,
dst,
src_indices,
dst_indices,
item_size,
block_quota,
num_warps_per_block,
)
def transfer_kv_per_layer_mla_pf_lf(
src: torch.Tensor,
dst: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
layer_id: int,
item_size: int,
src_layout_dim: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_per_layer_mla_pf_lf(
src,
dst,
src_indices,
dst_indices,
layer_id,
item_size,
src_layout_dim,
block_quota,
num_warps_per_block,
)
def transfer_kv_all_layer_mla(
src_layers: torch.Tensor,
dst_layers: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
item_size: int,
num_layers: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_all_layer_mla(
src_layers,
dst_layers,
src_indices,
dst_indices,
item_size,
num_layers,
block_quota,
num_warps_per_block,
)
def transfer_kv_all_layer_mla_lf_pf(
src_layers: torch.Tensor,
dst: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
item_size: int,
dst_layout_dim: int,
num_layers: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_all_layer_mla_lf_pf(
src_layers,
dst,
src_indices,
dst_indices,
item_size,
dst_layout_dim,
num_layers,
block_quota,
num_warps_per_block,
)

View File

@@ -0,0 +1,44 @@
import torch
def gptq_marlin_repack(
b_q_weight,
perm,
size_k,
size_n,
num_bits,
) -> torch.Tensor:
return torch.ops.sgl_kernel.gptq_marlin_repack(
b_q_weight,
perm,
size_k,
size_n,
num_bits,
)
def awq_marlin_repack(
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
return torch.ops.sgl_kernel.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
def awq_marlin_moe_repack(
b_q_weight: torch.Tensor,
perm: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
) -> torch.Tensor:
num_experts = b_q_weight.shape[0]
assert size_k % 16 == 0
output = torch.empty(
(num_experts, size_k // 16, size_n * (num_bits // 2)),
device=b_q_weight.device,
dtype=b_q_weight.dtype,
)
for e in range(num_experts):
output[e] = torch.ops.sgl_kernel.awq_marlin_repack(
b_q_weight[e], size_k, size_n, num_bits
)
return output

View File

@@ -0,0 +1,18 @@
import torch
def set_kv_buffer_kernel(
k_cache: torch.Tensor,
v_cache: torch.Tensor,
loc: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
fallback: bool = False,
):
try:
if fallback:
raise RuntimeError("Fallback to torch implementation")
torch.ops.sgl_kernel.store_kv_cache(k_cache, v_cache, loc, k, v)
except RuntimeError: # ok, fallback to torch implementation
k_cache[loc] = k
v_cache[loc] = v

View File

@@ -0,0 +1,261 @@
from typing import Any, Dict, Optional
import torch
def moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_token_ids,
experts_ids,
num_tokens_post_pad,
cumsum_buffer,
pad_sorted_token_ids=False,
):
torch.ops.sgl_kernel.moe_align_block_size.default(
topk_ids,
num_experts,
block_size,
sorted_token_ids,
experts_ids,
num_tokens_post_pad,
cumsum_buffer,
pad_sorted_token_ids,
)
def topk_softmax(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
gating_output: float,
renormalize: bool = False,
) -> None:
torch.ops.sgl_kernel.topk_softmax.default(
topk_weights, topk_ids, gating_output, renormalize
)
def moe_fused_gate(
input_tensor,
bias,
num_expert_group,
topk_group,
topk,
num_fused_shared_experts=0,
routed_scaling_factor=0,
apply_routed_scaling_factor_on_output=False,
):
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
# as the group weight to select expert groups and then select topk experts within the selected groups
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now.
# for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
# num_fused_shared_experts: if > 0, the last several experts will be
# replaced with shared experts. the shared experts will be divided by the
# routed_scaling_factor - this is intended to cancel out later when routed+shared
# output is scaled so that shared experts are not scaled.
# routed_scaling_factor: if > 0, the experts will be scaled by this factor
# apply_routed_scaling_factor_on_output: if true, output will be
# scaled by the routed_scaling_factor
return torch.ops.sgl_kernel.moe_fused_gate.default(
input_tensor,
bias,
num_expert_group,
topk_group,
topk,
num_fused_shared_experts,
routed_scaling_factor,
apply_routed_scaling_factor_on_output,
)
def ep_moe_pre_reorder(
input_tensor,
gateup_input,
src2dst,
topk_ids,
a1_scales,
start_expert_id,
end_expert_id,
topk,
use_per_token_if_dynamic,
):
return torch.ops.sgl_kernel.ep_moe_pre_reorder.default(
input_tensor,
gateup_input,
src2dst,
topk_ids,
a1_scales,
start_expert_id,
end_expert_id,
topk,
use_per_token_if_dynamic,
)
def ep_moe_silu_and_mul(
gateup_output,
down_input,
reorder_topk_ids,
scales,
start_expert_id,
end_expert_id,
):
return torch.ops.sgl_kernel.ep_moe_silu_and_mul.default(
gateup_output,
down_input,
reorder_topk_ids,
scales,
start_expert_id,
end_expert_id,
)
def ep_moe_post_reorder(
down_output,
output,
src2dst,
topk_ids,
topk_weights,
start_expert_id,
end_expert_id,
topk,
):
return torch.ops.sgl_kernel.ep_moe_post_reorder.default(
down_output,
output,
src2dst,
topk_ids,
topk_weights,
start_expert_id,
end_expert_id,
topk,
)
def fp8_blockwise_scaled_grouped_mm(
output,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
scales_a,
scales_b,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace,
):
torch.ops.sgl_kernel.fp8_blockwise_scaled_grouped_mm.default(
output,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
scales_a,
scales_b,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace,
)
def prepare_moe_input(
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
input_permutation,
output_permutation,
num_experts,
n,
k,
blockscale_offsets: Optional[torch.Tensor] = None,
):
torch.ops.sgl_kernel.prepare_moe_input.default(
topk_ids,
expert_offsets,
blockscale_offsets,
problem_sizes1,
problem_sizes2,
input_permutation,
output_permutation,
num_experts,
n,
k,
)
def apply_shuffle_mul_sum(
input,
output,
permutation,
factors,
):
torch.ops.sgl_kernel.apply_shuffle_mul_sum.default(
input, output, permutation, factors
)
def cutlass_fp4_group_mm(
a_fp4,
b_fp4,
a_blockscale,
b_blockscale,
alphas,
out_dtype,
device,
params: Dict[str, Any],
):
"""
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
the gemms for each combination based on the specified problem sizes.
This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward.
- a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized
input and expert weights.
- a_/b_scales: The blockscales in FP8-E4M3 precision
- ab_strides/c_strides: Strides for the a/b tensors between rows.
- expert_offsets/sf_offsets: Indices that mark at which token index
each expert begins its computation. The number of tokens
computed with expert E is expert_offsets[E + 1] -
expert_offsets[E] And the sf_size per expert is
sf_offset[E+1] - sf_offset[E]
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
MMs used in the fused MoE operation.
"""
m_topk = a_fp4.shape[0]
n = b_fp4.shape[1]
c_shape = (m_topk, n)
c = torch.empty(c_shape, device=device, dtype=out_dtype)
torch.ops.sgl_kernel.cutlass_fp4_group_mm.default(
c,
a_fp4,
b_fp4,
a_blockscale,
b_blockscale,
alphas,
params["ab_strides"],
params["c_strides"],
params["problem_sizes"],
params["expert_offsets"],
params["blockscale_offsets"],
)
return c.to(dtype=out_dtype)

View File

@@ -0,0 +1,543 @@
from typing import Optional, Union
import torch
from sgl_kernel.utils import _to_tensor_scalar_tuple
def _top_k_renorm_probs_internal(
probs: torch.Tensor,
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
) -> torch.Tensor:
probs = probs.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
renorm_probs = torch.empty_like(probs)
torch.ops.sgl_kernel.top_k_renorm_probs.default(
probs, renorm_probs, maybe_top_k_arr, top_k_val
)
return renorm_probs
def top_k_renorm_probs(
probs: torch.Tensor,
top_k: Union[torch.Tensor, int],
) -> torch.Tensor:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for renormalizing probabilities by top-k thresholding.
Parameters
----------
probs: torch.Tensor
Probabilities, shape ``(batch_size, num_classes)``.
top_k: Union[torch.Tensor, int]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for
for re-normalizing probabilities, should be in ``(0, num_classes)``.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
We keep the top-k probabilities, set the rest to zero, and renormalize the probabilities.
Returns
-------
renorm_probs: torch.Tensor
Renormalized probabilities, shape ``(batch_size, num_classes)``.
Note
----
This combination of ``top_k_renorm_probs`` and ``sampling_from_probs`` should be equivalent to
``top_k_sampling_from_probs``.
"""
return _top_k_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_k))
top_k_renorm_prob = top_k_renorm_probs
def _top_p_renorm_probs_internal(
probs: torch.Tensor,
maybe_top_p_arr: Optional[torch.Tensor],
top_p_val: float,
) -> torch.Tensor:
probs = probs.float()
maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
renorm_probs = torch.empty_like(probs)
torch.ops.sgl_kernel.top_p_renorm_probs.default(
probs, renorm_probs, maybe_top_p_arr, top_p_val
)
return renorm_probs
def top_p_renorm_probs(
probs: torch.Tensor,
top_p: Union[torch.Tensor, float],
) -> torch.Tensor:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for renormalizing probabilities by top-p thresholding.
Parameters
----------
probs: torch.Tensor
Probabilities, shape ``(batch_size, num_classes)``.
top_p: Union[torch.Tensor, float]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-p threshold for for
re-normalizing probabilities, should be in ``(0, 1)``.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
We mask out the probabilities less than `threshold` where the cumulative sum
of ``probs[probs >= threshold]`` is `top_p`, and renormalize the probabilities.
Returns
-------
renorm_probs: torch.Tensor
Renormalized probabilities, shape ``(batch_size, num_classes)``.
Note
----
This combination of ``top_p_renorm_probs`` and ``sampling_from_probs`` should be equivalent to
``top_p_sampling_from_probs``.
"""
return _top_p_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_p))
top_p_renorm_prob = top_p_renorm_probs
def _top_p_sampling_from_probs_internal(
probs: torch.Tensor,
indices: Optional[torch.Tensor],
maybe_top_p_arr: Optional[torch.Tensor],
top_p_val: float,
deterministic: bool,
generator: Optional[torch.Generator],
) -> torch.Tensor:
with probs.device as device:
probs = probs.float()
maybe_top_p_arr = (
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
)
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
torch.ops.sgl_kernel.top_p_sampling_from_probs.default(
probs,
samples,
indices,
maybe_top_p_arr,
top_p_val,
deterministic,
generator,
)
return samples
def top_p_sampling_from_probs(
probs: torch.Tensor,
top_p: Union[torch.Tensor, float],
indices: Optional[torch.Tensor] = None,
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
check_nan: bool = False,
) -> torch.Tensor:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.
Parameters
----------
probs: torch.Tensor
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
probability distributions.
top_p: Union[torch.Tensor, float]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
indices: Optional[torch.Tensor]
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
This allows reusing the same probability distribution for multiple outputs.
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
generator: Optional[torch.Generator]
A random number generator for the operation.
check_nan: bool
Whether to check nan in :attr:`probs`, default is ``False``.
Returns
-------
samples: torch.Tensor
Sampled categories, shape ``(batch_size,)``.
Note
----
This function expects float32 inputs, and the output is int32.
"""
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return _top_p_sampling_from_probs_internal(
probs, indices, *_to_tensor_scalar_tuple(top_p), deterministic, generator
)
def _top_k_top_p_sampling_from_probs_internal(
probs: torch.Tensor,
indices: Optional[torch.Tensor],
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
maybe_top_p_arr: Optional[torch.Tensor],
top_p_val: float,
deterministic: bool,
generator: Optional[torch.Generator],
) -> torch.Tensor:
with probs.device as device:
probs = probs.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
maybe_top_p_arr = (
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
)
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
torch.ops.sgl_kernel.top_k_top_p_sampling_from_probs.default(
probs,
samples,
indices,
maybe_top_k_arr,
top_k_val,
maybe_top_p_arr,
top_p_val,
deterministic,
generator,
)
return samples
def top_k_top_p_sampling_from_probs(
probs: torch.Tensor,
top_k: Union[torch.Tensor, int],
top_p: Union[torch.Tensor, float],
indices: Optional[torch.Tensor] = None,
filter_apply_order: str = "top_k_first",
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
check_nan: bool = False,
) -> torch.Tensor:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for top-k and top-p sampling from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.
Parameters
----------
probs: torch.Tensor
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
probability distributions.
top_k: Union[torch.Tensor, int]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
top_p: Union[torch.Tensor, float]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
indices: Optional[torch.Tensor]
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
This allows reusing the same probability distribution for multiple outputs.
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
filter_apply_order: str
The order of applying top-k and top-p sampling, should be either ``"top_k_first"`` or ``"joint"``.
If ``"top_k_first"``, we first apply top-k filter, then apply top-p sampling on the top-k results.
If ``"joint"``, we apply top-k and top-p filter simultaneously in each round. Default is ``"top_k_first"``.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
generator: Optional[torch.Generator]
A random number generator for the operation.
check_nan: bool
Whether to check nan in :attr:`probs`, default is ``False``.
Returns
-------
samples: torch.Tensor
Sampled categories, shape ``(batch_size,)``.
Note
----
This function expects float32 inputs, and the output is int32.
"""
if filter_apply_order == "top_k_first":
renorm_probs = top_k_renorm_probs(probs, top_k)
return top_p_sampling_from_probs(
renorm_probs,
top_p,
indices,
deterministic,
check_nan=check_nan,
generator=generator,
)
elif filter_apply_order == "joint":
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return _top_k_top_p_sampling_from_probs_internal(
probs,
indices,
*_to_tensor_scalar_tuple(top_k),
*_to_tensor_scalar_tuple(top_p),
deterministic,
generator,
)
else:
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
def _min_p_sampling_from_probs_internal(
probs: torch.Tensor,
indices: Optional[torch.Tensor],
maybe_min_p_arr: Optional[torch.Tensor],
min_p_val: float,
deterministic: bool,
generator: Optional[torch.Generator],
) -> torch.Tensor:
with probs.device as device:
probs = probs.float()
maybe_min_p_arr = (
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
)
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
torch.ops.sgl_kernel.min_p_sampling_from_probs.default(
probs,
samples,
indices,
maybe_min_p_arr,
min_p_val,
deterministic,
generator,
)
return samples
def min_p_sampling_from_probs(
probs: torch.Tensor,
min_p: Union[torch.Tensor, float],
indices: Optional[torch.Tensor] = None,
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
check_nan: bool = False,
) -> torch.Tensor:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for `min_p sampling <https://arxiv.org/abs/2407.01082>`_ from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.
Parameters
----------
probs: torch.Tensor
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
probability distributions.
min_p: Union[torch.Tensor, float]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for min-p sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
indices: Optional[torch.Tensor]
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
This allows reusing the same probability distribution for multiple outputs.
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
generator: Optional[torch.Generator]
A random number generator for the operation.
check_nan: bool
Whether to check nan in :attr:`probs`, default is ``False``.
Returns
-------
samples: torch.Tensor
Sampled categories, shape ``(batch_size,)``.
Note
----
This function expects float32 inputs, and the output is int32.
"""
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return _min_p_sampling_from_probs_internal(
probs, indices, *_to_tensor_scalar_tuple(min_p), deterministic, generator
)
def _top_k_mask_logits_internal(
logits: torch.Tensor,
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
) -> torch.Tensor:
logits = logits.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
mask_logits = torch.empty_like(logits)
torch.ops.sgl_kernel.top_k_mask_logits.default(
logits, mask_logits, maybe_top_k_arr, top_k_val
)
return mask_logits
def top_k_mask_logits(
logits: torch.Tensor,
top_k: Union[torch.Tensor, int],
) -> torch.Tensor:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for masking logits by top-k thresholding.
Parameters
----------
logits: torch.Tensor
Logits before softmax, shape ``(batch_size, num_classes)``.
top_k: Union[torch.Tensor, int]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for
for masking logits, should be in ``(0, num_classes)``.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
We keep the top-k logits, set the rest to negative infinity.
Returns
-------
masked_logits: torch.Tensor
Masked logits, shape ``(batch_size, num_classes)``.
Examples
--------
>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 5
>>> top_k = 3
>>> logits = torch.randn(batch_size, vocab_size).to(0)
>>> logits
tensor([[ 1.9269, 1.4873, 0.9007, -2.1055, -0.7581],
[ 1.0783, 0.8008, 1.6806, 0.3559, -0.6866],
[-0.4934, 0.2415, -0.2316, 0.0418, -0.2516],
[ 0.8599, -0.3097, -0.3957, 0.8034, -0.6216]], device='cuda:0')
>>> masked_logits = flashinfer.sampling.top_k_mask_logits(logits, top_k)
>>> masked_logits
tensor([[ 1.9269, 1.4873, 0.9007, -inf, -inf],
[ 1.0783, 0.8008, 1.6806, -inf, -inf],
[ -inf, 0.2415, -0.2316, 0.0418, -inf],
[ 0.8599, -0.3097, -inf, 0.8034, -inf]], device='cuda:0')
Note
----
The combination of ``top_k_mask_logits`` and ``softmax`` should be equivalent to ``top_k_renorm_probs``.
See Also
--------
top_k_renorm_probs
"""
return _top_k_mask_logits_internal(logits, *_to_tensor_scalar_tuple(top_k))
def top_k_top_p_sampling_from_logits(
logits: torch.Tensor,
top_k: Union[torch.Tensor, int],
top_p: Union[torch.Tensor, float],
indices: Optional[torch.Tensor] = None,
filter_apply_order: str = "top_k_first",
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
check_nan: bool = False,
) -> torch.Tensor:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for top-k and top-p sampling from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.
Parameters
----------
logits: torch.Tensor
Pre-softmax logits for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
and the i-th output will be sampled from the i-th row of logits. When indices is provided,
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
probability distributions.
top_k: Union[torch.Tensor, int]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
top_p: Union[torch.Tensor, float]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
indices: Optional[torch.Tensor]
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
This allows reusing the same probability distribution for multiple outputs.
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
filter_apply_order: str
The order of applying top-k and top-p sampling, should be either ``"top_k_first"`` or ``"joint"``.
If ``"top_k_first"``, we first apply top-k filter, then apply top-p sampling on the top-k results.
If ``"joint"``, we apply top-k and top-p filter simultaneously in each round. Default is ``"top_k_first"``.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
generator: Optional[torch.Generator]
A random number generator for the operation.
check_nan: bool
Whether to check nan in :attr:`probs`, default is ``False``.
Returns
-------
samples: torch.Tensor
Sampled categories, shape ``(batch_size,)``.
Note
----
This function expects float32 inputs, and the output is int32.
"""
if filter_apply_order == "top_k_first":
masked_logits = top_k_mask_logits(logits, top_k)
probs = torch.softmax(masked_logits, dim=-1)
return top_p_sampling_from_probs(
probs,
top_p,
indices,
deterministic,
check_nan=check_nan,
generator=generator,
)
elif filter_apply_order == "joint":
probs = torch.softmax(logits, dim=-1)
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return _top_k_top_p_sampling_from_probs_internal(
probs,
indices,
*_to_tensor_scalar_tuple(top_k),
*_to_tensor_scalar_tuple(top_p),
deterministic,
generator,
)
else:
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")

View File

@@ -0,0 +1,352 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import struct
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Union
_SCALAR_TYPES_ID_MAP = {}
# Mirrors enum in `core/scalar_type.hpp`
class NanRepr(Enum):
NONE = 0 # nans are not supported
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
# This ScalarType class is a parallel implementation of the C++ ScalarType
# class found in csrc/core/scalar_type.hpp. These two classes should be kept
# in sync until the inductor fully supports custom C++ classes.
@dataclass(frozen=True)
class ScalarType:
"""
ScalarType can represent a wide range of floating point and integer
types, in particular it can be used to represent sub-byte data types
(something that torch.dtype currently does not support). It is also
capable of representing types with a bias, i.e.:
`stored_value = value + bias`,
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
of 8). The implementation for this class can be found in
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
with that file.
"""
exponent: int
"""
Number of bits in the exponent if this is a floating point type
(zero if this an integer type)
"""
mantissa: int
"""
Number of bits in the mantissa if this is a floating point type,
or the number bits representing an integer excluding the sign bit if
this an integer type.
"""
signed: bool
"If the type is signed (i.e. has a sign bit)"
bias: int
"""
bias used to encode the values in this scalar type
(value = stored_value - bias, default 0) for example if we store the
type as an unsigned integer with a bias of 128 then the value 0 will be
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
"""
_finite_values_only: bool = False
"""
Private: if infs are supported, used `has_infs()` instead.
"""
nan_repr: NanRepr = NanRepr.IEEE_754
"""
How NaNs are represent in this scalar type, returns NanRepr value.
(not applicable for integer types)
"""
def _floating_point_max_int(self) -> int:
assert (
self.mantissa <= 52 and self.exponent <= 11
), f"Cannot represent max/min as a double for type {self.__str__()}"
max_mantissa = (1 << self.mantissa) - 1
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
max_mantissa = max_mantissa - 1
max_exponent = (1 << self.exponent) - 2
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN or self.nan_repr == NanRepr.NONE:
assert (
self.exponent < 11
), f"Cannot represent max/min as a double for type {self.__str__()}"
max_exponent = max_exponent + 1
# adjust the exponent to match that of a double
# for now we assume the exponent bias is the standard 2^(e-1) -1, (where
# e is the exponent bits), there is some precedent for non-standard
# biases, example `float8_e4m3b11fnuz` here:
# https://github.com/jax-ml/ml_dtypes but to avoid premature over
# complication we are just assuming the standard exponent bias until
# there is a need to support non-standard biases
exponent_bias = (1 << (self.exponent - 1)) - 1
exponent_bias_double = (1 << 10) - 1 # double e = 11
max_exponent_double = max_exponent - exponent_bias + exponent_bias_double
# shift the mantissa and exponent into the proper positions for an
# IEEE double and bitwise-or them together.
return (max_mantissa << (52 - self.mantissa)) | (max_exponent_double << 52)
def _floating_point_max(self) -> float:
double_raw = self._floating_point_max_int()
return struct.unpack("!d", struct.pack("!Q", double_raw))[0]
def _raw_max(self) -> Union[int, float]:
if self.is_floating_point():
return self._floating_point_max()
else:
assert (
self.size_bits < 64 or self.size_bits == 64 and self.is_signed()
), "Cannot represent max as an int"
return (1 << self.mantissa) - 1
def _raw_min(self) -> Union[int, float]:
if self.is_floating_point():
assert (
self.is_signed()
), "We currently assume all floating point types are signed"
sign_bit_double = 1 << 63
max_raw = self._floating_point_max_int()
min_raw = max_raw | sign_bit_double
return struct.unpack("!d", struct.pack("!Q", min_raw))[0]
else:
assert (
not self.is_signed() or self.size_bits <= 64
), "Cannot represent min as a int64_t"
if self.is_signed():
return -(1 << (self.size_bits - 1))
else:
return 0
@functools.cached_property
def id(self) -> int:
"""
Convert the ScalarType to an int which can be passed to pytorch custom
ops. This layout of the int must be kept in sync with the C++
ScalarType's from_id method.
"""
val = 0
offset = 0
def or_and_advance(member, bit_width):
nonlocal val
nonlocal offset
bit_mask = (1 << bit_width) - 1
val = val | (int(member) & bit_mask) << offset
offset = offset + bit_width
or_and_advance(self.exponent, 8)
or_and_advance(self.mantissa, 8)
or_and_advance(self.signed, 1)
or_and_advance(self.bias, 32)
or_and_advance(self._finite_values_only, 1)
or_and_advance(self.nan_repr.value, 8)
assert offset <= 64, f"ScalarType fields too big {offset} to fit into an int64"
_SCALAR_TYPES_ID_MAP[val] = self
return val
@property
def size_bits(self) -> int:
return self.exponent + self.mantissa + int(self.signed)
def min(self) -> Union[int, float]:
"""
Min representable value for this scalar type.
(accounting for bias if there is one)
"""
return self._raw_min() - self.bias
def max(self) -> Union[int, float]:
"""
Max representable value for this scalar type.
(accounting for bias if there is one)
"""
return self._raw_max() - self.bias
def is_signed(self) -> bool:
"""
If the type is signed (i.e. has a sign bit), same as `signed`
added for consistency with:
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
"""
return self.signed
def is_floating_point(self) -> bool:
"If the type is a floating point type"
return self.exponent != 0
def is_integer(self) -> bool:
"If the type is an integer type"
return self.exponent == 0
def has_bias(self) -> bool:
"If the type has a non-zero bias"
return self.bias != 0
def has_infs(self) -> bool:
"If the type is floating point and supports infinity"
return not self._finite_values_only
def has_nans(self) -> bool:
return self.nan_repr != NanRepr.NONE
def is_ieee_754(self) -> bool:
"""
If the type is a floating point type that follows IEEE 754
conventions
"""
return self.nan_repr == NanRepr.IEEE_754 and not self._finite_values_only
def __str__(self) -> str:
"""
naming generally follows: https://github.com/jax-ml/ml_dtypes
for floating point types (leading f) the scheme is:
`float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
flags:
- no-flags: means it follows IEEE 754 conventions
- f: means finite values only (no infinities)
- n: means nans are supported (non-standard encoding)
for integer types the scheme is:
`[u]int<size_bits>[b<bias>]`
- if bias is not present it means its zero
"""
if self.is_floating_point():
ret = (
"float"
+ str(self.size_bits)
+ "_e"
+ str(self.exponent)
+ "m"
+ str(self.mantissa)
)
if not self.is_ieee_754():
if self._finite_values_only:
ret = ret + "f"
if self.nan_repr != NanRepr.NONE:
ret = ret + "n"
return ret
else:
ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
if self.has_bias():
ret = ret + "b" + str(self.bias)
return ret
def __repr__(self) -> str:
return "ScalarType." + self.__str__()
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
# opcheck to work.
def __len__(self) -> int:
raise TypeError
#
# Convenience Constructors
#
@classmethod
def int_(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
"Create a signed integer scalar type (size_bits includes sign-bit)."
ret = cls(0, size_bits - 1, True, bias if bias else 0)
ret.id # noqa B018: make sure the id is cached
return ret
@classmethod
def uint(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
"""Create a unsigned integer scalar type."""
ret = cls(0, size_bits, False, bias if bias else 0)
ret.id # noqa B018: make sure the id is cached
return ret
@classmethod
def float_IEEE754(cls, exponent: int, mantissa: int) -> "ScalarType":
"""
Create a standard floating point type
(i.e. follows IEEE 754 conventions).
"""
assert mantissa > 0 and exponent > 0
ret = cls(exponent, mantissa, True, 0)
ret.id # noqa B018: make sure the id is cached
return ret
@classmethod
def float_(
cls, exponent: int, mantissa: int, finite_values_only: bool, nan_repr: NanRepr
) -> "ScalarType":
"""
Create a non-standard floating point type
(i.e. does not follow IEEE 754 conventions).
"""
assert mantissa > 0 and exponent > 0
assert nan_repr != NanRepr.IEEE_754, (
"use `float_IEEE754` constructor for floating point types that "
"follow IEEE 754 conventions"
)
ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
ret.id # noqa B018: make sure the id is cached
return ret
@classmethod
def from_id(cls, scalar_type_id: int):
if scalar_type_id not in _SCALAR_TYPES_ID_MAP:
raise ValueError(f"scalar_type_id {scalar_type_id} doesn't exists.")
return _SCALAR_TYPES_ID_MAP[scalar_type_id]
# naming generally follows: https://github.com/jax-ml/ml_dtypes
# for floating point types (leading f) the scheme is:
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
# flags:
# - no-flags: means it follows IEEE 754 conventions
# - f: means finite values only (no infinities)
# - n: means nans are supported (non-standard encoding)
# for integer types the scheme is:
# `[u]int<size_bits>[b<bias>]`
# - if bias is not present it means its zero
class scalar_types:
int4 = ScalarType.int_(4, None)
uint4 = ScalarType.uint(4, None)
int8 = ScalarType.int_(8, None)
uint8 = ScalarType.uint(8, None)
float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
# fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE)
# "gptq" types
uint2b2 = ScalarType.uint(2, 2)
uint3b4 = ScalarType.uint(3, 4)
uint4b8 = ScalarType.uint(4, 8)
uint8b128 = ScalarType.uint(8, 128)
# colloquial names
bfloat16 = float16_e8m7
float16 = float16_e5m10

View File

@@ -0,0 +1,293 @@
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
# Sparse attention utils
def convert_vertical_slash_indexes(
q_seqlens: torch.Tensor, # [BATCH, ]
kv_seqlens: torch.Tensor, # [BATCH, ]
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
context_size: int,
block_size_M: int,
block_size_N: int,
causal: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = slash_indexes.size(0)
num_heads = slash_indexes.size(1)
nnz_slash = slash_indexes.size(2)
nnz_vertical = vertical_indexes.size(2)
num_rows = (context_size + block_size_M - 1) // block_size_M
block_count = torch.zeros(
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
)
block_offset = torch.zeros(
batch_size,
num_heads,
num_rows,
nnz_slash,
dtype=q_seqlens.dtype,
device=q_seqlens.device,
)
column_count = torch.zeros(
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
)
column_index = torch.zeros(
batch_size,
num_heads,
num_rows,
nnz_vertical,
dtype=q_seqlens.dtype,
device=q_seqlens.device,
)
torch.ops.sgl_kernel.convert_vertical_slash_indexes.default(
block_count,
block_offset,
column_count,
column_index,
q_seqlens,
kv_seqlens,
vertical_indexes,
slash_indexes,
context_size,
block_size_M,
block_size_N,
causal,
)
return block_count, block_offset, column_count, column_index
def convert_vertical_slash_indexes_mergehead(
q_seqlens: torch.Tensor, # [BATCH, ]
kv_seqlens: torch.Tensor, # [BATCH, ]
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
# [N_HEADS] : different head use different number of indices
vertical_indices_count: torch.Tensor,
slash_indices_count: torch.Tensor,
context_size: int,
block_size_M: int,
block_size_N: int,
causal: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = slash_indexes.size(0)
num_heads = slash_indexes.size(1)
nnz_slash = slash_indexes.size(2)
nnz_vertical = vertical_indexes.size(2)
num_rows = (context_size + block_size_M - 1) // block_size_M
block_count = torch.empty(
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
)
block_offset = torch.empty(
batch_size,
num_heads,
num_rows,
nnz_slash,
dtype=q_seqlens.dtype,
device=q_seqlens.device,
)
column_count = torch.empty(
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
)
column_index = torch.empty(
batch_size,
num_heads,
num_rows,
nnz_vertical,
dtype=q_seqlens.dtype,
device=q_seqlens.device,
)
torch.ops.sgl_kernel.convert_vertical_slash_indexes_mergehead.default(
block_count,
block_offset,
column_count,
column_index,
q_seqlens,
kv_seqlens,
vertical_indexes,
slash_indexes,
vertical_indices_count,
slash_indices_count,
context_size,
block_size_M,
block_size_N,
causal,
)
return block_count, block_offset, column_count, column_index
def sparse_attn_func(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
dropout_p=0.0,
softmax_scale=None,
causal=False,
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
return_softmax_lse=False,
out=None,
):
"""Compute attention with vertical and slash sparsity patterns.
Most Arguments are the same with the flash_attn_func interface, except for 4 extra args:
block_count and block_offset for slash sparsity patterns, and
column_count and column_index for vertical sparsity patterns.
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse = torch.ops.sgl_kernel.fwd_sparse.default(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
out,
alibi_slopes,
dropout_p,
softmax_scale,
causal,
softcap,
return_attn_probs and dropout_p > 0,
None,
)
return (out, softmax_lse) if return_softmax_lse else out
def sparse_attn_varlen_func(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=False,
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
return_softmax_lse=False,
out=None,
):
"""Compute attention with vertical and slash sparsity patterns.
Most Arguments are the same with the flash_attn_varlen_func interface, except for 4 extra args:
block_count and block_offset for slash sparsity patterns, and
column_count and column_index for vertical sparsity patterns.
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse = torch.ops.sgl_kernel.varlen_fwd_sparse.default(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
out,
cu_seqlens_q,
cu_seqlens_k,
None,
alibi_slopes,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
False,
causal,
softcap,
return_attn_probs and dropout_p > 0,
None,
)
return (out, softmax_lse) if return_softmax_lse else out

View File

@@ -0,0 +1,63 @@
import torch
from torch.cuda.streams import ExternalStream
try:
from . import spatial_ops # triggers TORCH extension registration
except Exception as _e:
_spatial_import_error = _e
else:
_spatial_import_error = None
_IMPORT_ERROR = ImportError(
"Failed to load sgl_kernel.spatial_ops extension. Ensure CUDA Driver >= 12.4"
)
def create_greenctx_stream_by_value(
SM_a: int, SM_b: int, device_id: int = None
) -> tuple[ExternalStream, ExternalStream]:
"""
Create two streams for greenctx.
Args:
sm_A (int): The SM of stream A.
sm_B (int): The weight of stream B.
device_id (int): The device id.
Returns:
tuple[ExternalStream, ExternalStream]: The two streams.
"""
if _spatial_import_error is not None:
raise _IMPORT_ERROR from _spatial_import_error
if device_id is None:
device_id = torch.cuda.current_device()
res = torch.ops.sgl_kernel.create_greenctx_stream_by_value(SM_a, SM_b, device_id)
stream_a = ExternalStream(
stream_ptr=res[0], device=torch.device(f"cuda:{device_id}")
)
stream_b = ExternalStream(
stream_ptr=res[1], device=torch.device(f"cuda:{device_id}")
)
return stream_a, stream_b
def get_sm_available(device_id: int = None) -> int:
"""
Get the SMs available on the device.
Args:
device_id (int): The device id.
Returns:
int: The SMs available.
"""
if _spatial_import_error is not None:
raise _IMPORT_ERROR from _spatial_import_error
if device_id is None:
device_id = torch.cuda.current_device()
device_props = torch.cuda.get_device_properties(device_id)
# Get the number of Streaming Multiprocessors (SMs)
sm_count = device_props.multi_processor_count
return sm_count

View File

@@ -0,0 +1,107 @@
import torch
from sgl_kernel.utils import get_cuda_stream
def tree_speculative_sampling_target_only(
predicts: torch.Tensor, # mutable
accept_index: torch.Tensor, # mutable
accept_token_num: torch.Tensor, # mutable
candidates: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
uniform_samples: torch.Tensor,
uniform_samples_for_final_sampling: torch.Tensor,
target_probs: torch.Tensor,
draft_probs: torch.Tensor,
threshold_single: float = 1.0,
threshold_acc: float = 1.0,
deterministic: bool = True,
) -> None:
torch.ops.sgl_kernel.tree_speculative_sampling_target_only.default(
predicts,
accept_index,
accept_token_num,
candidates,
retrive_index,
retrive_next_token,
retrive_next_sibling,
uniform_samples,
uniform_samples_for_final_sampling,
target_probs,
draft_probs,
threshold_single,
threshold_acc,
deterministic,
get_cuda_stream(),
)
def verify_tree_greedy(
predicts: torch.Tensor, # mutable
accept_index: torch.Tensor, # mutable
accept_token_num: torch.Tensor, # mutable
candidates: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
target_predict: torch.Tensor,
) -> None:
torch.ops.sgl_kernel.verify_tree_greedy.default(
predicts,
accept_index,
accept_token_num,
candidates,
retrive_index,
retrive_next_token,
retrive_next_sibling,
target_predict,
get_cuda_stream(),
)
def build_tree_kernel_efficient(
parent_list: torch.Tensor,
selected_index: torch.Tensor,
verified_seq_len: torch.Tensor,
tree_mask: torch.Tensor,
positions: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
topk: int,
depth: int,
draft_token_num: int,
tree_mask_mode: int,
) -> None:
torch.ops.sgl_kernel.build_tree_kernel_efficient.default(
parent_list,
selected_index,
verified_seq_len,
tree_mask,
positions,
retrive_index,
retrive_next_token,
retrive_next_sibling,
topk,
depth,
draft_token_num,
tree_mask_mode,
)
def segment_packbits(
x: torch.Tensor,
input_indptr: torch.Tensor,
output_indptr: torch.Tensor,
y: torch.Tensor,
batch_size: int,
) -> None:
torch.ops.sgl_kernel.segment_packbits.default(
x,
input_indptr,
output_indptr,
y,
batch_size,
torch.cuda.current_stream().cuda_stream,
)

View File

@@ -0,0 +1,217 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import pytest
import torch
from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
# vLLM torch native
def _apply_rotary_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
class RotaryEmbedding(torch.nn.Module):
# Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
cache = self._compute_cos_sin_cache()
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
inv_freq = 1.0 / (
base
** (
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
)
)
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-native implementation of forward()."""
if offsets is not None:
positions = positions + offsets
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
# Modification: float32 is required for the rotary embedding to work correctly
query = query.to(torch.float32)
key = key.to(torch.float32)
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
# Modification: convert to the correct dtype
query = query.to(self.dtype)
key = key.to(self.dtype)
return query, key
class FlashInferRotaryEmbedding(RotaryEmbedding):
def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
apply_rope_with_cos_sin_cache_inplace(
positions=positions,
query=query,
key=key,
fused_set_kv_buffer_arg=fused_set_kv_buffer_arg,
head_size=self.head_size,
cos_sin_cache=self.cos_sin_cache,
is_neox=self.is_neox_style,
)
return query, key
class MHATokenToKVPool:
KV_POOL_SIZE = 16384
def __init__(
self,
head_num: int,
head_dim: int,
):
self.head_num = head_num
self.head_dim = head_dim
self.size = MHATokenToKVPool.KV_POOL_SIZE
self.page_size = 1
self.store_dtype = torch.bfloat16
self.device = "cuda"
self.layer_num = 1
self.start_layer = 0
self._create_buffers()
def _create_buffers(self):
self.k_buffer = [
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.v_buffer = [
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
def set_kv_buffer(
self,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
):
layer_id = 0
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
def create_inputs(
head_size: int,
batch_size: int,
seq_len: int,
device,
dtype: torch.dtype,
num_q_heads: int,
num_kv_heads: int,
):
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
query = torch.randn(
batch_size * seq_len, num_q_heads * head_size, dtype=dtype, device=device
)
key = torch.randn(
batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device
)
value = torch.randn(
batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device
)
out_cache_loc = torch.randperm(
MHATokenToKVPool.KV_POOL_SIZE, dtype=torch.int64, device=device
)[: batch_size * seq_len].clone()
return dict(
pos_ids=pos_ids, query=query, key=key, value=value, out_cache_loc=out_cache_loc
)

View File

@@ -0,0 +1,11 @@
import torch
def fast_topk(values, topk, dim):
if topk == 1:
# Use max along the specified dimension to get both value and index
return torch.max(values, dim=dim, keepdim=True)
else:
# Use topk for efficiency with larger k values
# TODO: implement faster cuda kernels for large vocab sizes
return torch.topk(values, topk, dim=dim)

View File

@@ -0,0 +1,50 @@
# Copyright 2025 SGLang Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import functools
from typing import Dict, Tuple
import torch
def get_cuda_stream() -> int:
return torch.cuda.current_stream().cuda_stream
_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {}
def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor:
key = (name, device)
buf = _cache_buf.get(key)
if buf is None:
buf = torch.empty(bytes, dtype=torch.uint8, device=device)
_cache_buf[key] = buf
return buf
def _to_tensor_scalar_tuple(x):
if isinstance(x, torch.Tensor):
return (x, 0)
else:
return (None, x)
@functools.lru_cache(maxsize=1)
def is_arch_support_pdl() -> bool:
# Hopper arch's compute capability == 9.0
device = torch.cuda.current_device()
major, minor = torch.cuda.get_device_capability(device)
return major >= 9

View File

@@ -0,0 +1 @@
__version__ = "0.3.8"

View File

@@ -0,0 +1,10 @@
# ninja log v5
4 12793 1756950714004023222 /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/speculative/eagle_utils.o c4ef5c8f5ca38169
4 22797 1756950724004023564 /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/grammar/apply_token_bitmask_inplace_cuda.o 983ca8e755dab2fa
3 24387 1756950725592023618 /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/allreduce/custom_all_reduce.o 95a36ff854253806
4 29769 1756950730660023792 /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/common_extension_rocm.o d99ffa7422128b8b
3 47747 1756950748952024417 /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/allreduce/quick_all_reduce.o ffc38859847f9f7
32 22789 1756951323296044067 /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/kvcacheio/transfer.o 51f92c934bf3b3a4
32 22927 1756951323432044072 /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/moe/moe_align_kernel.o 3cbca550065cd70f
32 23521 1756951324028044092 /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/elementwise/activation.o c8aa216837c116a0
33 26809 1756951327312044205 /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/moe/moe_topk_softmax_kernels.o a28bfee01bc84adc

View File

@@ -0,0 +1,38 @@
ninja_required_version = 1.3
cxx = c++
nvcc = /opt/dtk/bin/hipcc
cflags = -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -fPIC -I/home/git_sglang/sglang/sgl-kernel/include -I/home/git_sglang/sglang/sgl-kernel/csrc -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10 -c
post_cflags = -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -O3 -Wno-switch-bool -Wno-macro-redefined -Wno-deprecated-declarations -w -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=common_ops -D_GLIBCXX_USE_CXX11_ABI=1 -std=c++17
cuda_cflags = -I/home/git_sglang/sglang/sgl-kernel/include -I/home/git_sglang/sglang/sgl-kernel/csrc -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/lib/python3.10/dist-packages/torch/include/THH -I/opt/dtk/include -I/usr/include/python3.10 -c
cuda_post_cflags = -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -fPIC -O3 -std=c++17 -D__HIP_PLATFORM_HCC__=1 --offload-arch=gfx928 --offload-arch=gfx936 --gpu-max-threads-per-block=1024 -Wno-macro-redefined '' -funroll-loops -Rpass-analysis=unroll-loops -w -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1014"' -DTORCH_EXTENSION_NAME=common_ops -D_GLIBCXX_USE_CXX11_ABI=1 -fno-gpu-rdc
cuda_dlink_post_cflags =
ldflags =
rule compile
command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
depfile = $out.d
deps = gcc
rule cuda_compile
command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags
build /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/allreduce/custom_all_reduce.o: cuda_compile /home/git_sglang/sglang/sgl-kernel/csrc/allreduce/custom_all_reduce.hip
build /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/allreduce/quick_all_reduce.o: cuda_compile /home/git_sglang/sglang/sgl-kernel/csrc/allreduce/quick_all_reduce.hip
build /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/common_extension_rocm.o: compile /home/git_sglang/sglang/sgl-kernel/csrc/common_extension_rocm.cc
build /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/elementwise/activation.o: cuda_compile /home/git_sglang/sglang/sgl-kernel/csrc/elementwise/activation.hip
build /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/grammar/apply_token_bitmask_inplace_cuda.o: cuda_compile /home/git_sglang/sglang/sgl-kernel/csrc/grammar/apply_token_bitmask_inplace_cuda.hip
build /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/kvcacheio/transfer.o: cuda_compile /home/git_sglang/sglang/sgl-kernel/csrc/kvcacheio/transfer.hip
build /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/moe/moe_align_kernel.o: cuda_compile /home/git_sglang/sglang/sgl-kernel/csrc/moe/moe_align_kernel.hip
build /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/moe/moe_topk_softmax_kernels.o: cuda_compile /home/git_sglang/sglang/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.hip
build /home/git_sglang/sglang/sgl-kernel/build/temp.linux-x86_64-cpython-310/csrc/speculative/eagle_utils.o: cuda_compile /home/git_sglang/sglang/sgl-kernel/csrc/speculative/eagle_utils.hip

View File

@@ -0,0 +1,21 @@
# Adapt from: https://github.com/neuralmagic/vllm-flash-attention/blob/main/cmake/utils.cmake
#
# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in
# `CUDA_ARCH_FLAGS`.
#
# Example:
# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75"
# clear_cuda_arches(CUDA_ARCH_FLAGS)
# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75"
# CMAKE_CUDA_FLAGS="-Wall"
#
macro(clear_cuda_arches CUDA_ARCH_FLAGS)
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS
${CMAKE_CUDA_FLAGS})
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
# and passed back via the `CUDA_ARCHITECTURES` property.
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
${CMAKE_CUDA_FLAGS})
endmacro()

View File

@@ -0,0 +1,137 @@
// Adapted from: https://github.com/vllm-project/vllm/blob/v0.8.2/csrc/custom_all_reduce.cu
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/all.h>
#include "custom_all_reduce.cuh"
// Fake pointer type, must match fptr_t type in ops.h.
// We use this type alias to indicate when pointers are passed in as int64_t.
using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));
fptr_t
init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs, torch::Tensor& rank_data, int64_t rank, bool full_nvlink) {
int world_size = fake_ipc_ptrs.size();
if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported");
if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now");
if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in");
sglang::Signal* ipc_ptrs[8];
for (int i = 0; i < world_size; i++) {
ipc_ptrs[i] = reinterpret_cast<sglang::Signal*>(fake_ipc_ptrs[i]);
}
return (fptr_t) new sglang::CustomAllreduce(
ipc_ptrs, rank_data.data_ptr(), rank_data.numel(), rank, world_size, full_nvlink);
}
/**
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
* because it allows transpose of contiguous slice (i.e. slicing the first
* dimension). Currently, we require this because stride information is not
* passed into the kernels and we treat input tensors as flat.
*
* Examples
* A = torch.zeros(3, 3, 3)
* 1. A: OK
* 2. A[1:]: OK
* 3. A.permute(2, 0, 1): OK
* 4. A[1:].permute(2, 0, 1): OK
* 5. A[None].expand(2, -1, -1, -1): Not OK
* 6. A[:, 1:, 1:]: Not OK
*/
bool _is_weak_contiguous(torch::Tensor& t) {
return t.is_contiguous() ||
(t.storage().nbytes() - t.storage_offset() * t.element_size() == t.numel() * t.element_size());
}
/**
* Performs an out-of-place allreduce and stores result in out.
*
* If _reg_buffer is null, assumes inp.data_ptr() is already IPC-registered.
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
* copied into _reg_buffer.
*/
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(_is_weak_contiguous(out));
TORCH_CHECK(_is_weak_contiguous(inp));
auto input_size = inp.numel() * inp.element_size();
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
if (reg_buffer) {
TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes);
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size, cudaMemcpyDeviceToDevice, stream));
} else {
reg_buffer = inp.data_ptr();
}
switch (out.scalar_type()) {
case at::ScalarType::Float: {
fa->allreduce<float>(
stream, reinterpret_cast<float*>(reg_buffer), reinterpret_cast<float*>(out.data_ptr()), out.numel());
break;
}
case at::ScalarType::Half: {
fa->allreduce<half>(
stream, reinterpret_cast<half*>(reg_buffer), reinterpret_cast<half*>(out.data_ptr()), out.numel());
break;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: {
fa->allreduce<nv_bfloat16>(
stream,
reinterpret_cast<nv_bfloat16*>(reg_buffer),
reinterpret_cast<nv_bfloat16*>(out.data_ptr()),
out.numel());
break;
}
#endif
default:
throw std::runtime_error("custom allreduce only supports float32, float16 and bfloat16");
}
}
void dispose(fptr_t _fa) {
delete reinterpret_cast<sglang::CustomAllreduce*>(_fa);
}
int64_t meta_size() {
return sizeof(sglang::Signal);
}
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
void* ipc_ptrs[8];
for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
ipc_ptrs[i] = reinterpret_cast<void*>(fake_ipc_ptrs[i]);
}
fa->register_buffer(ipc_ptrs);
}
// Use vector<int64_t> to represent byte data for python binding compatibility.
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa) {
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
auto [handle, offsets] = fa->get_graph_buffer_ipc_meta();
std::vector<int64_t> bytes(handle.begin(), handle.end());
return std::make_tuple(bytes, offsets);
}
// Use vector<int64_t> to represent byte data for python binding compatibility.
void register_graph_buffers(
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets) {
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
std::vector<std::string> bytes;
bytes.reserve(handles.size());
for (int i = 0; i < handles.size(); i++) {
bytes.emplace_back(handles[i].begin(), handles[i].end());
}
bytes.reserve(handles.size());
fa->register_graph_buffers(bytes, offsets);
}

View File

@@ -0,0 +1,489 @@
// Adapted from https://github.com/vllm-project/vllm/blob/v0.8.2/csrc/custom_all_reduce.cuh
#pragma once
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <array>
#include <iostream>
#include <limits>
#include <map>
#include <unordered_map>
#include <vector>
#include "utils.h"
namespace sglang {
constexpr int kMaxBlocks = 36;
// Counter may overflow, but it's fine since unsigned int overflow is
// well-defined behavior.
using FlagType = uint32_t;
struct Signal {
alignas(128) FlagType self_counter[kMaxBlocks][8];
// Two sets of peer counters are needed for two syncs. The reason is that
// it's possible for peer GPU block to arrive at the second sync point while
// the current GPU block haven't passed the first sync point. Thus, peer GPU
// may write counter+1 while current GPU is busy waiting for counter. We use
// alternating counter array to avoid this possibility.
alignas(128) FlagType peer_counter[2][kMaxBlocks][8];
};
struct __align__(16) RankData {
const void* __restrict__ ptrs[8];
};
struct __align__(16) RankSignals {
Signal* signals[8];
};
// like std::array, but aligned
template <typename T, int sz>
struct __align__(alignof(T) * sz) array_t {
T data[sz];
using type = T;
static constexpr int size = sz;
};
// use packed type to maximize memory efficiency
// goal: generate ld.128 and st.128 instructions
template <typename T>
struct packed_t {
// the (P)acked type for load/store
using P = array_t<T, 16 / sizeof(T)>;
// the (A)ccumulator type for reduction
using A = array_t<float, 16 / sizeof(T)>;
};
#define DINLINE __device__ __forceinline__
// scalar cast functions
DINLINE float upcast_s(half val) {
return __half2float(val);
}
template <typename T>
DINLINE T downcast_s(float val);
template <>
DINLINE half downcast_s(float val) {
return __float2half(val);
}
// scalar add functions
// for some reason when compiling with Pytorch, the + operator for half and
// bfloat is disabled so we call the intrinsics directly
DINLINE half& assign_add(half& a, half b) {
a = __hadd(a, b);
return a;
}
DINLINE float& assign_add(float& a, float b) {
return a += b;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
DINLINE float upcast_s(nv_bfloat16 val) {
return __bfloat162float(val);
}
template <>
DINLINE nv_bfloat16 downcast_s(float val) {
return __float2bfloat16(val);
}
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
a = __hadd(a, b);
return a;
}
#endif
template <typename T, int N>
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
#pragma unroll
for (int i = 0; i < N; i++) {
assign_add(a.data[i], b.data[i]);
}
return a;
}
template <typename T, int N>
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
if constexpr (std::is_same<T, float>::value) {
return val;
} else {
array_t<float, N> out;
#pragma unroll
for (int i = 0; i < N; i++) {
out.data[i] = upcast_s(val.data[i]);
}
return out;
}
}
template <typename O>
DINLINE O downcast(array_t<float, O::size> val) {
if constexpr (std::is_same<typename O::type, float>::value) {
return val;
} else {
O out;
#pragma unroll
for (int i = 0; i < O::size; i++) {
out.data[i] = downcast_s<typename O::type>(val.data[i]);
}
return out;
}
}
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
#else
asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
#endif
}
static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
FlagType flag;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("ld.acquire.sys.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
#else
asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;" : "=r"(flag) : "l"(flag_addr));
#endif
return flag;
}
static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) {
asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
}
static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
FlagType flag;
asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
return flag;
}
// is_start: whether this is the very first synchronization barrier.
// need_fence: whether a memory fence is needed. If true, a release-acquire
// semantic is used to enforce memory access order before and after this
// barrier.
template <int ngpus, bool is_start, bool need_fence = false>
DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg, int rank) {
if constexpr (!is_start) __syncthreads();
static_assert(!(is_start && need_fence)); // Start barrier shouldn't need fence.
if (threadIdx.x < ngpus) {
// Increment the counter. Technically we only need one counter, but we use
// multiple per block to eliminate the need to share the counter via smem.
auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1;
// Write the expected counter value to peer and wait for correct value from
// peer.
auto peer_counter_ptr = &sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank];
auto self_counter_ptr = &self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x];
if constexpr (need_fence) {
st_flag_release(peer_counter_ptr, val);
while (ld_flag_acquire(self_counter_ptr) != val)
;
} else {
st_flag_volatile(peer_counter_ptr, val);
while (ld_flag_volatile(self_counter_ptr) != val)
;
}
}
if constexpr (is_start || need_fence) __syncthreads();
}
template <typename P, int ngpus, typename A>
DINLINE P packed_reduce(const P* ptrs[], int idx) {
A tmp = upcast(ptrs[0][idx]);
#pragma unroll
for (int i = 1; i < ngpus; i++) {
packed_assign_add(tmp, upcast(ptrs[i][idx]));
}
return downcast<P>(tmp);
}
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(
RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size) {
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
// note: we don't reorder the address so the accumulation order is the same
// for all ranks, ensuring bitwise identical results
auto dp = *_dp;
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
// do the actual reduction
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) {
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
}
multi_gpu_barrier<ngpus, false>(sg, self_sg, rank);
}
template <typename P>
DINLINE P* get_tmp_buf(Signal* sg) {
return (P*)(((Signal*)sg) + 1);
}
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(
RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
int part = size / ngpus;
int start = rank * part;
int end = rank == ngpus - 1 ? size : start + part;
int largest_part = part + size % ngpus;
const P* ptrs[ngpus];
P* tmps[ngpus];
#pragma unroll
for (int i = 0; i < ngpus; i++) {
int target = (rank + i) % ngpus;
ptrs[i] = (const P*)_dp->ptrs[target];
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
}
auto tmp_out = tmps[0];
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
// stage 1: reduce scatter
for (int idx = start + tid; idx < end; idx += stride) {
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
}
multi_gpu_barrier<ngpus, false, true>(sg, self_sg, rank);
// stage 2: allgather. Note: it's important to match the tid between
// the two stages, because visibility across devices is only guaranteed
// between threads that have the same tid. If thread i computes the sum of
// start + i in the first stage, then thread i also gathers start + i from all
// ranks.
for (int idx = tid; idx < largest_part; idx += stride) {
#pragma unroll
for (int i = 0; i < ngpus; i++) {
int gather_from_rank = ((rank + i) % ngpus);
if (gather_from_rank == ngpus - 1 || idx < part) {
int dst_idx = gather_from_rank * part + idx;
((P*)result)[dst_idx] = tmps[i][idx];
}
}
}
}
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
class CustomAllreduce {
public:
int rank_;
int world_size_;
bool full_nvlink_;
RankSignals sg_;
// Stores an map from a pointer to its peer pointters from all ranks.
std::unordered_map<void*, RankData*> buffers_;
Signal* self_sg_;
// Stores rank data from all ranks. This is mainly for cuda graph purposes.
// For cuda graph to work, all kernel arguments must be fixed during graph
// capture time. However, the peer pointers are not known during graph capture
// time. Therefore, during capture, we increment the rank data pointer and use
// that as the argument to the kernel. The kernel arguments are stored in
// graph_unreg_buffers_. The actual peer pointers will be filled in at the
// memory pointed to by the pointers in graph_unreg_buffers_ when
// the IPC handles are exchanged between ranks.
//
// The overall process looks like this:
// 1. Graph capture.
// 2. Each rank obtains the IPC handles for each addresses used during cuda
// graph capture using get_graph_buffer_ipc_meta.
// 3. (In Python) all gather the IPC handles.
// 4. Obtain the peer pointers by opening the IPC handles, and store them in
// the rank data array at corresponding positions.
RankData *d_rank_data_base_, *d_rank_data_end_;
std::vector<void*> graph_unreg_buffers_;
// a map from IPC handles to opened IPC pointers
std::map<IPC_KEY, char*> ipc_handles_;
/**
* Signals are an array of ipc-enabled buffers from all ranks.
* For each of the buffer, the layout is as follows:
* | -- sizeof(Signal) -- | ------ a few MB ----- |
* The first section is for allreduce synchronization, and the second section
* is for storing the intermediate results required by some allreduce algos.
*
* Note: this class does not own any device memory. Any required buffers
* are passed in from the constructor.
*/
CustomAllreduce(
Signal** signals, void* rank_data, size_t rank_data_sz, int rank, int world_size, bool full_nvlink = true)
: rank_(rank),
world_size_(world_size),
full_nvlink_(full_nvlink),
self_sg_(signals[rank]),
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
for (int i = 0; i < world_size_; i++) {
sg_.signals[i] = signals[i];
}
}
char* open_ipc_handle(const void* ipc_handle) {
auto [it, new_handle] = ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
if (new_handle) {
char* ipc_ptr;
CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle(
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle), cudaIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr;
}
return it->second;
}
std::pair<std::string, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
auto num_buffers = graph_unreg_buffers_.size();
auto handle_sz = sizeof(cudaIpcMemHandle_t);
std::string handles(handle_sz * num_buffers, static_cast<char>(0));
std::vector<int64_t> offsets(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto ptr = graph_unreg_buffers_[i];
void* base_ptr;
// note: must share the base address of each allocation, or we get wrong
// address
if (cuPointerGetAttribute(&base_ptr, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (CUdeviceptr)ptr) != CUDA_SUCCESS)
throw std::runtime_error("failed to get pointer attr");
CHECK_CUDA_SUCCESS(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
}
return std::make_pair(handles, offsets);
}
void check_rank_data_capacity(size_t num = 1) {
if (d_rank_data_base_ + num > d_rank_data_end_)
throw std::runtime_error(
"Rank data buffer is overflowed by " + std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
}
/**
* Register already-shared IPC pointers.
*/
void register_buffer(void** ptrs) {
check_rank_data_capacity();
RankData data;
for (int i = 0; i < world_size_; i++) {
data.ptrs[i] = ptrs[i];
}
auto d_data = d_rank_data_base_++;
CHECK_CUDA_SUCCESS(cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
buffers_[ptrs[rank_]] = d_data;
}
// Note: when registering graph buffers, we intentionally choose to not
// deduplicate the addresses. That means if the allocator reuses some
// addresses, they will be registered again. This is to account for the remote
// possibility of different allocation patterns between ranks. For example,
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
void
register_graph_buffers(const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets) {
auto num_buffers = graph_unreg_buffers_.size();
check_rank_data_capacity(num_buffers);
std::vector<RankData> rank_data(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto self_ptr = graph_unreg_buffers_[i];
auto& rd = rank_data[i];
for (int j = 0; j < world_size_; j++) {
if (j != rank_) {
char* handle = open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
handle += offsets[j][i];
rd.ptrs[j] = handle;
} else {
rd.ptrs[j] = self_ptr;
}
}
}
CHECK_CUDA_SUCCESS(
cudaMemcpy(d_rank_data_base_, rank_data.data(), sizeof(RankData) * num_buffers, cudaMemcpyHostToDevice));
d_rank_data_base_ += num_buffers;
graph_unreg_buffers_.clear();
}
/**
* Performs allreduce, assuming input has already been registered.
*
* Block and grid default configs are results after careful grid search. Using
* 36 blocks give the best or close to the best runtime on the devices I
* tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only
* take a small amount of SMs. Not quite sure the underlying reason, but my
* guess is that too many SMs will cause contention on NVLink bus.
*/
template <typename T>
void allreduce(cudaStream_t stream, T* input, T* output, int size, int threads = 512, int block_limit = 36) {
auto d = packed_t<T>::P::size;
if (size % d != 0)
throw std::runtime_error(
"custom allreduce currently requires input length to be multiple "
"of " +
std::to_string(d));
if (block_limit > kMaxBlocks)
throw std::runtime_error(
"max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit));
RankData* ptrs;
cudaStreamCaptureStatus status;
CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &status));
if (status == cudaStreamCaptureStatusActive) {
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
graph_unreg_buffers_.push_back(input);
} else {
auto it = buffers_.find(input);
if (it == buffers_.end())
throw std::runtime_error(
"buffer address " + std::to_string(reinterpret_cast<uint64_t>(input)) + " is not registered!");
ptrs = it->second;
}
size /= d;
auto bytes = size * sizeof(typename packed_t<T>::P);
int blocks = std::min(block_limit, (size + threads - 1) / threads);
#define KL(ngpus, name) name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, rank_, size);
// TODO(hanzhi713): Threshold is different for A100 and H100.
// Add per device threshold.
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
KL(ngpus, cross_device_reduce_1stage); \
} else if (full_nvlink_) { \
if ((world_size_ <= 4 && bytes < 512 * 1024) || (world_size_ <= 8 && bytes < 256 * 1024)) { \
KL(ngpus, cross_device_reduce_1stage); \
} else { \
KL(ngpus, cross_device_reduce_2stage); \
} \
} \
break; \
}
switch (world_size_) {
REDUCE_CASE(2)
REDUCE_CASE(4)
REDUCE_CASE(6)
REDUCE_CASE(8)
default:
throw std::runtime_error(
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
"gpus = " +
std::to_string(world_size_));
}
#undef REDUCE_CASE
#undef KL
}
~CustomAllreduce() {
for (auto [_, ptr] : ipc_handles_) {
CHECK_CUDA_SUCCESS(cudaIpcCloseMemHandle(ptr));
}
}
};
/**
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
a template instantiation:
* template void sglang::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
half *, int, int, int);
*/
} // namespace sglang

View File

@@ -0,0 +1,180 @@
// !!! This is a file automatically generated by hipify!!!
#include <ATen/hip/Exceptions.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
#include <torch/all.h>
#include "custom_all_reduce_hip.cuh"
// fake pointer type, must match fptr_t type in ops.h
using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets, int64_t rank,
bool full_nvlink) {
int world_size = offsets.size();
if (world_size > 8)
throw std::invalid_argument("world size > 8 is not supported");
if (world_size % 2 != 0)
throw std::invalid_argument("Odd num gpus is not supported for now");
if (world_size != handles.size())
throw std::invalid_argument(
"handles length should equal to offsets length");
if (rank < 0 || rank >= world_size)
throw std::invalid_argument("invalid rank passed in");
hipIpcMemHandle_t ipc_handles[8];
for (int i = 0; i < world_size; i++) {
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(hipIpcMemHandle_t));
}
return (fptr_t) new sglang::CustomAllreduce(
reinterpret_cast<sglang::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
}
/**
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
* because it allows transpose of contiguous slice (i.e. slicing the first
* dimension). Currently, we require this because stride information is not
* passed into the kernels and we treat input tensors as flat.
*
* Examples
* A = torch.zeros(3, 3, 3)
* 1. A: OK
* 2. A[1:]: OK
* 3. A.permute(2, 0, 1): OK
* 4. A[1:].permute(2, 0, 1): OK
* 5. A[None].expand(2, -1, -1, -1): Not OK
* 6. A[:, 1:, 1:]: Not OK
*/
bool _is_weak_contiguous(torch::Tensor& t) {
return t.is_contiguous() ||
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
t.numel() * t.element_size());
}
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
hipStream_t stream) {
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
TORCH_CHECK(_is_weak_contiguous(out));
switch (out.scalar_type()) {
case at::ScalarType::Float: {
fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()),
reinterpret_cast<float*>(out.data_ptr()),
out.numel());
break;
}
case at::ScalarType::Half: {
fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half*>(out.data_ptr()), out.numel());
break;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: {
fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()),
reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
break;
}
#endif
default:
throw std::runtime_error(
"custom allreduce only supports float32, float16 and bfloat16");
}
}
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
_all_reduce(_fa, inp, out, stream);
}
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor& out) {
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp));
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
auto input_size = inp.numel() * inp.element_size();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
"registered buffer is too small to contain the input");
AT_CUDA_CHECK(hipMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
input_size, hipMemcpyDeviceToDevice, stream));
_all_reduce(_fa, reg_buffer, out, stream);
}
void dispose(fptr_t _fa) {
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
delete fa;
}
int64_t meta_size() { return sizeof(sglang::Signal); }
void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets) {
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
fa->register_buffer(handles, offsets, t.data_ptr());
}
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa) {
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto handles =
torch::empty({static_cast<int64_t>(handle_bytes.size())}, options);
std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size());
return {handles, std::move(offsets)};
}
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets) {
auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
fa->register_graph_buffers(handles, offsets);
}
void free_meta_buffer(void* buffer) { CUDACHECK(hipFree(buffer)); }
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp) {
auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto data_handle =
torch::empty({static_cast<int64_t>(sizeof(hipIpcMemHandle_t))}, options);
CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)data_handle.data_ptr(),
inp.data_ptr()));
return data_handle;
}
torch::Tensor allocate_meta_buffer(int64_t size) {
auto device_index = c10::hip::current_device();
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
void* buffer;
hipStreamCaptureMode mode = hipStreamCaptureModeRelaxed;
auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
AT_CUDA_CHECK(hipThreadExchangeStreamCaptureMode(&mode));
AT_CUDA_CHECK(
hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
AT_CUDA_CHECK(hipMemsetAsync(buffer, 0, size, stream));
AT_CUDA_CHECK(hipStreamSynchronize(stream));
AT_CUDA_CHECK(hipThreadExchangeStreamCaptureMode(&mode));
auto options = torch::TensorOptions()
.dtype(torch::kI8)
.device(torch::kCUDA, device_index);
return torch::from_blob(buffer, {size}, free_meta_buffer, options);
}
std::vector<uint8_t> get_device_bdf(int dev) {
char busIdStr[] = "0000:00:00.0";
std::vector<uint8_t> bdf(sizeof(busIdStr), 0);
CUDACHECK(hipDeviceGetPCIBusId((char*)bdf.data(), sizeof(busIdStr), dev));
bdf.resize(bdf.size() - 1); // remove trailing NULL
return bdf;
}

View File

@@ -0,0 +1,582 @@
// !!! This is a file automatically generated by hipify!!!
#pragma once
#include <hip/hip_runtime.h>
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef __hip_bfloat16 nv_bfloat16;
#else
#include <hip/hip_bf16.h>
#endif
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <iostream>
#include <limits>
#include <map>
#include <unordered_map>
#include <vector>
#define CUDACHECK(cmd) \
do { \
hipError_t e = cmd; \
if (e != hipSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, hipGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
namespace sglang {
constexpr int kMaxBlocks = 64;
// note: we don't want to use atomics for signals because peer atomics are no
// supported on PCIe links
struct Signal {
alignas(128) uint32_t start[kMaxBlocks][8];
alignas(128) uint32_t end[kMaxBlocks][8];
alignas(128) uint32_t _flag[kMaxBlocks]; // incremental flags for each rank
};
#ifdef USE_ROCM
struct __align__(16) RankData {
const void* ptrs[8];
};
#else
struct __align__(16) RankData {
const void* __restrict__ ptrs[8];
};
#endif
struct __align__(16) RankSignals {
#ifndef USE_ROCM
volatile
#endif
Signal* signals[8];
};
// like std::array, but aligned
template <typename T, int sz>
struct __align__(alignof(T) * sz) array_t {
T data[sz];
using type = T;
static constexpr int size = sz;
};
// use packed type to maximize memory efficiency
// goal: generate ld.128 and st.128 instructions
template <typename T>
struct packed_t {
// the (P)acked type for load/store
using P = array_t<T, 16 / sizeof(T)>;
// the (A)ccumulator type for reduction
using A = array_t<float, 16 / sizeof(T)>;
};
#define DINLINE __device__ __forceinline__
// scalar cast functions
DINLINE float upcast_s(half val) {
return __half2float(val);
}
template <typename T>
DINLINE T downcast_s(float val);
template <>
DINLINE half downcast_s(float val) {
return __float2half(val);
}
// scalar add functions
// for some reason when compiling with Pytorch, the + operator for half and
// bfloat is disabled so we call the intrinsics directly
DINLINE half& assign_add(half& a, half b) {
a = __hadd(a, b);
return a;
}
DINLINE float& assign_add(float& a, float b) {
return a += b;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
DINLINE float upcast_s(nv_bfloat16 val) {
return __bfloat162float(val);
}
template <>
DINLINE nv_bfloat16 downcast_s(float val) {
return __float2bfloat16(val);
}
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
a = __hadd(a, b);
return a;
}
#endif
template <typename T, int N>
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
#pragma unroll
for (int i = 0; i < N; i++) {
assign_add(a.data[i], b.data[i]);
}
return a;
}
template <typename T, int N>
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
if constexpr (std::is_same<T, float>::value) {
return val;
} else {
array_t<float, N> out;
#pragma unroll
for (int i = 0; i < N; i++) {
out.data[i] = upcast_s(val.data[i]);
}
return out;
}
}
template <typename O>
DINLINE O downcast(array_t<float, O::size> val) {
if constexpr (std::is_same<typename O::type, float>::value) {
return val;
} else {
O out;
#pragma unroll
for (int i = 0; i < O::size; i++) {
out.data[i] = downcast_s<typename O::type>(val.data[i]);
}
return out;
}
}
// This function is meant to be used as the first synchronization in the all
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
template <int ngpus>
DINLINE void start_sync(
const RankSignals& sg,
#ifndef USE_ROCM
volatile
#endif
Signal* self_sg,
int rank) {
#ifdef USE_ROCM
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__hip_atomic_store(
&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM);
// wait until we got true from all ranks
while (__hip_atomic_load(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT) <
flag)
;
}
__syncthreads();
// use one thread to update flag
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
#else
if (threadIdx.x < ngpus) {
// reset flag for next time
self_sg->end[blockIdx.x][threadIdx.x] = 0;
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
// wait until we got true from all ranks
while (!self_sg->start[blockIdx.x][threadIdx.x])
;
}
__syncthreads();
#endif
}
// This function is meant to be used as the second or the final synchronization
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses.
template <int ngpus, bool final_sync = false>
DINLINE void end_sync(
const RankSignals& sg,
#ifndef USE_ROCM
volatile
#endif
Signal* self_sg,
int rank) {
#ifdef USE_ROCM
__syncthreads();
// eliminate the case that prior writes are not visible after signals become
// visible. Note that I did not managed to make this happen through a lot of
// testing. Might be the case that hardware provides stronger guarantee than
// the memory model.
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__hip_atomic_store(
&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
flag,
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
__HIP_MEMORY_SCOPE_SYSTEM);
// wait until we got true from all ranks
while (__hip_atomic_load(
&self_sg->end[blockIdx.x][threadIdx.x],
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
__HIP_MEMORY_SCOPE_AGENT) < flag)
;
}
__syncthreads();
// use one thread to update flag
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
#else
__syncthreads();
// eliminate the case that prior writes are not visible after signals become
// visible. Note that I did not managed to make this happen through a lot of
// testing. Might be the case that hardware provides stronger guarantee than
// the memory model.
if constexpr (!final_sync) __threadfence_system();
if (threadIdx.x < ngpus) {
// reset flag for next time
self_sg->start[blockIdx.x][threadIdx.x] = 0;
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
// wait until we got true from all ranks
while (!self_sg->end[blockIdx.x][threadIdx.x])
;
}
if constexpr (!final_sync) __syncthreads();
#endif
}
template <typename P, int ngpus, typename A>
DINLINE P packed_reduce(const P* ptrs[], int idx) {
A tmp = upcast(ptrs[0][idx]);
#pragma unroll
for (int i = 1; i < ngpus; i++) {
packed_assign_add(tmp, upcast(ptrs[i][idx]));
}
return downcast<P>(tmp);
}
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(
RankData* _dp,
RankSignals sg,
#ifndef USE_ROCM
volatile
#endif
Signal* self_sg,
T* __restrict__ result,
int rank,
int size) {
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
// note: we don't reorder the address so the accumulation order is the same
// for all ranks, ensuring bitwise identical results
auto dp = *_dp;
start_sync<ngpus>(sg, self_sg, rank);
// do the actual reduction
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) {
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
}
end_sync<ngpus, true>(sg, self_sg, rank);
}
template <typename P>
#ifdef USE_ROCM
DINLINE P* get_tmp_buf(Signal* sg) {
#else
DINLINE P* get_tmp_buf(volatile Signal* sg) {
#endif
return (P*)(((Signal*)sg) + 1);
}
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(
RankData* _dp,
RankSignals sg,
#ifndef USE_ROCM
volatile
#endif
Signal* self_sg,
T* __restrict__ result,
int rank,
int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
int part = size / ngpus;
int start = rank * part;
int end = rank == ngpus - 1 ? size : start + part;
int largest_part = part + size % ngpus;
const P* ptrs[ngpus];
P* tmps[ngpus];
#pragma unroll
for (int i = 0; i < ngpus; i++) {
int target = (rank + i) % ngpus;
ptrs[i] = (const P*)_dp->ptrs[target];
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
}
auto tmp_out = tmps[0];
start_sync<ngpus>(sg, self_sg, rank);
// stage 1: reduce scatter
for (int idx = start + tid; idx < end; idx += stride) {
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
}
end_sync<ngpus>(sg, self_sg, rank);
// stage 2: allgather. Note: it's important to match the tid between
// the two stages, because visibility across devices is only guaranteed
// between threads that have the same tid. If thread i computes the sum of
// start + i in the first stage, then thread i also gathers start + i from all
// ranks.
for (int idx = tid; idx < largest_part; idx += stride) {
#pragma unroll
for (int i = 0; i < ngpus; i++) {
int gather_from_rank = ((rank + i) % ngpus);
if (gather_from_rank == ngpus - 1 || idx < part) {
int dst_idx = gather_from_rank * part + idx;
((P*)result)[dst_idx] = tmps[i][idx];
}
}
}
}
using IPC_KEY = std::array<uint8_t, sizeof(hipIpcMemHandle_t)>;
static_assert(sizeof(IPC_KEY) == sizeof(hipIpcMemHandle_t));
static_assert(alignof(IPC_KEY) == alignof(hipIpcMemHandle_t));
class CustomAllreduce {
public:
int rank_;
int world_size_;
bool full_nvlink_;
// below are device pointers
RankSignals sg_;
std::unordered_map<void*, RankData*> buffers_;
Signal* self_sg_;
// stores the registered device pointers from all ranks
RankData *d_rank_data_base_, *d_rank_data_end_;
std::vector<void*> graph_unreg_buffers_;
// a map from IPC handles to opened IPC pointers
std::map<IPC_KEY, char*> ipc_handles_;
/**
* meta is a pointer to device metadata and temporary buffer for allreduce.
*
* There's a total of sizeof(Signal) of prefix before the actual data,
* so meta + 1 points to actual temporary buffer.
*
* note: this class does not own any device memory. Any required buffers
* are passed in from the constructor
*/
CustomAllreduce(
Signal* meta,
void* rank_data,
size_t rank_data_sz,
const hipIpcMemHandle_t* handles,
const std::vector<int64_t>& offsets,
int rank,
bool full_nvlink = true)
: rank_(rank),
world_size_(offsets.size()),
full_nvlink_(full_nvlink),
self_sg_(meta),
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
for (int i = 0; i < world_size_; i++) {
Signal* rank_sg;
if (i != rank_) {
char* handle = open_ipc_handle(&handles[i]);
handle += offsets[i];
rank_sg = (Signal*)handle;
} else {
rank_sg = self_sg_;
}
sg_.signals[i] = rank_sg;
}
}
char* open_ipc_handle(const void* ipc_handle) {
auto [it, new_handle] = ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
if (new_handle) {
char* ipc_ptr;
CUDACHECK(hipIpcOpenMemHandle(
(void**)&ipc_ptr, *((const hipIpcMemHandle_t*)ipc_handle), hipIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr;
}
return it->second;
}
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
auto num_buffers = graph_unreg_buffers_.size();
auto handle_sz = sizeof(hipIpcMemHandle_t);
std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
std::vector<int64_t> offsets(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto ptr = graph_unreg_buffers_[i];
void* base_ptr;
// note: must share the base address of each allocation, or we get wrong
// address
if (hipPointerGetAttribute(
&base_ptr,
#ifdef USE_ROCM
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR,
#else
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
#endif
(hipDeviceptr_t)ptr) != hipSuccess)
throw std::runtime_error("failed to get pointer attr");
CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
}
return std::make_pair(handles, offsets);
}
void check_rank_data_capacity(size_t num = 1) {
if (d_rank_data_base_ + num > d_rank_data_end_)
throw std::runtime_error(
"Rank data buffer is overflowed by " + std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
}
void register_buffer(const std::vector<std::string>& handles, const std::vector<int64_t>& offsets, void* self) {
check_rank_data_capacity();
RankData data;
for (int i = 0; i < world_size_; i++) {
if (i != rank_) {
char* handle = open_ipc_handle(handles[i].data());
handle += offsets[i];
data.ptrs[i] = handle;
} else {
data.ptrs[i] = self;
}
}
auto d_data = d_rank_data_base_++;
CUDACHECK(hipMemcpy(d_data, &data, sizeof(RankData), hipMemcpyHostToDevice));
buffers_[self] = d_data;
}
// note: when registering graph buffers, we intentionally choose to not
// deduplicate the addresses. That means if the allocator reuses some
// addresses, they will be registered again. This is to account for the remote
// possibility of different allocation patterns between ranks. For example,
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
void
register_graph_buffers(const std::vector<std::string>& handles, const std::vector<std::vector<int64_t>>& offsets) {
auto num_buffers = graph_unreg_buffers_.size();
check_rank_data_capacity(num_buffers);
std::vector<RankData> rank_data(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto self_ptr = graph_unreg_buffers_[i];
auto& rd = rank_data[i];
for (int j = 0; j < world_size_; j++) {
if (j != rank_) {
char* handle = open_ipc_handle(&handles[j][i * sizeof(hipIpcMemHandle_t)]);
handle += offsets[j][i];
rd.ptrs[j] = handle;
} else {
rd.ptrs[j] = self_ptr;
}
}
}
CUDACHECK(hipMemcpy(d_rank_data_base_, rank_data.data(), sizeof(RankData) * num_buffers, hipMemcpyHostToDevice));
d_rank_data_base_ += num_buffers;
graph_unreg_buffers_.clear();
}
/**
* This is the result after careful grid search. Using 36 blocks give the best
* or close to the best runtime on the devices I tried: A100, A10, A30, T4,
* V100. You'll notice that NCCL kernels also only take a small amount of SMs.
* Not quite sure the underlying reason, but my guess is that too many SMs
* will cause contention on NVLink bus.
*/
template <typename T>
void allreduce(
hipStream_t stream,
T* input,
T* output,
int size,
#ifndef USE_ROCM
int threads = 512,
int block_limit = 36){
#else
int threads = 512,
int block_limit = 16) {
#endif
auto d = packed_t<T>::P::size;
if (size % d != 0)
throw std::runtime_error(
"custom allreduce currently requires input length to be multiple "
"of " +
std::to_string(d));
if (block_limit > kMaxBlocks)
throw std::runtime_error(
"max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit));
RankData* ptrs;
hipStreamCaptureStatus status;
CUDACHECK(hipStreamIsCapturing(stream, &status));
if (status == hipStreamCaptureStatusActive) {
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
graph_unreg_buffers_.push_back(input);
} else {
auto it = buffers_.find(input);
if (it == buffers_.end())
throw std::runtime_error(
"buffer address " + std::to_string(reinterpret_cast<uint64_t>(input)) + " is not registered!");
ptrs = it->second;
}
size /= d;
auto bytes = size * sizeof(typename packed_t<T>::P);
int blocks = ::min(block_limit, (size + threads - 1) / threads);
#define KL(ngpus, name) \
hipLaunchKernelGGL( \
(name<T, ngpus>), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_, size);
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
KL(ngpus, cross_device_reduce_1stage); \
} else if (full_nvlink_) { \
if ((world_size_ <= 4 && bytes < 512 * 1024) || (world_size_ <= 8 && bytes < 256 * 1024)) { \
KL(ngpus, cross_device_reduce_1stage); \
} else { \
KL(ngpus, cross_device_reduce_2stage); \
} \
} \
break; \
}
switch (world_size_) {
REDUCE_CASE(2)
REDUCE_CASE(4)
REDUCE_CASE(6)
REDUCE_CASE(8)
default:
throw std::runtime_error(
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
"gpus = " +
std::to_string(world_size_));
}
#undef REDUCE_CASE
#undef KL
}
~CustomAllreduce() {
for (auto [_, ptr] : ipc_handles_) {
CUDACHECK(hipIpcCloseMemHandle(ptr));
}
}
}; // namespace sglang
/**
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
a template instantiation:
* template void sglang::CustomAllreduce::allreduce<half>(hipStream_t, half *,
half *, int, int, int);
*/
} // namespace sglang

View File

@@ -0,0 +1,140 @@
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/all.h>
#include <torch/library.h>
#include "mscclpp_allreduce.cuh"
enum MscclContextSelection {
MSCCL1NODELL = 1,
MSCCL2NODELL = 2,
};
class MscclContext {
public:
MscclContextSelection selection_;
std::shared_ptr<sglang::Msccl1NodeLLcontext> msccl_1nodeLL_context;
std::shared_ptr<sglang::Msccl2NodeLLcontext> msccl_2nodeLL_context;
MscclContext(MscclContextSelection selection) : selection_(selection) {}
template <typename T>
void allreduce(
cudaStream_t stream, T* input, T* output, const size_t input_numel, int threads = 512, int block_limit = 21) {
if (selection_ == MSCCL1NODELL) {
msccl_1nodeLL_context->allreduce<T>(stream, input, output, input_numel, threads, block_limit);
} else if (selection_ == MSCCL2NODELL) {
msccl_2nodeLL_context->allreduce<T>(stream, input, output, input_numel, threads, block_limit);
}
}
};
using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));
torch::Tensor _unique_id2tensor(const mscclpp::UniqueId& unique_id) {
auto options = torch::TensorOptions().dtype(torch::kByte).device(torch::kCPU);
auto tensor = torch::empty({static_cast<int64_t>(unique_id.size())}, options);
std::memcpy(tensor.data_ptr<uint8_t>(), unique_id.data(), unique_id.size());
return tensor;
}
// Function to convert vector of int32_t back to array of uint8_t
mscclpp::UniqueId _tensor2unique_id(const torch::Tensor& tensor) {
mscclpp::UniqueId unique_id;
std::memcpy(unique_id.data(), tensor.data_ptr<uint8_t>(), unique_id.size());
return unique_id;
}
torch::Tensor mscclpp_generate_unique_id() {
mscclpp::UniqueId unique_id = mscclpp::TcpBootstrap::createUniqueId();
return _unique_id2tensor(unique_id);
}
fptr_t mscclpp_init_context(
const torch::Tensor& unique_id,
const int64_t rank,
const int64_t world_size,
torch::Tensor& scratch,
torch::Tensor& put_buffer,
const int64_t nranks_per_node,
const std::vector<int64_t>& rank_to_node,
const std::vector<int64_t>& rank_to_ib,
const int64_t context_selection) {
MscclContext* context_ptr = new MscclContext(static_cast<MscclContextSelection>(context_selection));
mscclpp::UniqueId uid = _tensor2unique_id(unique_id);
if (context_selection == MSCCL1NODELL) {
void* scratch_ptr = reinterpret_cast<void*>(scratch.data_ptr());
const size_t scratch_bytes = scratch.numel() * scratch.element_size();
context_ptr->msccl_1nodeLL_context = std::make_shared<sglang::Msccl1NodeLLcontext>(
uid, rank, world_size, scratch_ptr, scratch_bytes, nranks_per_node, rank_to_node, rank_to_ib);
} else if (context_selection == MSCCL2NODELL) {
void* scratch_ptr = reinterpret_cast<void*>(scratch.data_ptr());
const size_t scratch_bytes = scratch.numel() * scratch.element_size();
void* put_buffer_ptr = reinterpret_cast<void*>(put_buffer.data_ptr());
const size_t put_buffer_bytes = put_buffer.numel() * put_buffer.element_size();
context_ptr->msccl_2nodeLL_context = std::make_shared<sglang::Msccl2NodeLLcontext>(
uid,
rank,
world_size,
scratch_ptr,
scratch_bytes,
put_buffer_ptr,
put_buffer_bytes,
nranks_per_node,
rank_to_node,
rank_to_ib);
} else {
throw std::runtime_error("invalid context selection");
}
return (fptr_t)context_ptr;
}
bool _mscclpp_is_weak_contiguous(torch::Tensor& t) {
return t.is_contiguous() ||
(t.storage().nbytes() - t.storage_offset() * t.element_size() == t.numel() * t.element_size());
}
void mscclpp_allreduce(fptr_t _context, torch::Tensor& inp, torch::Tensor& out, int64_t nthreads, int64_t nblocks) {
MscclContext* context = reinterpret_cast<MscclContext*>(_context);
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(_mscclpp_is_weak_contiguous(out));
TORCH_CHECK(_mscclpp_is_weak_contiguous(inp));
switch (out.scalar_type()) {
case at::ScalarType::Float: {
context->allreduce<float>(
stream,
reinterpret_cast<float*>(inp.data_ptr()),
reinterpret_cast<float*>(out.data_ptr()),
inp.numel(),
nthreads,
nblocks);
break;
}
case at::ScalarType::Half: {
context->allreduce<half>(
stream,
reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half*>(out.data_ptr()),
inp.numel(),
nthreads,
nblocks);
break;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: {
context->allreduce<__nv_bfloat16>(
stream,
reinterpret_cast<__nv_bfloat16*>(inp.data_ptr()),
reinterpret_cast<__nv_bfloat16*>(out.data_ptr()),
inp.numel(),
nthreads,
nblocks);
break;
}
#endif
default:
throw std::runtime_error("custom allreduce only supports float32, float16 and bfloat16");
}
}

View File

@@ -0,0 +1,779 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#pragma once
#ifdef USE_ROCM
#include <hip/hip_fp16.h>
#else
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#endif
#include <mscclpp/concurrency_device.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/memory_channel.hpp>
#include <mscclpp/memory_channel_device.hpp>
#include <mscclpp/nvls_device.hpp>
#include <mscclpp/port_channel.hpp>
#include <mscclpp/port_channel_device.hpp>
// comment this for test_mscclpp_allreduce.cu
#include "utils.h"
namespace sglang {
__device__ mscclpp::DeviceSyncer deviceSyncer;
__device__ mscclpp::DeviceSyncer allGatherDeviceSyncer;
__device__ mscclpp::DeviceSyncer reduceScatterDeviceSyncer;
__device__ mscclpp::DeviceSyncer ibDeviceSyncer;
template <typename To, typename From>
__forceinline__ __device__ To bit_cast(const From& src) {
static_assert(sizeof(To) == sizeof(From), "Size mismatch for bit_cast");
union {
From f;
To t;
} u;
u.f = src;
return u.t;
}
template <typename T>
__forceinline__ __device__ T add_elements(T a, T b) {
return a + b;
}
template <>
__forceinline__ __device__ __half2 add_elements(__half2 a, __half2 b) {
return __hadd2(a, b);
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
template <>
__forceinline__ __device__ __nv_bfloat162 add_elements(__nv_bfloat162 a, __nv_bfloat162 b) {
return __hadd2(a, b);
}
#endif
template <typename T>
__forceinline__ __device__ int4 add_vectors_helper(int4 a, int4 b) {
int4 ret;
ret.w = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.w), bit_cast<T, int>(b.w)));
ret.x = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.x), bit_cast<T, int>(b.x)));
ret.y = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.y), bit_cast<T, int>(b.y)));
ret.z = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.z), bit_cast<T, int>(b.z)));
return ret;
}
template <typename T>
__forceinline__ __device__ int4 add_vectors(int4 a, int4 b) {
return add_vectors_helper<T>(a, b);
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
template <>
__forceinline__ __device__ int4 add_vectors<__nv_bfloat16>(int4 a, int4 b) {
return add_vectors_helper<__nv_bfloat162>(a, b);
}
#endif
template <>
__forceinline__ __device__ int4 add_vectors<__half>(int4 a, int4 b) {
return add_vectors_helper<__half2>(a, b);
}
template <typename T>
__forceinline__ __device__ uint2 add_vectors_helper(uint2 a, uint2 b) {
uint2 ret;
ret.x = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.x), bit_cast<T, int>(b.x)));
ret.y = bit_cast<int, T>(add_elements(bit_cast<T, int>(a.y), bit_cast<T, int>(b.y)));
return ret;
}
template <typename T>
__forceinline__ __device__ uint2 add_vectors(uint2 a, uint2 b) {
return add_vectors_helper<T>(a, b);
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
template <>
__forceinline__ __device__ uint2 add_vectors<__nv_bfloat16>(uint2 a, uint2 b) {
return add_vectors_helper<__nv_bfloat162>(a, b);
}
#endif
template <>
__forceinline__ __device__ uint2 add_vectors<__half>(uint2 a, uint2 b) {
return add_vectors_helper<__half2>(a, b);
}
template <typename T>
__forceinline__ __device__ int add_vectors_helper(int a, int b) {
return bit_cast<int, T>(add_elements(bit_cast<T, int>(a), bit_cast<T, int>(b)));
}
template <typename T>
__forceinline__ __device__ int add_vectors(int a, int b) {
return add_vectors_helper<T>(a, b);
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
template <>
__forceinline__ __device__ int add_vectors<__nv_bfloat16>(int a, int b) {
return add_vectors_helper<__nv_bfloat162>(a, b);
}
#endif
template <>
__forceinline__ __device__ int add_vectors<__half>(int a, int b) {
return add_vectors_helper<__half2>(a, b);
}
// -------------------------------------------------------
// allreduce_LL_1node using LLPacket, origin allreduce2
// -------------------------------------------------------
__device__ uint64_t globalFlag = 1;
template <typename TYPE>
__global__ void __launch_bounds__(1024, 1) allreduce_LL_1node(
mscclpp::MemoryChannelDeviceHandle* memChans,
TYPE* buff,
TYPE* scratch,
void* resultBuff,
int rank,
int worldSize,
size_t nelems) {
nelems = nelems / (sizeof(int) / sizeof(TYPE));
// This version of allreduce only works for single nodes
const int nPeers = worldSize - 1;
const size_t nPkts = nelems / 2;
const int nelemsPerRank = nelems / worldSize;
const int nPktsPerRank = nelemsPerRank / 2;
// flag for packets. Initially 1
const uint32_t flag = (uint32_t)globalFlag;
// thread block & channel info
const int nBlocksPerPeer = gridDim.x / nPeers;
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
const int peerIdx = blockIdx.x / nBlocksPerPeer;
const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1;
mscclpp::MemoryChannelDeviceHandle memChan = memChans[peerIdx];
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
// double buffering
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket);
void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset);
size_t scratchOffset = scratchBaseOffset + rank * nPktsPerRank * sizeof(mscclpp::LLPacket);
size_t scratchResultOffset =
(flag & 1) ? 2 * nPkts * sizeof(mscclpp::LLPacket) : 3 * nPkts * sizeof(mscclpp::LLPacket);
size_t srcOffset = remoteRank * nelemsPerRank * sizeof(int);
uint2* src = (uint2*)((char*)buff + rank * nelemsPerRank * sizeof(int));
uint2* dst = (uint2*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int));
// step 1: write to scratch buffer
memChan.putPackets(scratchOffset, srcOffset, nelemsPerRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer, flag);
// step 2: get data from scratch buffer, reduce data and write result to remote scratch buffer
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * gridDim.x) {
uint2 data = make_uint2(0, 0);
for (int index = 0; index < nPeers; index++) {
const int remoteRank = index < rank ? index : index + 1;
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerRank;
uint2 val = dstPkt[idx].read(flag);
data = add_vectors<TYPE>(val, data);
}
data = add_vectors<TYPE>(data, src[idx]);
dst[idx] = data;
mscclpp::LLPacket packet;
packet.data1 = data.x;
packet.flag1 = flag;
packet.data2 = data.y;
packet.flag2 = flag;
size_t offset = scratchResultOffset / sizeof(mscclpp::LLPacket) + (idx + rank * nPktsPerRank);
for (int index = 0; index < nPeers; index++) {
memChans[index].write(offset, packet);
}
}
// step 3: get data result from scratch buffer
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)((char*)scratch + scratchResultOffset);
const int dstOffset = remoteRank * nPktsPerRank;
uint2* result = (uint2*)((char*)resultBuff + remoteRank * nelemsPerRank * sizeof(int));
for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerRank; idx += blockDim.x * nBlocksPerPeer) {
uint2 data = dstPkt[idx + dstOffset].read(flag);
result[idx].x = data.x;
result[idx].y = data.y;
}
if (threadIdx.x == 0 && blockIdx.x == 0) {
globalFlag += 1;
}
}
// -------------------------------------------------------
// allreduce_LL_2node using LLPacket, origin allreduce5
// -------------------------------------------------------
template <typename TYPE>
__global__ void __launch_bounds__(1024, 1) allreduce_LL_2node(
mscclpp::MemoryChannelDeviceHandle* memChans,
mscclpp::PortChannelDeviceHandle* portChans,
TYPE* buff,
TYPE* scratch,
TYPE* putBuff,
TYPE* resultBuff,
int rank,
int nRanksPerNode,
int worldSize,
size_t nelems) {
nelems = nelems / (sizeof(int) / sizeof(TYPE));
// This version of allreduce only works for single nodes
const int nPeersInNode = nRanksPerNode - 1;
const int nPkts = nelems / 2;
const int nelemsPerLocalRank = nelems / nRanksPerNode;
const int nPktsPerLocalRank = nelemsPerLocalRank / 2;
const int localRankId = rank % nRanksPerNode;
// flag for packets. Initially 1
const uint32_t flag = (uint32_t)globalFlag;
// thread block & channel info
const int nBlocksPerPeer = gridDim.x / nPeersInNode;
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
const int peerIdx = blockIdx.x / nBlocksPerPeer;
const int remoteRankIdx = peerIdx < localRankId ? peerIdx : peerIdx + 1;
mscclpp::MemoryChannelDeviceHandle memChan = memChans[peerIdx];
mscclpp::PortChannelDeviceHandle portChan = portChans[localRankId];
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
// double buffering
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LLPacket);
size_t putBaseOffset = (flag & 1) ? 0 : nPktsPerLocalRank * sizeof(mscclpp::LLPacket);
void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset);
size_t scratchOffset = scratchBaseOffset + localRankId * nPktsPerLocalRank * sizeof(mscclpp::LLPacket);
size_t scratchResultOffset =
(flag & 1) ? 2 * nPkts * sizeof(mscclpp::LLPacket) : 3 * nPkts * sizeof(mscclpp::LLPacket);
size_t srcOffset = remoteRankIdx * nelemsPerLocalRank * sizeof(int);
uint2* src = (uint2*)((char*)buff + localRankId * nelemsPerLocalRank * sizeof(int));
uint2* dst = (uint2*)((char*)resultBuff + localRankId * nelemsPerLocalRank * sizeof(int));
// step 1: write to scratch buffer
if (nRanksPerNode > 1) {
memChan.putPackets(
scratchOffset, srcOffset, nelemsPerLocalRank * sizeof(int), tid, blockDim.x * nBlocksPerPeer, flag);
}
// step 2: get data from scratch buffer, do local reduce-scatter in each node.
mscclpp::LLPacket* putPkt = (mscclpp::LLPacket*)((char*)putBuff + putBaseOffset);
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerLocalRank; idx += blockDim.x * gridDim.x) {
uint2 data = make_uint2(0, 0);
for (int index = 0; index < nPeersInNode; index++) {
const int remoteRank = index < localRankId ? index : index + 1;
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerLocalRank;
uint2 val = dstPkt[idx].read(flag);
data = add_vectors<TYPE>(val, data);
}
data = add_vectors<TYPE>(data, src[idx]);
putPkt[idx].write(data.x, data.y, flag);
dst[idx] = data;
}
deviceSyncer.sync(gridDim.x);
// step 3. send local reduced data to remote node.
if (threadIdx.x == 0 && blockIdx.x == 0) {
portChan.put(scratchOffset, putBaseOffset, nPktsPerLocalRank * sizeof(mscclpp::LLPacket));
if ((flag & 63) == 0) {
portChan.flush();
}
}
// step 4. try to read the data from scratch buffer and write to local peers
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + localRankId * nPktsPerLocalRank;
for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nPktsPerLocalRank; idx += blockDim.x * gridDim.x) {
uint2 res = dst[idx];
uint2 val = dstPkt[idx].read(flag);
res = add_vectors<TYPE>(res, val);
mscclpp::LLPacket packet;
packet.data1 = res.x;
packet.flag1 = flag;
packet.data2 = res.y;
packet.flag2 = flag;
size_t offset = scratchResultOffset / sizeof(mscclpp::LLPacket) + (idx + localRankId * nPktsPerLocalRank);
for (int index = 0; index < nPeersInNode; index++) {
memChans[index].write(offset, packet);
}
dst[idx] = res;
}
// step 5: get data result from scratch buffer
dstPkt = (mscclpp::LLPacket*)((char*)scratch + scratchResultOffset);
const int dstOffset = remoteRankIdx * nPktsPerLocalRank;
uint2* result = (uint2*)((char*)resultBuff + remoteRankIdx * nelemsPerLocalRank * sizeof(int));
if (nRanksPerNode > 1) {
for (int idx = threadIdx.x + localBlockIdx * blockDim.x; idx < nPktsPerLocalRank;
idx += blockDim.x * nBlocksPerPeer) {
uint2 data = dstPkt[idx + dstOffset].read(flag);
result[idx] = data;
}
}
if (threadIdx.x == 0 && blockIdx.x == 0) {
globalFlag += 1;
}
}
static const mscclpp::Transport IBs[] = {
mscclpp::Transport::IB0,
mscclpp::Transport::IB1,
mscclpp::Transport::IB2,
mscclpp::Transport::IB3,
mscclpp::Transport::IB4,
mscclpp::Transport::IB5,
mscclpp::Transport::IB6,
mscclpp::Transport::IB7};
class MscclCommGroup {
public:
std::shared_ptr<mscclpp::Communicator> comm_;
const size_t rank_;
const size_t world_size_;
const std::vector<int64_t> rank_to_node_;
const std::vector<int64_t> rank_to_ib_;
MscclCommGroup(
mscclpp::UniqueId unique_id,
const size_t rank,
const size_t world_size,
const std::vector<int64_t>& rank_to_node,
const std::vector<int64_t>& rank_to_ib)
: rank_(rank), world_size_(world_size), rank_to_node_(rank_to_node), rank_to_ib_(rank_to_ib) {
auto bootstrap = std::make_shared<mscclpp::TcpBootstrap>(rank, world_size);
bootstrap->initialize(unique_id);
comm_ = std::make_shared<mscclpp::Communicator>(bootstrap);
}
template <typename T>
void allreduce(cudaStream_t stream, T* output, size_t input_numel, int threads = 512, int block_limit = 21) {
throw std::runtime_error("you should not call allreduce of a base context");
}
bool is_same_node(int r1, int r2) {
return rank_to_node_[r1] == rank_to_node_[r2];
}
void make_connection(
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& same_node_connections,
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& cross_node_connections) {
same_node_connections.clear();
cross_node_connections.clear();
std::unordered_map<int, mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> conn_futures;
for (int r = 0; r < world_size_; ++r) {
if (r == rank_) continue;
mscclpp::Transport transport = is_same_node(r, rank_) ? mscclpp::Transport::CudaIpc : IBs[rank_to_ib_[r]];
conn_futures.emplace(r, comm_->connectOnSetup(r, 0, transport));
}
comm_->setup();
for (int r = 0; r < world_size_; ++r) {
if (r == rank_) continue;
if (is_same_node(r, rank_)) {
same_node_connections.emplace(r, conn_futures[r].get());
} else {
cross_node_connections.emplace(r, conn_futures[r].get());
}
}
}
void make_memory_channels_with_scratch(
void* tensor_ptr,
const size_t tensor_bytes,
void* scratch_ptr,
const size_t scratch_bytes,
const std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
std::unordered_map<int, std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>>& semaphores,
std::unordered_map<int, mscclpp::RegisteredMemory>& registered_memories,
std::unordered_map<int, mscclpp::MemoryChannel>& channels) {
channels.clear();
make_semaphores<mscclpp::MemoryDevice2DeviceSemaphore>(connections, semaphores);
register_tensor_with_connections(scratch_ptr, scratch_bytes, connections, registered_memories);
for (const auto& [peer, _] : connections) {
channels.emplace(
peer, mscclpp::MemoryChannel(semaphores[peer], registered_memories[peer], tensor_ptr, scratch_ptr));
}
}
void make_port_channels_with_scratch(
std::shared_ptr<mscclpp::ProxyService> proxyService,
void* tensor_ptr,
const size_t tensor_bytes,
void* scratch_ptr,
const size_t scratch_bytes,
const std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
std::unordered_map<int, std::shared_ptr<mscclpp::Host2DeviceSemaphore>>& semaphores,
std::unordered_map<int, mscclpp::RegisteredMemory>& registered_memories,
std::unordered_map<int, mscclpp::PortChannel>& channels) {
channels.clear();
make_semaphores<mscclpp::Host2DeviceSemaphore>(connections, semaphores);
mscclpp::TransportFlags flags;
for (const auto& [_, conn] : connections) {
flags |= conn->transport();
}
auto local_reg_memory = comm_->registerMemory(tensor_ptr, tensor_bytes, flags);
register_tensor_with_connections(scratch_ptr, scratch_bytes, connections, registered_memories);
std::unordered_map<int, mscclpp::SemaphoreId> semaphore_ids;
std::unordered_map<int, size_t> memory_ids;
memory_ids[rank_] = proxyService->addMemory(local_reg_memory);
for (const auto& [peer, memory] : registered_memories) {
if (peer == rank_) continue;
memory_ids[peer] = proxyService->addMemory(memory);
}
for (const auto& [peer, semaphore] : semaphores) {
semaphore_ids[peer] = proxyService->addSemaphore(semaphore);
}
for (const auto& [peer, _] : connections) {
channels.emplace(peer, proxyService->portChannel(semaphore_ids[peer], memory_ids[peer], memory_ids[rank_]));
}
}
template <typename SemaphoreType>
void make_semaphores(
const std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
std::unordered_map<int, std::shared_ptr<SemaphoreType>>& semaphores) {
semaphores.clear();
for (const auto& [peer, conn] : connections) {
semaphores[peer] = std::make_shared<SemaphoreType>(*comm_, conn);
}
comm_->setup();
}
void register_tensor_with_connections(
void* tensor_ptr,
size_t tensor_bytes,
const std::unordered_map<int, std::shared_ptr<mscclpp::Connection>>& connections,
std::unordered_map<int, mscclpp::RegisteredMemory>& registered_memories) {
registered_memories.clear();
mscclpp::TransportFlags all_transports;
for (const auto& [_, connection] : connections) {
all_transports |= connection->transport();
}
mscclpp::RegisteredMemory buf_reg_mem = comm_->registerMemory(tensor_ptr, tensor_bytes, all_transports);
registered_memories[rank_] = buf_reg_mem;
std::unordered_map<int, mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> remote_mem_futures;
for (const auto& [r, connection] : connections) {
comm_->sendMemoryOnSetup(buf_reg_mem, r, 0);
auto remoteMemory = comm_->recvMemoryOnSetup(r, 0);
remote_mem_futures.emplace(r, remoteMemory);
}
comm_->setup();
for (auto& [r, mem_feature] : remote_mem_futures) {
registered_memories.emplace(r, mem_feature.get());
}
}
void make_device_memory_handle_base_on_new_ptr(
const std::unordered_map<int, mscclpp::MemoryChannel>& old_memory_channels,
std::unordered_map<int, mscclpp::RegisteredMemory>& registered_sm_memories,
std::unordered_map<int, std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>>& memory_semaphores,
std::unordered_map<int, mscclpp::MemoryChannel>& memory_channels,
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle>& device_memory_handle,
void* input,
void* scratch,
const cudaStream_t stream) {
memory_channels.clear();
for (const auto& [peer, channel] : old_memory_channels) {
memory_channels.emplace(
peer, mscclpp::MemoryChannel(memory_semaphores[peer], registered_sm_memories[peer], input, scratch));
}
std::vector<mscclpp::MemoryChannel> memory_channels_list;
for (int r = 0; r < world_size_; r++) {
if (r == rank_) continue;
if (is_same_node(r, rank_)) {
memory_channels_list.push_back(memory_channels[r]);
}
}
std::vector<mscclpp::MemoryChannelDeviceHandle> memory_channel_handlers(memory_channels_list.size());
std::transform(
memory_channels_list.begin(),
memory_channels_list.end(),
memory_channel_handlers.begin(),
[](const mscclpp::MemoryChannel& channel) { return channel.deviceHandle(); });
mscclpp::gpuMemcpyAsync<mscclpp::MemoryChannelDeviceHandle>(
device_memory_handle.data(),
memory_channel_handlers.data(),
memory_channel_handlers.size(),
stream,
cudaMemcpyHostToDevice);
}
};
class Msccl1NodeLLcontext {
private:
std::shared_ptr<MscclCommGroup> comm_group_ = nullptr;
void* scratch_;
const size_t scratch_bytes_;
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> same_node_connections_;
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> cross_node_connections_;
std::unordered_map<int, mscclpp::RegisteredMemory> registered_sm_memories_;
std::unordered_map<int, std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memory_semaphores_;
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels_;
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle> d_memHandles_;
std::unordered_map<void*, std::unordered_map<int, mscclpp::MemoryChannel>> input_ptr2memory_channels_;
std::unordered_map<void*, mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle>> input_ptr2d_memHandles_;
cudaStream_t h2d_stream;
const size_t nranks_per_node_;
public:
Msccl1NodeLLcontext(
mscclpp::UniqueId unique_id,
const size_t rank,
const size_t world_size,
void* scratch,
const size_t scratch_bytes,
const size_t nranks_per_node,
const std::vector<int64_t>& rank_to_node,
const std::vector<int64_t>& rank_to_ib)
: scratch_(scratch),
scratch_bytes_(scratch_bytes),
nranks_per_node_(nranks_per_node),
d_memHandles_(nranks_per_node - 1) {
CHECK_CUDA_SUCCESS(cudaStreamCreateWithFlags(&h2d_stream, cudaStreamNonBlocking));
comm_group_ = std::make_shared<MscclCommGroup>(unique_id, rank, world_size, rank_to_node, rank_to_ib);
comm_group_->make_connection(same_node_connections_, cross_node_connections_);
comm_group_->make_memory_channels_with_scratch(
scratch_,
scratch_bytes_,
scratch_,
scratch_bytes_,
same_node_connections_,
memory_semaphores_,
registered_sm_memories_,
memory_channels_);
std::vector<mscclpp::MemoryChannel> memory_channels_list;
for (int r = 0; r < comm_group_->world_size_; r++) {
if (r == comm_group_->rank_) continue;
memory_channels_list.push_back(memory_channels_[r]);
}
std::vector<mscclpp::MemoryChannelDeviceHandle> memory_channel_handlers(memory_channels_list.size());
std::transform(
memory_channels_list.begin(),
memory_channels_list.end(),
memory_channel_handlers.begin(),
[](const mscclpp::MemoryChannel& channel) { return channel.deviceHandle(); });
mscclpp::gpuMemcpy<mscclpp::MemoryChannelDeviceHandle>(
d_memHandles_.data(), memory_channel_handlers.data(), memory_channel_handlers.size(), cudaMemcpyHostToDevice);
}
~Msccl1NodeLLcontext() {
CHECK_CUDA_SUCCESS(cudaStreamDestroy(h2d_stream));
}
template <typename T>
void allreduce(cudaStream_t stream, T* input, T* output, size_t input_numel, int nthreads = 512, int nblocks = 21) {
dim3 nthrs(nthreads);
dim3 nblks(nblocks);
cudaStreamCaptureStatus capturing_status;
CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &capturing_status));
mscclpp::MemoryChannelDeviceHandle* memChans;
if (capturing_status != cudaStreamCaptureStatusActive) {
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels;
comm_group_->make_device_memory_handle_base_on_new_ptr(
memory_channels_,
registered_sm_memories_,
memory_semaphores_,
memory_channels,
d_memHandles_,
input,
scratch_,
h2d_stream);
CHECK_CUDA_SUCCESS(cudaStreamSynchronize(h2d_stream));
memChans = d_memHandles_.data();
} else {
void* input_void_ptr = reinterpret_cast<void*>(input);
if (input_ptr2d_memHandles_.find(input_void_ptr) == input_ptr2d_memHandles_.end()) {
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels;
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle> device_memory_handle(comm_group_->world_size_ - 1);
comm_group_->make_device_memory_handle_base_on_new_ptr(
memory_channels_,
registered_sm_memories_,
memory_semaphores_,
memory_channels,
device_memory_handle,
input,
scratch_,
h2d_stream);
input_ptr2memory_channels_.emplace(input_void_ptr, memory_channels);
input_ptr2d_memHandles_.emplace(input_void_ptr, device_memory_handle);
}
auto it = input_ptr2d_memHandles_.find(input_void_ptr);
memChans = it->second.data();
}
allreduce_LL_1node<T><<<nblks, nthrs, 0, stream>>>(
memChans, (T*)input, (T*)scratch_, output, comm_group_->rank_, comm_group_->world_size_, input_numel);
cudaError_t status = cudaGetLastError();
if (status != cudaSuccess) {
printf("rank: %lu failed to launch allreduce_LL_1node: %s\n", comm_group_->rank_, cudaGetErrorString(status));
}
}
};
class Msccl2NodeLLcontext {
private:
std::shared_ptr<MscclCommGroup> comm_group_ = nullptr;
void* scratch_;
const size_t scratch_bytes_;
void* put_buffer_;
const size_t put_buffer_bytes_;
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> same_node_connections_;
std::unordered_map<int, std::shared_ptr<mscclpp::Connection>> cross_node_connections_;
std::unordered_map<int, mscclpp::RegisteredMemory> registered_sm_memories_;
std::unordered_map<int, mscclpp::RegisteredMemory> registered_port_memories_;
std::unordered_map<int, std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memory_semaphores_;
std::unordered_map<int, std::shared_ptr<mscclpp::Host2DeviceSemaphore>> port_semaphores_;
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels_;
std::unordered_map<int, mscclpp::PortChannel> port_channels_;
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle> d_memHandles_;
mscclpp::GpuBuffer<mscclpp::PortChannelDeviceHandle> d_portHandles_;
std::shared_ptr<mscclpp::ProxyService> proxyService;
cudaStream_t h2d_stream;
const size_t nranks_per_node_;
std::unordered_map<void*, std::unordered_map<int, mscclpp::MemoryChannel>> input_ptr2memory_channels_;
std::unordered_map<void*, mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle>> input_ptr2d_memHandles_;
public:
Msccl2NodeLLcontext(
mscclpp::UniqueId unique_id,
const size_t rank,
const size_t world_size,
void* scratch,
const size_t scratch_bytes,
void* put_buffer,
const size_t put_buffer_bytes,
const size_t nranks_per_node,
const std::vector<int64_t>& rank_to_node,
const std::vector<int64_t>& rank_to_ib)
: scratch_(scratch),
scratch_bytes_(scratch_bytes),
put_buffer_(put_buffer),
put_buffer_bytes_(put_buffer_bytes),
nranks_per_node_(nranks_per_node),
d_memHandles_(nranks_per_node - 1),
d_portHandles_(world_size - nranks_per_node) {
CHECK_CUDA_SUCCESS(cudaStreamCreateWithFlags(&h2d_stream, cudaStreamNonBlocking));
comm_group_ = std::make_shared<MscclCommGroup>(unique_id, rank, world_size, rank_to_node, rank_to_ib);
proxyService = std::make_shared<mscclpp::ProxyService>();
proxyService->startProxy();
comm_group_->make_connection(same_node_connections_, cross_node_connections_);
comm_group_->make_memory_channels_with_scratch(
scratch_,
scratch_bytes_,
scratch_,
scratch_bytes_,
same_node_connections_,
memory_semaphores_,
registered_sm_memories_,
memory_channels_);
comm_group_->make_port_channels_with_scratch(
proxyService,
put_buffer_,
put_buffer_bytes_,
scratch_,
scratch_bytes_,
cross_node_connections_,
port_semaphores_,
registered_port_memories_,
port_channels_);
std::vector<mscclpp::MemoryChannel> memory_channels_list;
std::vector<mscclpp::PortChannel> port_channels_list;
for (int r = 0; r < comm_group_->world_size_; r++) {
if (r == comm_group_->rank_) continue;
if (comm_group_->is_same_node(r, comm_group_->rank_)) {
memory_channels_list.push_back(memory_channels_[r]);
} else {
port_channels_list.push_back(port_channels_[r]);
}
}
std::vector<mscclpp::MemoryChannelDeviceHandle> memory_channel_handlers(memory_channels_list.size());
std::transform(
memory_channels_list.begin(),
memory_channels_list.end(),
memory_channel_handlers.begin(),
[](const mscclpp::MemoryChannel& channel) { return channel.deviceHandle(); });
mscclpp::gpuMemcpy<mscclpp::MemoryChannelDeviceHandle>(
d_memHandles_.data(), memory_channel_handlers.data(), memory_channel_handlers.size(), cudaMemcpyHostToDevice);
std::vector<mscclpp::PortChannelDeviceHandle> port_channel_handlers(port_channels_list.size());
std::transform(
port_channels_list.begin(),
port_channels_list.end(),
port_channel_handlers.begin(),
[](const mscclpp::PortChannel& channel) { return channel.deviceHandle(); });
mscclpp::gpuMemcpy<mscclpp::PortChannelDeviceHandle>(
d_portHandles_.data(), port_channel_handlers.data(), port_channel_handlers.size(), cudaMemcpyHostToDevice);
}
~Msccl2NodeLLcontext() {
CHECK_CUDA_SUCCESS(cudaStreamDestroy(h2d_stream));
if (proxyService) {
proxyService->stopProxy();
}
}
template <typename T>
void
allreduce(cudaStream_t stream, T* input, T* output, const size_t input_numel, int nthreads = 512, int nblocks = 21) {
dim3 nthrs(nthreads);
dim3 nblks(nblocks);
cudaStreamCaptureStatus capturing_status;
CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &capturing_status));
mscclpp::MemoryChannelDeviceHandle* memChans;
if (capturing_status != cudaStreamCaptureStatusActive) {
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels;
comm_group_->make_device_memory_handle_base_on_new_ptr(
memory_channels_,
registered_sm_memories_,
memory_semaphores_,
memory_channels,
d_memHandles_,
input,
scratch_,
h2d_stream);
CHECK_CUDA_SUCCESS(cudaStreamSynchronize(h2d_stream));
memChans = d_memHandles_.data();
} else {
void* input_void_ptr = reinterpret_cast<void*>(input);
if (input_ptr2d_memHandles_.find(input_void_ptr) == input_ptr2d_memHandles_.end()) {
std::unordered_map<int, mscclpp::MemoryChannel> memory_channels;
mscclpp::GpuBuffer<mscclpp::MemoryChannelDeviceHandle> device_memory_handle(7);
comm_group_->make_device_memory_handle_base_on_new_ptr(
memory_channels_,
registered_sm_memories_,
memory_semaphores_,
memory_channels,
device_memory_handle,
input,
scratch_,
h2d_stream);
input_ptr2memory_channels_.emplace(input_void_ptr, memory_channels);
input_ptr2d_memHandles_.emplace(input_void_ptr, device_memory_handle);
}
auto it = input_ptr2d_memHandles_.find(input_void_ptr);
memChans = it->second.data();
}
allreduce_LL_2node<T><<<nblks, nthrs, 0, stream>>>(
memChans,
d_portHandles_.data(),
(T*)input,
(T*)scratch_,
(T*)put_buffer_,
output,
comm_group_->rank_,
nranks_per_node_,
comm_group_->world_size_,
input_numel);
cudaError_t status = cudaGetLastError();
if (status != cudaSuccess) {
printf("rank: %lu failed to launch allreduce_LL_2node: %s\n", comm_group_->rank_, cudaGetErrorString(status));
}
}
};
} // namespace sglang

View File

@@ -0,0 +1,111 @@
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/all.h>
#ifdef USE_ROCM
#include "quick_all_reduce.h"
quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size, std::optional<int64_t> qr_max_size) {
if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported");
if (world_size == 6) throw std::invalid_argument("world size == 6 is not supported");
if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now");
if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in");
quickreduce::DeviceComms* fptr = new quickreduce::DeviceComms();
fptr->init(world_size, rank, qr_max_size);
return (quickreduce::fptr_t)fptr;
}
void qr_destroy(quickreduce::fptr_t _fa) {
if (_fa) {
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
fa->destroy();
delete fa;
}
}
torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) {
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
hipIpcMemHandle_t handle = fa->get_handle();
auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
auto data_handle = torch::empty({static_cast<int64_t>(sizeof(hipIpcMemHandle_t))}, options);
std::memcpy(data_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t));
return data_handle;
}
void qr_open_handles(quickreduce::fptr_t _fa, const std::vector<torch::Tensor>& handles) {
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
std::vector<hipIpcMemHandle_t> ipc_handles;
ipc_handles.reserve(handles.size());
for (auto& handle : handles) {
// Ensure the tensor is on the same device as the current device.
hipIpcMemHandle_t ipc_handle;
std::memcpy(&ipc_handle, handle.data_ptr(), sizeof(hipIpcMemHandle_t));
ipc_handles.push_back(ipc_handle);
}
fa->open_ipc_handles(ipc_handles);
}
void qr_all_reduce(
quickreduce::fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half) {
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK_LE(out.numel(), fa->kMaxProblemSize);
if (out.scalar_type() == at::ScalarType::Half) {
fa->allreduce<half, false>(
reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half*>(out.data_ptr()),
out.numel(),
quant_level,
stream);
} else if (out.scalar_type() == at::ScalarType::BFloat16) {
if (cast_bf2half) {
fa->allreduce<half, true>(
reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half*>(out.data_ptr()),
out.numel(),
quant_level,
stream);
} else {
fa->allreduce<quickreduce::nv_bfloat16, false>(
reinterpret_cast<quickreduce::nv_bfloat16*>(inp.data_ptr()),
reinterpret_cast<quickreduce::nv_bfloat16*>(out.data_ptr()),
out.numel(),
quant_level,
stream);
}
} else {
throw std::runtime_error("quick allreduce only supports float16 and bfloat16");
}
}
int64_t qr_max_size() {
// The default is 2GB (2,147,483,648 bytes)
return static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1;
}
#define INSTANTIATE_FOR_WORLDSIZE(T, Codec, cast_bf2half) \
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 2>, cast_bf2half>; \
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 4>, cast_bf2half>; \
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 8>, cast_bf2half>;
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, false)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, false)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, false)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, false)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, true)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, true)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, true)
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, true)
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecFP, false)
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ4, false)
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ6, false)
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ8, false)
#endif // USE_ROCM

View File

@@ -0,0 +1,633 @@
#pragma once
#include <hip/hip_runtime.h>
#include "quick_all_reduce_base.h"
namespace quickreduce {
struct CodecBase {
const int thread;
const int rank;
const int group_leader;
__quickreduce_device_inline__ CodecBase(int thread, int rank)
: thread(thread), rank(rank), group_leader((threadIdx.x / kThreadGroupSize) * kThreadGroupSize) {
set_fp16_ovfl(true);
}
};
// Default full precision codec.
template <typename T, int world_size>
struct CodecFP : public CodecBase {
static constexpr int kWorldSize = world_size;
static constexpr int kRankAtoms = kAtoms / kWorldSize;
// Codec tile size process by this workgroup.
// Each thread processes atoms of f16x8_t (16B).
static constexpr int kRankTransmittedTileSize = kBlockSize * kRankAtoms * sizeof(int32x4_t);
static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTransmittedTileSize must be 16B aligned.");
// Total tile size for the collective communication.
static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize;
__quickreduce_device_inline__ CodecFP(int thread, int rank) : CodecBase(thread, rank) {}
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, const int32x4_t* __restrict__ data) {
for (int i = 0; i < kRankAtoms; i++) {
__builtin_nontemporal_store(data[i], send_buffer + thread);
send_buffer += kAtomStride;
}
}
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, int32x4_t* __restrict__ data) {
for (int i = 0; i < kRankAtoms; i++) {
data[i] = __builtin_nontemporal_load(*recv_buffer + thread);
*recv_buffer += kAtomStride;
}
}
};
// Int4 symmetric quantization codec.
// We quantize the FP16 data to block-scaled Int4 in blocks of 4 *
// kThreadGroupSize.
template <typename T, int world_size>
struct CodecQ4 : public CodecBase {
static constexpr int kWorldSize = world_size;
// Codec tile size process by this workgroup.
// Each threads processes a fragment of fp16x8_t (16B),
// into a int4x8_t (4B) and a fp16 scale shared among 32 values.
static constexpr int kRankAtoms = kAtoms / kWorldSize;
static constexpr int kRankTileStride = 1152;
static constexpr int kRankTileScaleOffset = 1024;
static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms;
static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTransmittedTileSize must be 16B aligned.");
static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t);
// Total tile size for the collective communication.
static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize;
// Constants configuration
// {-1/8.0h, -1/8.0h}, f16x2_t
static constexpr int kScaleFactor = std::is_same<T, half>::value ? 0xB000B000 : 0xBE00BE00;
// {1e-7, 1e-7}, f16x2_t
static constexpr int kScaleEpsilon = std::is_same<T, half>::value ? 0x00010001 : 0x33D733D7;
// {-8, -8}, f16x2_t
static constexpr int kRangeMin = std::is_same<T, half>::value ? 0xC800C800 : 0xC100C100;
// {+7, +7}, f16x2_t
static constexpr int kRangeMax = std::is_same<T, half>::value ? 0x47004700 : 0x40E040E0;
// {+8, +8}, int16x2_t
static constexpr int kRangeBias = 0x00080008;
__quickreduce_device_inline__ CodecQ4(int thread, int rank) : CodecBase(thread, rank) {}
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, const int32x4_t* __restrict__ data) {
for (int k = 0; k < kRankAtoms; k++) {
int32x4_t const atom = data[k];
// Compute the absolute maximum of the atom in the thread group
// In 2 blocks of values, upper/lower halves of the f16x2_t
int wblockmax = group_abs_max<T>(atom);
// Derive scales
int decoding_scale;
int encoding_scale;
decoding_scale = packed_mul<T>(wblockmax, kScaleFactor);
encoding_scale = packed_add<T>(decoding_scale, kScaleEpsilon);
encoding_scale = packed_rcp<T>(encoding_scale);
// Apply scales to get quantized values
int32x4_t w;
for (int i = 0; i < 4; i++) {
w[i] = packed_mul<T>(atom[i], encoding_scale);
w[i] = packed_max<T>(w[i], kRangeMin);
w[i] = packed_min<T>(w[i], kRangeMax);
}
// Convert from f16x2_t to uint16x2_t
int32x4_t q;
{
int16_t* qi = reinterpret_cast<int16_t*>(&q);
T* wh = reinterpret_cast<T*>(&w);
for (int i = 0; i < 8; i++)
qi[i] = (int16_t)rintf(T2float_cast(wh[i]));
for (int i = 0; i < 4; i++) {
q[i] = packed_add<int16_t>(q[i], kRangeBias);
}
}
// Pack 8 x q4 into int32_t
int qw = q[0] | (q[1] << 4) | (q[2] << 8) | (q[3] << 12);
// Write quantized atom to send_buffer
// note: only the group leader stores the scale
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(send_buffer + k * kRankBufferTileStride);
int32_t* qw_ptr = reinterpret_cast<int32_t*>(atom_ptr) + thread;
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + (thread / 8);
__builtin_nontemporal_store(qw, qw_ptr);
if (threadIdx.x == group_leader) {
__builtin_nontemporal_store(decoding_scale, qs_ptr);
}
}
}
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, int32x4_t* __restrict__ data) {
for (int k = 0; k < kRankAtoms; k++) {
// Directly read quantized atom from recv_buffer
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(*recv_buffer);
int32_t* qw_ptr = reinterpret_cast<int32_t*>(atom_ptr) + thread;
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + (thread / 8);
int32_t qw = __builtin_nontemporal_load(qw_ptr);
int qs = __builtin_nontemporal_load(qs_ptr);
*recv_buffer += kRankBufferTileStride;
// Unpack q4 into f16x8_t
int32x4_t w;
{
static constexpr uint kMask000F = 0x000F000F;
static constexpr uint kHalf2_1024 = 0x64006400; // {1024.0, 1024.0}, fp16x2_t
static uint constexpr kHalf2_1032 = 0xE408E408; // {-1032.0, -1032.0}, fp16x2_t
for (int i = 0; i < 4; i++) {
if constexpr (std::is_same<T, half>::value) {
int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kHalf2_1024;
w[i] = packed_add<half>(q4, kHalf2_1032);
} else {
int32_t int16_2 = (qw >> (i * 4)) & kMask000F;
int16_t low = static_cast<int16_t>(int16_2 & 0xFFFF);
int16_t high = static_cast<int16_t>((int16_2 >> 16) & 0xFFFF);
nv_bfloat16 bf_low = __float2bfloat16(static_cast<float>(low));
nv_bfloat16 bf_high = __float2bfloat16(static_cast<float>(high));
nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high);
int32_t packed_bf16 = *reinterpret_cast<int32_t*>(&bf2);
w[i] = packed_add<nv_bfloat16>(packed_bf16, kRangeMin);
}
}
}
// Apply decoding scales
for (int i = 0; i < 4; i++) {
w[i] = packed_mul<T>(w[i], qs);
}
data[k] = w;
}
}
};
// Int6 symmetric quantization codec.
// We quantize the FP16 data to block-scaled Int6 in blocks of 4 *
// kThreadGroupSize.
template <typename T, int world_size>
struct CodecQ6 : public CodecBase {
static constexpr int kWorldSize = world_size;
// Codec tile size process by this workgroup.
// Each threads processes a fragment of fp16x8_t (16B),
// into a int6x8_t (4B + 2B) and a fp16 scale shared among 32 values.
static constexpr int kRankAtoms = kAtoms / kWorldSize;
static constexpr int kRankTileStride = 1664;
static constexpr int kRankTileQ2Offset = 1024;
static constexpr int kRankTileScaleOffset = 1536;
static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms;
static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTransmittedTileSize must be 16B aligned.");
static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t);
// Total tile size for the collective communication.
static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize;
// Constants configuration
// {-1/32.0h, -1/32.0h}, fp16x2_t
static constexpr int kScaleFactor = std::is_same<T, half>::value ? 0xA800A800 : 0xBD00BD00;
// {1e-7, 1e-7}, fp16x2_t
static constexpr int kScaleEpsilon = std::is_same<T, half>::value ? 0x00010001 : 0x33D733D7;
// {-32, -32}, fp16x2_t
static constexpr int kRangeMin = std::is_same<T, half>::value ? 0xD000D000 : 0xC200C200;
// {+31, +31}, fp16x2_t
static constexpr int kRangeMax = std::is_same<T, half>::value ? 0x4FC04FC0 : 0x41F841F8;
// {+32, +32}, int16x2_t
static constexpr int kRangeBias = 0x00200020;
__quickreduce_device_inline__ CodecQ6(int thread, int rank) : CodecBase(thread, rank) {}
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, const int32x4_t* __restrict__ data) {
for (int k = 0; k < kRankAtoms; k++) {
int32x4_t const atom = data[k];
// Compute the absolute maximum of the atom in the thread group
// In 2 blocks of values, upper/lower halves of the f16x2_t
int wblockmax = group_abs_max<T>(atom);
// Derive scales
int decoding_scale;
int encoding_scale;
decoding_scale = packed_mul<T>(wblockmax, kScaleFactor);
encoding_scale = packed_add<T>(decoding_scale, kScaleEpsilon);
encoding_scale = packed_rcp<T>(encoding_scale);
// Apply scales to get quantized values
int32x4_t w;
for (int i = 0; i < 4; i++) {
w[i] = packed_mul<T>(atom[i], encoding_scale);
w[i] = packed_max<T>(w[i], kRangeMin);
w[i] = packed_min<T>(w[i], kRangeMax);
}
// Convert from f16x2_t to uint16x2_t
int32x4_t q;
{
int16_t* qi = reinterpret_cast<int16_t*>(&q);
T* wh = reinterpret_cast<T*>(&w);
for (int i = 0; i < 8; i++)
qi[i] = (int16_t)rintf(T2float_cast(wh[i]));
for (int i = 0; i < 4; i++) {
q[i] = packed_add<int16_t>(q[i], kRangeBias);
}
}
// Pack 8 x q6 into int32_t + int16_t
uint32_t q4w;
uint16_t q2w = 0;
q4w = (q[0] & 0x000F000F) | ((q[1] & 0x000F000F) << 4) | ((q[2] & 0x000F000F) << 8) | ((q[3] & 0x000F000F) << 12);
{
int16_t* tw = reinterpret_cast<int16_t*>(&q);
#pragma unroll
for (int i = 0; i < 8; i++) {
q2w |= (tw[i] >> 4) << (i * 2);
}
}
// Write quantized atom to send_buffer
// note: only the group leader stores the scale
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(send_buffer + k * kRankBufferTileStride);
uint32_t* q4w_ptr = reinterpret_cast<uint32_t*>(atom_ptr) + thread;
uint16_t* q2w_ptr = reinterpret_cast<uint16_t*>(atom_ptr + kRankTileQ2Offset) + thread;
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + (thread / 8);
__builtin_nontemporal_store(q4w, q4w_ptr);
__builtin_nontemporal_store(q2w, q2w_ptr);
if (threadIdx.x == group_leader) {
__builtin_nontemporal_store(decoding_scale, qs_ptr);
}
}
}
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, int32x4_t* __restrict__ data) {
for (int k = 0; k < kRankAtoms; k++) {
// Directly read quantized atom from recv_buffer
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(*recv_buffer);
uint32_t* q4w_ptr = reinterpret_cast<uint32_t*>(atom_ptr) + thread;
uint16_t* q2w_ptr = reinterpret_cast<uint16_t*>(atom_ptr + kRankTileQ2Offset) + thread;
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + (thread / 8);
uint32_t q4w = __builtin_nontemporal_load(q4w_ptr);
uint16_t q2w = __builtin_nontemporal_load(q2w_ptr);
int qs = __builtin_nontemporal_load(qs_ptr);
*recv_buffer += kRankBufferTileStride;
// Unpack q6 into fp16x8_t
int32x4_t w;
{
static uint constexpr kMask000F = 0x000F000F;
static uint constexpr kHalf2_1024 = 0x64006400; // {1024.0, 1024.0}, fp16x2_t
static uint constexpr kHalf2_1056 = 0xE420E420; // {-1056.0, -1056.0}, fp16x2_t
#pragma unroll
for (int i = 0; i < 4; i++) {
int32_t q4 = q4w & kMask000F;
int32_t q2 = (q2w & 0x3) | ((q2w & 0xC) << 14);
q4w >>= 4;
q2w >>= 4;
if constexpr (std::is_same<T, half>::value) {
int32_t q6 = q4 | (q2 << 4) | kHalf2_1024;
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(w[i]) : "v"(q6), "v"(kHalf2_1056));
} else {
int32_t int16_2 = q4 | (q2 << 4);
int16_t low = static_cast<int16_t>(int16_2 & 0xFFFF);
int16_t high = static_cast<int16_t>((int16_2 >> 16) & 0xFFFF);
nv_bfloat16 bf_low = __float2bfloat16(static_cast<float>(low));
nv_bfloat16 bf_high = __float2bfloat16(static_cast<float>(high));
nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high);
int32_t packed_bf16 = *reinterpret_cast<int32_t*>(&bf2);
w[i] = packed_add<nv_bfloat16>(packed_bf16, kRangeMin);
}
}
}
// Apply decoding scales
for (int i = 0; i < 4; i++) {
w[i] = packed_mul<T>(w[i], qs);
}
// That's pretty much it...
data[k] = w;
}
}
};
// Int8 symmetric quantization codec.
// We quantize the FP16 data to block-scaled Int8 in blocks of 4 *
// kThreadGroupSize.
template <typename T, int world_size>
struct CodecQ8 : public CodecBase {
static constexpr int kWorldSize = world_size;
// Codec tile size process by this workgroup.
// Each threads processes a fragment of f16x8_t (16B),
// into a int8x8_t (8B) and a f16 scale shared among 32 values.
static constexpr int kRankAtoms = kAtoms / kWorldSize;
static constexpr int kRankTileStride = 2176;
static constexpr int kRankTileScaleOffset = 2048;
static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms;
static_assert(kRankTransmittedTileSize % 16 == 0, "kRankTileSize must be 16B aligned.");
static constexpr int kRankBufferTileStride = kRankTileStride / sizeof(int32x4_t);
// Total tile size for the collective communication.
static constexpr int kTransmittedTileSize = kRankTransmittedTileSize * kWorldSize;
// Constants configuration
// {-1/128.0h, -1/128.0h}, f16x2_t
static constexpr int kScaleFactor = std::is_same<T, half>::value ? 0xA000A000 : 0xBC00BC00;
// {1e-7, 1e-7}, f16x2_t
static constexpr int kScaleEpsilon = std::is_same<T, half>::value ? 0x00010001 : 0x33D733D7;
// {-128, -128}, f16x2_t
static constexpr int kRangeMin = std::is_same<T, half>::value ? 0xD800D800 : 0xC300C300;
// {+127, +127}, f16x2_t
static constexpr int kRangeMax = std::is_same<T, half>::value ? 0x57F057F0 : 0x42FE42FE;
// {+128, +128}, int16x2_t
static constexpr int kRangeBias = 0x00800080;
__quickreduce_device_inline__ CodecQ8(int thread, int rank) : CodecBase(thread, rank) {}
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, int32x4_t const* __restrict__ data) {
for (int k = 0; k < kRankAtoms; k++) {
int32x4_t const atom = data[k];
// Compute the absolute maximum of the atom in the thread group
// In 2 blocks of values, upper/lower halves of the f16x2_t
int wblockmax = group_abs_max<T>(atom);
// Derive scales
int decoding_scale;
int encoding_scale;
decoding_scale = packed_mul<T>(wblockmax, kScaleFactor);
encoding_scale = packed_add<T>(decoding_scale, kScaleEpsilon);
encoding_scale = packed_rcp<T>(encoding_scale);
// Apply scales to get quantized values
int32x4_t w;
for (int i = 0; i < 4; i++) {
w[i] = packed_mul<T>(atom[i], encoding_scale);
w[i] = packed_max<T>(w[i], kRangeMin);
w[i] = packed_min<T>(w[i], kRangeMax);
}
// Convert from f16x2_t to uint16x2_t
int32x4_t q;
{
int16_t* qi = reinterpret_cast<int16_t*>(&q);
T* wh = reinterpret_cast<T*>(&w);
for (int i = 0; i < 8; i++)
qi[i] = (int16_t)rintf(T2float_cast(wh[i]));
for (int i = 0; i < 4; i++) {
q[i] = packed_add<int16_t>(q[i], kRangeBias);
}
}
// Pack 8 x q8 into int32x2_t
int32x2_t qw;
qw[0] = q[0] | (q[1] << 8);
qw[1] = q[2] | (q[3] << 8);
// Write quantized atom to send_buffer
// note: only the group leader stores the scale
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(send_buffer + k * kRankBufferTileStride);
int32x2_t* qw_ptr = reinterpret_cast<int32x2_t*>(atom_ptr) + thread;
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + (thread / 8);
__builtin_nontemporal_store(qw, qw_ptr);
if (threadIdx.x == group_leader) {
__builtin_nontemporal_store(decoding_scale, qs_ptr);
}
}
}
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, int32x4_t* __restrict__ data) {
for (int k = 0; k < kRankAtoms; k++) {
// Directly read quantized atom from recv_buffer
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(*recv_buffer);
int32x2_t* qw_ptr = reinterpret_cast<int32x2_t*>(atom_ptr) + thread;
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + (thread / 8);
int32x2_t qw = __builtin_nontemporal_load(qw_ptr);
int qs = __builtin_nontemporal_load(qs_ptr);
*recv_buffer += kRankBufferTileStride;
// Unpack q8 into fp16x8_t
int32x4_t w;
{
static uint constexpr kMask00FF = 0x00FF00FF;
// {1024.0, 1024.0}, fp16x2_t
static uint constexpr kHalf2_1024 = 0x64006400;
// {-1152.0, -1152.0}, fp16x2_t
static uint constexpr kHalf2_1152 = 0xE480E480;
#pragma unroll
for (int i = 0; i < 4; i++) {
if constexpr (std::is_same<T, half>::value) {
int32_t q8 = ((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kHalf2_1024;
w[i] = packed_add<half>(q8, kHalf2_1152);
} else {
int32_t int16_2 = (qw[i / 2] >> ((i % 2) * 8)) & kMask00FF;
int16_t low = static_cast<int16_t>(int16_2 & 0xFFFF);
int16_t high = static_cast<int16_t>((int16_2 >> 16) & 0xFFFF);
nv_bfloat16 bf_low = __float2bfloat16(static_cast<float>(low));
nv_bfloat16 bf_high = __float2bfloat16(static_cast<float>(high));
nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high);
int32_t packed_bf16 = *reinterpret_cast<int32_t*>(&bf2);
w[i] = packed_add<nv_bfloat16>(packed_bf16, kRangeMin);
}
}
}
// Apply decoding scales
for (int i = 0; i < 4; i++) {
w[i] = packed_mul<T>(w[i], qs);
}
data[k] = w;
}
}
};
// Twoshot All Reduce
template <typename T, class Codec, bool cast_bf2half>
struct AllReduceTwoshot {
static_assert(sizeof(T) == 2);
static constexpr int kWorldSize = Codec::kWorldSize;
__device__ static void
run(T const* __restrict__ input,
T* __restrict__ output,
uint32_t const N, // number of elements
int const block, // block index
int const rank, // rank index
uint8_t** __restrict__ buffer_list, // communication buffers
uint32_t const data_offset, // offset to start of the data buffer
uint32_t flag_color) {
// Topology
int thread = threadIdx.x + threadIdx.y * kWavefront;
uint8_t* rank_buffer = buffer_list[rank];
Codec codec(thread, rank);
int block_id = blockIdx.x;
int grid_size = gridDim.x;
// --------------------------------------------------------
// Read input into registers
int32x4_t tA[kAtoms];
BufferResource src_buffer(const_cast<T*>(input), N * sizeof(T));
uint32_t src_offset = block * kTileSize + thread * sizeof(int32x4_t);
for (int i = 0; i < kAtoms; i++) {
tA[i] = buffer_load_dwordx4(src_buffer.descriptor, src_offset, 0, 0);
src_offset += kAtomStride * sizeof(int32x4_t);
if constexpr (cast_bf2half) {
const nv_bfloat162* bf_buf = reinterpret_cast<const nv_bfloat162*>(&tA[i]);
half2 half_buf[4];
#pragma unroll
for (int j = 0; j < 4; ++j) {
float2 f = __bfloat1622float2(bf_buf[j]);
half_buf[j] = __float22half2_rn(f);
}
tA[i] = *reinterpret_cast<const int32x4_t*>(half_buf);
}
}
// --------------------------------------------------------
// Phase-1A: Write segment data into the communication buffer of the target
// rank responsible for this segment.
uint32_t comm_data0_offset = data_offset + block_id * Codec::kTransmittedTileSize;
uint32_t comm_data1_offset = grid_size * Codec::kTransmittedTileSize + comm_data0_offset;
uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t));
uint32_t comm_flags1_offset = grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset;
for (int r = 0; r < kWorldSize; r++) {
int32x4_t* send_buffer =
reinterpret_cast<int32x4_t*>(buffer_list[r] + comm_data0_offset + rank * Codec::kRankTransmittedTileSize);
codec.send(send_buffer, &tA[r * Codec::kRankAtoms]);
}
__syncthreads();
if (thread < kWorldSize) {
int r = thread;
uint32_t* flag_ptr = reinterpret_cast<uint32_t*>(buffer_list[r] + comm_flags0_offset + rank * sizeof(uint32_t));
set_sync_flag(flag_ptr, flag_color);
}
// --------------------------------------------------------
// Phase-1B: Reduce the segment data from the communication buffers.
int32x4_t tR[Codec::kRankAtoms] = {};
{
// Read the data from the communication buffer.
int32x4_t* recv_buffer = reinterpret_cast<int32x4_t*>(rank_buffer + comm_data0_offset);
uint32_t* flag_ptr = reinterpret_cast<uint32_t*>(rank_buffer + comm_flags0_offset);
for (int r = 0; r < kWorldSize; r++) {
// Wait for the flags to be set.
if (thread == 0) {
wait_sync_flag(&flag_ptr[r], flag_color);
}
__syncthreads();
// note: we reuse tA as temp buffer here
codec.recv(&recv_buffer, tA);
for (int i = 0; i < Codec::kRankAtoms; i++) {
packed_assign_add<T>(&tR[i], &tA[i]);
}
}
}
// Phase-2: Write the reduced segment to every other rank
for (int r = 0; r < kWorldSize; r++) {
int32x4_t* send_buffer =
reinterpret_cast<int32x4_t*>(buffer_list[r] + comm_data1_offset + rank * Codec::kRankTransmittedTileSize);
codec.send(send_buffer, tR);
}
__syncthreads();
if (thread < kWorldSize) {
int r = thread;
uint32_t* flag_ptr = reinterpret_cast<uint32_t*>(buffer_list[r] + comm_flags1_offset + rank * sizeof(uint32_t));
set_sync_flag(flag_ptr, flag_color);
}
// Phase-2: Read the gather segments from the rank's communication buffer.
{
// Read the data from the communication buffer.
int32x4_t* recv_buffer = reinterpret_cast<int32x4_t*>(rank_buffer + comm_data1_offset);
uint32_t* flag_ptr = reinterpret_cast<uint32_t*>(rank_buffer + comm_flags1_offset);
for (int r = 0; r < kWorldSize; r++) {
// Wait for the flags to be set.
if (thread == 0) {
wait_sync_flag(&flag_ptr[r], flag_color);
}
__syncthreads();
// Gather all reduced and final rank segments into tA.
codec.recv(&recv_buffer, &tA[r * Codec::kRankAtoms]);
}
}
// --------------------------------------------------------
// Write the result to output.
BufferResource dst_buffer(output, N * sizeof(T));
uint32_t dst_offset = block * kTileSize + thread * sizeof(int32x4_t);
for (int i = 0; i < kAtoms; i++) {
if constexpr (cast_bf2half) {
const half2* half_buf = reinterpret_cast<const half2*>(&tA[i]);
nv_bfloat162 bf16_buf[4];
#pragma unroll
for (int j = 0; j < 4; ++j) {
float2 f = __half22float2(half_buf[j]);
bf16_buf[j] = __float22bfloat162_rn(f);
}
buffer_store_dwordx4(*reinterpret_cast<const int32x4_t*>(bf16_buf), dst_buffer.descriptor, dst_offset, 0, 0);
} else {
buffer_store_dwordx4(tA[i], dst_buffer.descriptor, dst_offset, 0, 0);
}
dst_offset += kAtomStride * sizeof(int32x4_t);
}
}
};
} // namespace quickreduce

View File

@@ -0,0 +1,233 @@
#pragma once
#include <hip/hip_runtime.h>
#include <vector>
#include "quick_all_reduce.cuh"
#define HIP_CHECK(err) \
do { \
hipError_t err_ = (err); \
if (err_ != hipSuccess) { \
std::printf("HIP error %d at %s:%d. %s\n", err_, __FILE__, __LINE__, hipGetErrorString(err_)); \
throw std::runtime_error("HIP error"); \
} \
} while (0)
namespace quickreduce {
using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));
template <typename AllReduceKernel, typename T>
__global__ __quickreduce_launch_bounds_two_shot__ static void allreduce_prototype_twoshot(
T const* A,
T* B,
uint32_t N,
uint32_t num_blocks,
int rank,
uint8_t** dbuffer_list,
uint32_t data_offset,
uint32_t flag_color) {
int block = blockIdx.x;
int grid = gridDim.x;
while (block < num_blocks) {
AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset, flag_color);
block += grid;
flag_color++;
}
}
#define TWOSHOT_DISPATCH(__codec) \
if (world_size == 2) { \
using LineCodec = __codec<T, 2>; \
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
hipLaunchKernelGGL( \
(allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), \
dim3(kBlockTwoShot), \
0, \
stream, \
A, \
B, \
N, \
num_blocks, \
rank, \
dbuffer_list, \
data_offset, \
flag_color); \
} else if (world_size == 4) { \
using LineCodec = __codec<T, 4>; \
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
hipLaunchKernelGGL( \
(allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), \
dim3(kBlockTwoShot), \
0, \
stream, \
A, \
B, \
N, \
num_blocks, \
rank, \
dbuffer_list, \
data_offset, \
flag_color); \
} else if (world_size == 8) { \
using LineCodec = __codec<T, 8>; \
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
hipLaunchKernelGGL( \
(allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), \
dim3(kBlockTwoShot), \
0, \
stream, \
A, \
B, \
N, \
num_blocks, \
rank, \
dbuffer_list, \
data_offset, \
flag_color); \
}
enum QuickReduceQuantLevel {
F16 = 0,
INT8 = 1,
INT6 = 2,
INT4 = 3,
};
struct DeviceComms {
// Max problem size is 2GB (in bytes) or half of uint32_t max value.
int64_t kMaxProblemSize = static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1;
// Max TP-8
static int constexpr kMaxWorldSize = 8;
bool initialized = false;
uint32_t flag_color = 1;
int world_size;
int rank;
uint8_t* dbuffer;
uint8_t** dbuffer_list;
hipIpcMemHandle_t buffer_ipc_handle;
std::vector<hipIpcMemHandle_t> all_buffer_ipc_handles;
std::vector<uint8_t*> buffer_list;
uint32_t data_offset;
DeviceComms() : initialized(false), world_size(1), rank(0) {}
~DeviceComms() {
destroy();
}
void init(int world_size, int rank, std::optional<int64_t> max_problem_size = std::nullopt) {
destroy();
this->world_size = world_size;
this->rank = rank;
if (max_problem_size.has_value() && max_problem_size.value() > 0) {
this->kMaxProblemSize = max_problem_size.value();
}
// Allocate buffer size for worst case: F16 2-stage buffer.
uint32_t flags_buffer_size = 2 * world_size * kMaxNumBlocks * sizeof(uint32_t);
static int64_t data_buffer_size = 2 * this->kMaxProblemSize;
int64_t total_buffer_size = flags_buffer_size + data_buffer_size;
data_offset = flags_buffer_size;
HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size, hipDeviceMallocUncached));
// Clear the flags buffer.
HIP_CHECK(hipMemset(dbuffer, 0, flags_buffer_size));
// Device-side list of IPC buffers.
buffer_list.resize(world_size);
HIP_CHECK(hipMalloc(&dbuffer_list, world_size * sizeof(uint8_t*)));
// Create IPC handles for rank's communication buffer.
all_buffer_ipc_handles.resize(world_size);
HIP_CHECK(hipIpcGetMemHandle(&buffer_ipc_handle, dbuffer));
initialized = true;
}
int get_world_size() {
return world_size;
}
int get_rank() {
return rank;
}
bool status() {
return initialized;
}
hipIpcMemHandle_t const get_handle() {
return buffer_ipc_handle;
}
void destroy() {
if (initialized) {
for (int i = 0; i < world_size; i++) {
if (i != rank) {
HIP_CHECK(hipIpcCloseMemHandle(dbuffer_list[i]));
}
}
HIP_CHECK(hipFree(dbuffer));
HIP_CHECK(hipFree(dbuffer_list));
initialized = false;
}
}
void open_ipc_handles(std::vector<hipIpcMemHandle_t> const& ipc_handles) {
assert(ipc_handles.size() == all_buffer_ipc_handles.size());
for (int i = 0; i < world_size; i++) {
all_buffer_ipc_handles[i] = ipc_handles[i];
}
// Open device memory access to the IPC communication buffers.
// Note: For our own rank, we do not need to open a handle.
for (int i = 0; i < world_size; i++) {
if (i != rank) {
HIP_CHECK(
hipIpcOpenMemHandle((void**)&buffer_list[i], all_buffer_ipc_handles[i], hipIpcMemLazyEnablePeerAccess));
} else {
buffer_list[i] = dbuffer;
}
}
HIP_CHECK(hipMemcpy(dbuffer_list, buffer_list.data(), world_size * sizeof(uint8_t*), hipMemcpyHostToDevice));
}
template <typename T, bool cast_bf2half>
void allreduce(T const* A, T* B, uint32_t N, int quant_level, hipStream_t stream) {
if (world_size != 2 && world_size != 4 && world_size != 8) {
throw std::runtime_error("All Reduce not supported for world_size = " + std::to_string(world_size));
}
// Configuration.
uint32_t msg_size = N * sizeof(T);
uint32_t num_blocks = divceil(msg_size, kTileSize);
uint32_t grid = min(kMaxNumBlocks, num_blocks);
auto quant_level_ = static_cast<QuickReduceQuantLevel>(quant_level);
switch (quant_level_) {
case QuickReduceQuantLevel::INT8:
TWOSHOT_DISPATCH(CodecQ8)
break;
case QuickReduceQuantLevel::INT6:
TWOSHOT_DISPATCH(CodecQ6)
break;
case QuickReduceQuantLevel::INT4:
TWOSHOT_DISPATCH(CodecQ4)
break;
default:
TWOSHOT_DISPATCH(CodecFP)
break;
}
HIP_CHECK(cudaGetLastError());
// Rotate the flag color.
flag_color += divceil(N, grid);
}
};
} // namespace quickreduce

View File

@@ -0,0 +1,318 @@
#pragma once
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <cstdint>
#define __quickreduce_device_inline__ __device__ __forceinline__
#define __quickreduce_launch_bounds_two_shot__ __launch_bounds__(256, 4)
#define __quickreduce_launch_bounds_one_shot__ __launch_bounds__(512, 4)
namespace quickreduce {
typedef __hip_bfloat16 nv_bfloat16;
typedef __hip_bfloat162 nv_bfloat162;
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
// Setup acquire-release semantics for vector memory reads (mubuf instruction)
// as per architecture.
#if defined(__gfx942__)
// CDNA3: Scope bits sc0, sc1
#define MUBUF_ACQUIRE 16
#define MUBUF_RELEASE 16
#elif (defined(__gfx908__) || defined(__gfx90a__))
// CDNA1 and CDNA2 - glc bit
#define MUBUF_ACQUIRE 1
#define MUBUF_RELEASE 0
#endif
static constexpr int kNegOne = 0xBC00BC00; // {-1, -1}, fp16x2_t
// Number of atoms (4xf16x2_t) processed by a single thread
static constexpr int kAtoms = 8;
// We use a workgroup of 256 threads
static constexpr int kBlockSize = 256;
static constexpr int kAtomStride = kBlockSize;
// Size and atom stride of source/destination data that the block will
// process.
// Workgroup scope = Tile = (256 threads x 8 atoms x 16B)
static constexpr int kTileSize = kBlockSize * kAtoms * sizeof(int32x4_t);
// Max number of blocks. 304 CUs on MI300
static constexpr int kMaxNumBlocks = 304 * 4;
// Standard CDNA wavefront size.
static constexpr int kWavefront = 64;
// 256 thread, 4 wavefronts.
static dim3 constexpr kBlockTwoShot = {kWavefront, kBlockSize / kWavefront, 1};
// Number of threads in a group for quantization
// It corresponds to 32 F16 elements in quantization block
static constexpr int kThreadGroupSize = 8;
// Methods
__quickreduce_device_inline__ __host__ unsigned long divceil(unsigned long x, unsigned long y) {
return ((x + y - 1) / y);
}
union BufferResource {
__quickreduce_device_inline__ constexpr BufferResource() : config(0x00020000U) {}
__quickreduce_device_inline__ constexpr BufferResource(void* buffer_address, uint32_t buffer_size)
: address(buffer_address), range(buffer_size), config(0x00020000U) {}
int32x4_t descriptor;
struct {
void* address; // 8B, out of which first 48b is address, and 16b is stride
// (unused)
uint32_t range; // Byte range for the buffer resource
uint32_t config; // Constant, DFMT=32b
};
};
__quickreduce_device_inline__ static int32x4_t buffer_load_dwordx4(
int32x4_t srsrc, int32_t voffset, int32_t soffset, int32_t aux) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
__quickreduce_device_inline__ static void
buffer_store_dwordx4(int32x4_t data, int32x4_t srsrc, int32_t voffset, int32_t soffset, int32_t aux) __asm(
"llvm.amdgcn.raw.buffer.store.v4i32");
__quickreduce_device_inline__ static void set_fp16_ovfl(bool const value) {
#if defined(__gfx942__)
if (value) {
asm volatile("s_setreg_imm32_b32 0xdc1, 1;" ::);
} else {
asm volatile("s_setreg_imm32_b32 0xdc1, 0;" ::);
}
#endif
}
union bf162_int_union {
int i;
nv_bfloat162 bf2;
};
template <typename T>
__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, int32x4_t* B);
template <>
__quickreduce_device_inline__ void packed_assign_add<half>(int32x4_t* A, int32x4_t* B) {
int32x4_t& tR_fragment = A[0];
int32x4_t& tA_fragment = B[0];
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[0]) : "v"(tR_fragment[0]), "v"(tA_fragment[0]));
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[1]) : "v"(tR_fragment[1]), "v"(tA_fragment[1]));
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[2]) : "v"(tR_fragment[2]), "v"(tA_fragment[2]));
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(tR_fragment[3]) : "v"(tR_fragment[3]), "v"(tA_fragment[3]));
}
template <>
__quickreduce_device_inline__ void packed_assign_add<nv_bfloat16>(int32x4_t* A, int32x4_t* B) {
nv_bfloat162* tA = reinterpret_cast<nv_bfloat162*>(A);
nv_bfloat162* tB = reinterpret_cast<nv_bfloat162*>(B);
#pragma unroll
for (int i = 0; i < 4; i++) {
tA[i] = __hadd2(tA[i], tB[i]);
}
}
template <typename T>
__quickreduce_device_inline__ int packed_max(int a, int b);
template <>
__quickreduce_device_inline__ int packed_max<half>(int a, int b) {
int result;
asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
return result;
}
template <>
__quickreduce_device_inline__ int packed_max<nv_bfloat16>(int a, int b) {
bf162_int_union A, B, R;
A.i = a;
B.i = b;
R.bf2 = __hmax2(A.bf2, B.bf2);
return R.i;
}
template <typename T>
__quickreduce_device_inline__ int packed_min(int a, int b);
template <>
__quickreduce_device_inline__ int packed_min<half>(int a, int b) {
int result;
asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
return result;
}
template <>
__quickreduce_device_inline__ int packed_min<nv_bfloat16>(int a, int b) {
bf162_int_union A, B, R;
A.i = a;
B.i = b;
R.bf2 = __hmin2(A.bf2, B.bf2);
return R.i;
}
template <typename T>
__quickreduce_device_inline__ int packed_abs_max(int a, int b);
template <>
__quickreduce_device_inline__ int packed_abs_max<half>(int a, int b) {
half2 wmaxh2 = __builtin_bit_cast(half2, a);
half2 wminh2 = __builtin_bit_cast(half2, b);
half2 wblockmaxh2;
wblockmaxh2.x = __hgt(__habs(wmaxh2.x), __habs(wminh2.x)) ? wmaxh2.x : wminh2.x;
wblockmaxh2.y = __hgt(__habs(wmaxh2.y), __habs(wminh2.y)) ? wmaxh2.y : wminh2.y;
return __builtin_bit_cast(int, wblockmaxh2);
}
template <>
__quickreduce_device_inline__ int packed_abs_max<nv_bfloat16>(int a, int b) {
bf162_int_union A, B, R;
A.i = a;
B.i = b;
R.bf2.x = __hgt(__habs(A.bf2.x), __habs(B.bf2.x)) ? A.bf2.x : B.bf2.x;
R.bf2.y = __hgt(__habs(A.bf2.y), __habs(B.bf2.y)) ? A.bf2.y : B.bf2.y;
return R.i;
}
template <typename T>
__quickreduce_device_inline__ int packed_add(int a, int b);
template <>
__quickreduce_device_inline__ int packed_add<half>(int a, int b) {
int result;
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
return result;
}
template <>
__quickreduce_device_inline__ int packed_add<nv_bfloat16>(int a, int b) {
bf162_int_union A, B, R;
A.i = a;
B.i = b;
R.bf2 = __hadd2(A.bf2, B.bf2);
return R.i;
}
template <>
__quickreduce_device_inline__ int packed_add<int16_t>(int a, int b) {
int result;
asm volatile("v_pk_add_i16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
return result;
}
template <typename T>
__quickreduce_device_inline__ int packed_sub(int a, int b);
template <>
__quickreduce_device_inline__ int packed_sub<half>(int a, int b) {
int result;
// MI300 lacks packed fp16 sub instruction. So we do -1 * min + max
asm volatile("v_pk_fma_f16 %0, %1, %2 %3" : "=v"(result) : "v"(kNegOne), "v"(b), "v"(a));
return result;
}
template <>
__quickreduce_device_inline__ int packed_sub<nv_bfloat16>(int a, int b) {
bf162_int_union A, B, R;
A.i = a;
B.i = b;
R.bf2 = __hsub2(A.bf2, B.bf2);
return R.i;
}
template <typename T>
__quickreduce_device_inline__ int packed_mul(int a, int b);
template <>
__quickreduce_device_inline__ int packed_mul<half>(int a, int b) {
int result;
asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
return result;
}
template <>
__quickreduce_device_inline__ int packed_mul<nv_bfloat16>(int a, int b) {
nv_bfloat162* tA = reinterpret_cast<nv_bfloat162*>(&a);
nv_bfloat162* tB = reinterpret_cast<nv_bfloat162*>(&b);
nv_bfloat162 tR = __hmul2(*tA, *tB);
return *(reinterpret_cast<int*>(&tR));
}
template <typename T>
__quickreduce_device_inline__ int packed_rcp(int a);
template <>
__quickreduce_device_inline__ int packed_rcp<half>(int a) {
return __builtin_bit_cast(int, h2rcp(__builtin_bit_cast(half2, a)));
}
template <>
__quickreduce_device_inline__ int packed_rcp<nv_bfloat16>(int a) {
bf162_int_union A, R;
A.i = a;
R.bf2 = h2rcp(A.bf2);
return R.i;
}
// changes dtype
__quickreduce_device_inline__ float T2float_cast(half a) {
return __half2float(a);
}
__quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) {
return __bfloat162float(a);
}
template <typename T>
__quickreduce_device_inline__ int group_abs_max(int32x4_t atom) {
const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize;
int wmax, wmin, wblockmax;
int a, b;
a = packed_max<T>(atom[0], atom[1]);
b = packed_max<T>(atom[2], atom[3]);
wmax = packed_max<T>(a, b);
a = packed_min<T>(atom[0], atom[1]);
b = packed_min<T>(atom[2], atom[3]);
wmin = packed_min<T>(a, b);
// Reduce the max among a group of threads
// Note: This is basically 2 blocks of values setup as the
// upper/lower halves of the f16x2_t
for (int i = 1; i < kThreadGroupSize; i <<= 1) {
int x = __shfl_down(wmax, i);
wmax = packed_max<T>(wmax, x);
int y = __shfl_down(wmin, i);
wmin = packed_min<T>(wmin, y);
}
wblockmax = packed_abs_max<T>(wmax, wmin);
// Share with the cohort
wblockmax = __shfl(wblockmax, group_leader);
return wblockmax;
}
__quickreduce_device_inline__ void set_sync_flag(uint32_t* flag_ptr, uint32_t flag) {
__atomic_store_n(flag_ptr, flag, __ATOMIC_RELEASE);
}
__quickreduce_device_inline__ void wait_sync_flag(uint32_t* flag_ptr, uint32_t flag) {
while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag) {
}
}
} // namespace quickreduce

View File

@@ -0,0 +1,153 @@
/*
* this file is used to test mscclpp_allreduce.cu using mpirun
* this file is adapted from https://github.com/flashinfer-ai/flashinfer/blob/v0.2.5/src/test_sum_all_reduce.cu
usage:
cd PATH-TO-THIS-FILE
export MPI_HOME=/usr/local/mpi
# export MPI_HOME=/opt/hpcx/ompi/
export MSCCLPP_HOME=/workspace/test/mscclpp
nvcc -O2 -arch=native -std=c++17 test_mscclpp_allreduce.cu \
-o test_mscclpp_allreduce -D_GLIBCXX_USE_CXX11_ABI=0 \
-I${MSCCLPP_HOME}/include -L${MSCCLPP_HOME}/build -lmscclpp \
-lnccl -I${MPI_HOME}/include -L${MPI_HOME}/lib -lmpi
/opt/hpcx/ompi/bin/
mpirun --allow-run-as-root -H 127.0.0.1:8 -np 8 \
--map-by ppr:8:node \
--mca btl_openib_warn_no_device_params_found 0 \
--mca btl_tcp_if_include bond0 \
--allow-run-as-root -np 8 \
-x NCCL_RUNTIME_CONNECT=0 -x NCCL_IB_GID_INDEX=3 -x NCCL_DEBUG=WARN \
-x LD_PRELOAD=${MSCCLPP_HOME}/build/libmscclpp.so ./test_mscclpp_allreduce
*/
#include <mpi.h>
#include <thrust/detail/raw_pointer_cast.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#ifndef CHECK_CUDA_SUCCESS
#define CHECK_CUDA_SUCCESS(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
#endif
#include <cstdint>
#include "mscclpp_allreduce.cuh"
template <typename T>
bool isclose(T a, T b, float rtol = 1e-5, float atol = 1e-8) {
return fabs(a - b) <= (atol + rtol * fabs(b));
}
int main(int argc, char* argv[]) {
// init mpi
MPI_Init(&argc, &argv);
printf("MPI Initialized.\n");
int nranks, rank;
// get work size and rank id
MPI_Comm_size(MPI_COMM_WORLD, &nranks);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
cudaSetDevice(rank);
printf("nranks: %d, rank: %d\n", nranks, rank);
// init host and device buffers
using T = float;
using ReduceT = float;
const size_t num_elems = 2 * 1024 * 1024;
std::vector<T> host_buf(num_elems);
for (uint32_t i = 0; i < num_elems; ++i) {
host_buf[i] = T(i + rank);
}
thrust::device_vector<T> device_buf(host_buf);
const size_t buf_size_in_bytes = num_elems * sizeof(T);
std::vector<T> host_result_buf(num_elems);
thrust::device_vector<T> device_result_buf(host_result_buf);
std::vector<T> host_scratch_buf(num_elems * 8);
for (uint32_t i = 0; i < num_elems; ++i) {
host_scratch_buf[i] = 1;
}
thrust::device_vector<T> device_scratch_buf(host_scratch_buf);
std::vector<T> host_put_buf(num_elems);
thrust::device_vector<T> device_put_buf(host_put_buf);
mscclpp::UniqueId unique_id;
if (rank == 0) unique_id = mscclpp::TcpBootstrap::createUniqueId();
MPI_Bcast(&unique_id, sizeof(unique_id), MPI_BYTE, 0, MPI_COMM_WORLD);
std::vector<int64_t> rank_to_node(nranks);
std::vector<int64_t> rank_to_ib(nranks);
for (int i = 0; i < nranks; i++) {
rank_to_node[i] = i / 8;
rank_to_ib[i] = i % 8;
}
cudaStream_t s;
CHECK_CUDA_SUCCESS(cudaStreamCreate(&s));
CHECK_CUDA_SUCCESS(cudaStreamSynchronize(s));
if (nranks == 8) {
auto context = std::make_shared<sglang::Msccl1NodeLLcontext>(
unique_id,
rank,
nranks,
thrust::raw_pointer_cast(device_scratch_buf.data()),
buf_size_in_bytes * 8,
rank_to_node,
rank_to_ib);
printf("rank: %d, Msccl1NodeLLcontext setup.\n", rank);
MPI_Barrier(MPI_COMM_WORLD);
context->allreduce<T>(
s,
thrust::raw_pointer_cast(device_buf.data()),
thrust::raw_pointer_cast(device_result_buf.data()),
device_buf.size());
} else if (nranks == 16) {
// TODO: this branch is untested since there is something wrong with mpirun in my test machince
auto context = std::make_shared<sglang::Msccl2NodeLLcontext>(
unique_id,
rank,
nranks,
thrust::raw_pointer_cast(device_scratch_buf.data()),
buf_size_in_bytes * 8,
thrust::raw_pointer_cast(device_put_buf.data()),
buf_size_in_bytes,
rank_to_node,
rank_to_ib);
printf("rank: %d, Msccl2NodeLLcontext setup.\n", rank);
MPI_Barrier(MPI_COMM_WORLD);
context->allreduce<T>(
s,
thrust::raw_pointer_cast(device_buf.data()),
thrust::raw_pointer_cast(device_result_buf.data()),
device_buf.size());
}
// check result correctness
thrust::host_vector<T> host_buf_result = device_result_buf;
size_t num_results_error_atol_1e_3_rtol_1e_3 = 0;
bool nan_detected = false;
for (uint32_t i = 0; i < num_elems; ++i) {
T expected = T(i * nranks + (nranks - 1) * nranks / 2);
if (std::isnan(float(host_buf_result[i]))) {
nan_detected = true;
}
if (!isclose(float(host_buf_result[i]), float(expected), 1e-3, 1e-3)) {
num_results_error_atol_1e_3_rtol_1e_3++;
}
}
float result_accuracy = 1. - float(num_results_error_atol_1e_3_rtol_1e_3) / float(num_elems);
printf("rank: %d, nan_detected: %d accuracy: %f\n", rank, nan_detected, result_accuracy);
CHECK_CUDA_SUCCESS(cudaStreamDestroy(s));
MPI_Finalize();
return 0;
}

View File

@@ -0,0 +1,55 @@
// Adapted from
// https://github.com/flashinfer-ai/flashinfer/blob/55576c626421b5ee7e7ebe74afd26465c8ae863f/csrc/cascade.cu
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <flashinfer/attention/cascade.cuh>
#include "pytorch_extension_utils.h"
using namespace flashinfer;
void merge_state(
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged) {
CHECK_INPUT(v_a);
CHECK_INPUT(s_a);
CHECK_INPUT(v_b);
CHECK_INPUT(s_b);
auto device = v_a.device();
CHECK_EQ(s_a.device(), device);
CHECK_EQ(v_b.device(), device);
CHECK_EQ(s_b.device(), device);
CHECK_DIM(3, v_a);
CHECK_DIM(2, s_a);
CHECK_DIM(3, v_b);
CHECK_DIM(2, s_b);
CHECK_SHAPE(v_a, v_b);
CHECK_SHAPE(s_a, s_b);
CHECK_EQ(v_a.size(0), s_a.size(0));
CHECK_EQ(v_a.size(1), s_b.size(1));
unsigned int seq_len = v_a.size(0);
unsigned int num_heads = v_a.size(1);
unsigned int head_dim = v_a.size(2);
const c10::cuda::OptionalCUDAGuard device_guard(v_a.device());
auto stream = at::cuda::getCurrentCUDAStream();
bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(v_a.scalar_type(), c_type, [&] {
cudaError_t status = MergeState(
static_cast<c_type*>(v_a.data_ptr()),
static_cast<float*>(s_a.data_ptr()),
static_cast<c_type*>(v_b.data_ptr()),
static_cast<float*>(s_b.data_ptr()),
static_cast<c_type*>(v_merged.data_ptr()),
static_cast<float*>(s_merged.data_ptr()),
seq_len,
num_heads,
head_dim,
stream);
TORCH_CHECK(status == cudaSuccess, "MergeState kernel launch failed: ", cudaGetErrorString(status));
return true;
});
TORCH_CHECK(success, "MergeState kernel launch failed: unsupported data type");
}

View File

@@ -0,0 +1,269 @@
/*
Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cutlass/cutlass.h>
#include <cutlass/kernel_hardware_info.h>
#include <torch/all.h>
#include <cute/tensor.hpp>
#include <iostream>
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
// clang-format off
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
void cutlass_mla_decode(
torch::Tensor const& out,
torch::Tensor const& q_nope,
torch::Tensor const& q_pe,
torch::Tensor const& kv_c_and_k_pe_cache,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table,
torch::Tensor const& workspace,
int64_t num_kv_splits) {
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode");
}
int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_get_workspace_size");
}
#else
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \
}
using namespace cute;
using namespace cutlass::fmha::kernel;
template <bool v>
struct IsPersistent {
static const bool value = v;
};
template <typename T, bool IsPaged128, typename PersistenceOption = IsPersistent<true>>
struct MlaSm100 {
using Element = T;
using ElementAcc = float;
using ElementOut = T;
using TileShape = Shape<_128, _128, Shape<_512, _64>>;
using TileShapeH = cute::tuple_element_t<0, TileShape>;
using TileShapeD = cute::tuple_element_t<2, TileShape>;
// H K (D_latent D_rope) B
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
using StrideO = StrideK; // H D B
using StrideLSE = cute::tuple<_1, int>; // H B
using TileScheduler =
std::conditional_t<PersistenceOption::value, Sm100MlaPersistentTileScheduler, Sm100MlaIndividualTileScheduler>;
using FmhaKernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
TileShape,
Element,
ElementAcc,
ElementOut,
ElementAcc,
TileScheduler,
/*kIsCpAsync=*/!IsPaged128>;
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
};
template <typename T>
typename T::Fmha::Arguments args_from_options(
at::Tensor const& out,
at::Tensor const& q_nope,
at::Tensor const& q_pe,
at::Tensor const& kv_c_and_k_pe_cache,
at::Tensor const& seq_lens,
at::Tensor const& page_table,
double sm_scale,
int64_t num_kv_splits) {
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = q_nope.device().index();
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
int batches = q_nope.size(0);
int page_count_per_seq = page_table.size(1);
int page_count_total = kv_c_and_k_pe_cache.size(0);
int page_size = kv_c_and_k_pe_cache.size(1);
int max_seq_len = page_size * page_count_per_seq;
using TileShapeH = typename T::TileShapeH;
using TileShapeD = typename T::TileShapeD;
auto problem_shape = cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches);
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
float scale = float(sm_scale);
using StrideQ = typename T::StrideQ;
using StrideK = typename T::StrideK;
using StrideO = typename T::StrideO;
using StrideLSE = typename T::StrideLSE;
StrideQ stride_Q_nope = cute::make_tuple(
static_cast<int64_t>(q_nope.stride(1)), _1{}, static_cast<int64_t>(q_nope.stride(0)));
StrideQ stride_Q_pe = cute::make_tuple(
static_cast<int64_t>(q_pe.stride(1)), _1{}, static_cast<int64_t>(q_pe.stride(0)));
StrideK stride_C = cute::make_tuple(
static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(page_size * (D_latent + D_rope)));
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
StrideLSE stride_LSE = cute::make_tuple(_1{}, 0 + H);
StrideO stride_O = cute::make_tuple(static_cast<int64_t>(0 + D_latent), _1{}, static_cast<int64_t>(0 + H * D_latent));
using Element = typename T::Element;
using ElementOut = typename T::ElementOut;
using ElementAcc = typename T::ElementAcc;
auto Q_nope_ptr = static_cast<Element*>(q_nope.data_ptr());
auto Q_pe_ptr = static_cast<Element*>(q_pe.data_ptr());
auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
typename T::Fmha::Arguments arguments{
problem_shape,
{scale,
Q_nope_ptr,
stride_Q_nope,
Q_pe_ptr,
stride_Q_pe,
C_ptr,
stride_C,
C_ptr + D_latent,
stride_C,
static_cast<int*>(seq_lens.data_ptr()),
static_cast<int*>(page_table.data_ptr()),
stride_PT,
page_count_total,
page_size},
{static_cast<ElementOut*>(out.data_ptr()), stride_O, static_cast<ElementAcc*>(nullptr), stride_LSE},
hw_info,
// TODO(trevor-m): Change split_kv back to -1 when
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
// perform worse with larger context length and smaller batch sizes.
static_cast<int>(num_kv_splits), // split_kv
nullptr, // is_var_split_kv
};
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
// split_kv automatically based on batch size and sequence length to balance
// workload across available SMs. Consider using var_split_kv for manual
// control if needed.
T::Fmha::set_split_kv(arguments);
return arguments;
}
template <typename Element, bool IsPaged128, typename PersistenceOption>
void runMla(
at::Tensor const& out,
at::Tensor const& q_nope,
at::Tensor const& q_pe,
at::Tensor const& kv_c_and_k_pe_cache,
at::Tensor const& seq_lens,
at::Tensor const& page_table,
at::Tensor const& workspace,
double sm_scale,
int64_t num_kv_splits,
cudaStream_t stream) {
using MlaSm100Type = MlaSm100<Element, IsPaged128, PersistenceOption>;
typename MlaSm100Type::Fmha fmha;
auto arguments = args_from_options<MlaSm100Type>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits);
CUTLASS_CHECK(fmha.can_implement(arguments));
CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream));
CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream));
}
#define DISPATCH_BOOL(expr, const_expr, ...) \
[&]() -> bool { \
if (expr) { \
constexpr bool const_expr = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool const_expr = false; \
return __VA_ARGS__(); \
} \
}()
void cutlass_mla_decode(
torch::Tensor const& out,
torch::Tensor const& q_nope,
torch::Tensor const& q_pe,
torch::Tensor const& kv_c_and_k_pe_cache,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table,
torch::Tensor const& workspace,
double sm_scale,
int64_t num_kv_splits) {
auto in_dtype = q_nope.dtype();
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device());
const int page_size = kv_c_and_k_pe_cache.size(1);
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
// Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8)
// Maybe per batch split kv will fix this.
DISPATCH_BOOL(page_size == 128, IsPaged128, [&] {
DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] {
if (in_dtype == at::ScalarType::Half) {
runMla<cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
} else if (in_dtype == at::ScalarType::BFloat16) {
runMla<cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
runMla<cutlass::float_e4m3_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
} else {
TORCH_CHECK(false, "Unsupported input data type of MLA");
}
return true;
});
return true;
});
}
int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
// Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc)
// which are float, so Element type here doesn't matter.
using MlaSm100Type = MlaSm100<cutlass::half_t, true>;
// Get split kv. Requires problem shape and sm_count only.
typename MlaSm100Type::Fmha::Arguments arguments;
using TileShapeH = typename MlaSm100Type::TileShapeH;
using TileShapeD = typename MlaSm100Type::TileShapeD;
arguments.problem_shape =
cute::make_tuple(TileShapeH{}, static_cast<int>(max_seq_len), TileShapeD{}, static_cast<int>(num_batches));
// Assumes device 0 when getting sm_count.
arguments.hw_info.sm_count =
sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count;
arguments.split_kv = static_cast<int>(num_kv_splits);
MlaSm100Type::Fmha::set_split_kv(arguments);
return MlaSm100Type::Fmha::get_workspace_size(arguments);
}
#endif
// clang-format on

View File

@@ -0,0 +1,358 @@
/***************************************************************************************************
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief An universal device layer for cutlass 3.x-style kernels.
*/
// clang-format off
#pragma once
// common
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#if !defined(__CUDACC_RTC__)
#include "cutlass/cluster_launch.hpp"
#include "cutlass/trace.h"
#endif // !defined(__CUDACC_RTC__)
#include "../kernel/sm100_fmha_mla_tma_warpspecialized.hpp"
#include "../kernel/sm100_fmha_mla_reduction.hpp"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass::fmha::device {
using namespace cute;
using namespace cutlass::fmha::kernel;
////////////////////////////////////////////////////////////////////////////////
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template<
class Kernel_
>
class MLA {
public:
using Kernel = Kernel_;
using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel<
typename Kernel::ElementOut,
typename Kernel::ElementAcc,
typename Kernel::ElementAcc,
Kernel::TileShapeH::value,
Kernel::TileShapeL::value,
256 /*Max split*/
>;
/// Argument structure: User API
using KernelArguments = typename Kernel::Arguments;
using ReductionArguments = typename ReductionKernel::Arguments;
using Arguments = KernelArguments;
/// Argument structure: Kernel API
using KernelParams = typename Kernel::Params;
using ReductionParams = typename ReductionKernel::Params;
struct Params {
KernelParams fmha_params;
ReductionParams reduction_params;
};
private:
/// Kernel API parameters object
Params params_;
bool is_initialized(bool set = false) {
static bool initialized = false;
if (set) initialized = true;
return initialized;
}
static ReductionArguments to_reduction_args(Arguments const& args) {
auto [H, K, D, B] = args.problem_shape;
return ReductionArguments{
nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse,
args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq,
args.ptr_split_kv, Kernel::TileShapeS::value
};
}
public:
/// Access the Params structure
Params const& params() const {
return params_;
}
static void set_split_kv (KernelArguments& args) {
if (args.split_kv >= 1) return;
auto [H, K, D, B] = args.problem_shape;
int sm_count = args.hw_info.sm_count;
int max_splits = ceil_div(K, 128);
int sms_per_batch = max(1, sm_count / B);
int split_heur = min(max_splits, sms_per_batch);
int waves = ceil_div(B * split_heur, sm_count);
int k_waves = ceil_div(max_splits, split_heur);
int split_wave_aware = ceil_div(max_splits, k_waves);
args.split_kv = split_wave_aware;
}
/// Determines whether the GEMM can execute the given problem.
static Status
can_implement(Arguments const& args) {
if (! Kernel::can_implement(args)) {
return Status::kInvalid;
}
if (! ReductionKernel::can_implement(to_reduction_args(args))) {
return Status::kInvalid;
}
return Status::kSuccess;
}
/// Gets the workspace size
static size_t
get_workspace_size(Arguments const& args) {
size_t workspace_bytes = 0;
workspace_bytes += Kernel::get_workspace_size(args);
workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args));
return workspace_bytes;
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()");
int max_active_blocks = -1;
int smem_size = Kernel::SharedStorageSize;
// first, account for dynamic smem capacity if needed
cudaError_t result;
if (smem_size >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
result = cudaFuncSetAttribute(
device_kernel<Kernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaFuncSetAttribute() returned error: "
<< cudaGetErrorString(result));
return -1;
}
}
// query occupancy after setting smem size
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks,
device_kernel<Kernel>,
Kernel::MaxThreadsPerBlock,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
<< cudaGetErrorString(result));
return -1;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
/// Initializes GEMM state from arguments.
Status
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("MLA::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
// Initialize the workspace
Status status = Kernel::initialize_workspace(args, workspace, stream);
if (status != Status::kSuccess) {
return status;
}
status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream);
if (status != Status::kSuccess) {
return status;
}
KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace);
ReductionArguments reduction_args = to_reduction_args(args);
if (reduction_args.split_kv > 1) {
reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc;
reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc;
}
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
// Initialize the Params structure
params_ = Params {kernel_params, reduction_params};
if (is_initialized()) return Status::kSuccess;
// account for dynamic smem capacity if needed
// no dynamic smem is needed for reduction kernel
int smem_size = Kernel::SharedStorageSize;
if (smem_size >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
cudaError_t result = cudaFuncSetAttribute(
device_kernel<Kernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
is_initialized(true);
return Status::kSuccess;
}
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
Status
update(Arguments const& args, void* workspace = nullptr) {
CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace);
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes > 0 && nullptr == workspace) {
return Status::kErrorWorkspaceNull;
}
auto fmha_params = Kernel::to_underlying_arguments(args, workspace);
ReductionArguments reduction_args = to_reduction_args(args);
if (reduction_args.split_kv > 1) {
reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc;
reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc;
}
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
// Initialize the Params structure
params_ = Params {fmha_params, reduction_params};
return Status::kSuccess;
}
/// Primary run() entry point API that is static allowing users to create and manage their own params.
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
static Status
run(Params& params, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("MLA::run()");
dim3 const block = Kernel::get_block_shape();
dim3 const grid = Kernel::get_grid_shape(params.fmha_params);
// configure smem size and carveout
int smem_size = Kernel::SharedStorageSize;
Status launch_result;
// Use extended launch API only for mainloops that use it
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
cute::size<1>(typename Kernel::ClusterShape{}),
cute::size<2>(typename Kernel::ClusterShape{}));
void const* kernel = (void const*) device_kernel<Kernel>;
void* kernel_params[] = {&params.fmha_params};
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
}
else {
launch_result = Status::kSuccess;
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params.fmha_params);
}
cudaError_t result = cudaGetLastError();
if (cudaSuccess != result or Status::kSuccess != launch_result) {
//return Status::kSuccess;
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
return Status::kErrorInternal;
}
if (params.reduction_params.split_kv > 1) {
// launch reduction kernel
dim3 const block = ReductionKernel::get_block_shape();
dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params);
device_kernel<ReductionKernel><<<grid, block, 0, stream>>>(params.reduction_params);
cudaError_t result = cudaGetLastError();
if (cudaSuccess == result) {
return Status::kSuccess;
}
else {
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
return Status::kErrorInternal;
}
}
else {
return Status::kSuccess;
}
}
//
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
//
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (Status::kSuccess == status) {
status = run(params_, stream);
}
return status;
}
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
return run(args, workspace, stream);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
run(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
operator()(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::device
////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,198 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// clang-format off
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/arch.h"
#include "cute/tensor.hpp"
namespace cutlass::fmha::kernel {
using namespace cute;
template<
class ElementOut,
class ElementAcc,
class ElementScale,
size_t kNumHeads,
size_t kHeadDimLatent,
int kMaxSplits
>
struct Sm100FmhaMlaReductionKernel {
static const int SharedStorageSize = 0;
static const int MaxThreadsPerBlock = 128;
static const int MinBlocksPerMultiprocessor = 1;
using ArchTag = cutlass::arch::Sm100;
static_assert(kHeadDimLatent % MaxThreadsPerBlock == 0);
struct Arguments {
ElementAcc* ptr_oaccum = nullptr;
ElementOut* ptr_o = nullptr;
ElementAcc* ptr_lseaccum = nullptr;
ElementAcc* ptr_lse = nullptr;
ElementScale scale = 1.f;
int num_batches = 0;
int split_kv = -1;
int dim_k = -1;
int* ptr_seq = nullptr;
int* ptr_split_kv = nullptr;
int tile_shape_s = 128;
};
using Params = Arguments;
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
return {args.ptr_oaccum, args.ptr_o, args.ptr_lseaccum, args.ptr_lse,
args.scale, args.num_batches, args.split_kv, args.dim_k, args.ptr_seq,
args.ptr_split_kv, args.tile_shape_s};
}
static size_t get_workspace_size(Arguments const& /*args*/) {
return 0;
}
static Status initialize_workspace(
Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) {
return Status::kSuccess;
}
static dim3 get_grid_shape(Params const& params) {
return dim3(kNumHeads, 1, params.num_batches);
}
static dim3 get_block_shape() {
return dim3(MaxThreadsPerBlock, 1, 1);
}
static bool can_implement(Arguments const& args) {
if (args.num_batches <= 0) return false;
if (args.split_kv <= 0) return false;
return true;
}
CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) {
if (params.split_kv <= 1) return;
auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z);
__shared__ ElementAcc sLseScale[kMaxSplits];
const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord);
const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord);
Tensor gLSEaccum = make_tensor(make_gmem_ptr(params.ptr_lseaccum + offset_lseaccum),
make_shape(params.split_kv), Stride<Int<kNumHeads>>{});
Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse),
Shape<_1>{}, Stride<_1>{});
auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)];
auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)];
auto k_tile_total = ceil_div(dim_k, params.tile_shape_s);
auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv);
local_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
int warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0) {
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
ElementAcc local_lse[kNLsePerThread];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
const int split = i * 32 + threadIdx.x;
local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits<ElementAcc>::infinity();
}
ElementAcc lse_max = -std::numeric_limits<ElementAcc>::infinity();
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
lse_max = max(lse_max, local_lse[i]);
}
CUTLASS_PRAGMA_UNROLL
for (int offset = 16; offset >= 1; offset /= 2) {
lse_max = max(lse_max, __shfl_xor_sync(0xffffffff, lse_max, offset));
}
lse_max = lse_max == -std::numeric_limits<ElementAcc>::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf
lse_max = __shfl_sync(0xffffffff, lse_max, 0);
ElementAcc sum_lse = 0;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
sum_lse = sum_lse + expf(local_lse[i] - lse_max);
}
CUTLASS_PRAGMA_UNROLL
for (int offset = 16; offset >= 1; offset /= 2) {
sum_lse = sum_lse + __shfl_xor_sync(0xffffffff, sum_lse, offset);
}
sum_lse = __shfl_sync(0xffffffff, sum_lse, 0);
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + lse_max;
if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
gLSE(0) = global_lse;
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
const int split = i * 32 + threadIdx.x;
if (split < local_split_kv) {
sLseScale[split] = expf(local_lse[i] - global_lse);
}
}
}
__syncthreads();
constexpr int Elements = kHeadDimLatent / MaxThreadsPerBlock;
const size_t offset_oaccum = kHeadDimLatent * params.split_kv * (get<0>(blk_coord) + kNumHeads * get<2>(blk_coord));
Tensor gOaccum = make_tensor(make_gmem_ptr(params.ptr_oaccum + offset_oaccum),
Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
ElementAcc local_val[Elements] = {0};
for (int split = 0; split < local_split_kv; ++split) {
ElementAcc lse_scale = sLseScale[split];
CUTLASS_PRAGMA_UNROLL
for(int i = 0; i < Elements; ++i) {
local_val[i] += lse_scale * gOaccum(threadIdx.x + MaxThreadsPerBlock * i);
}
gOaccum.data() = gOaccum.data() + kHeadDimLatent;
}
auto ptr_o_local = params.ptr_o + (get<0>(blk_coord) + get<2>(blk_coord) * kNumHeads) * kHeadDimLatent;
Tensor gO = make_tensor(make_gmem_ptr(ptr_o_local), Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
CUTLASS_PRAGMA_UNROLL
for(int i = 0; i < Elements; ++i) {
gO(threadIdx.x + MaxThreadsPerBlock * i) = static_cast<ElementOut>(local_val[i]);
}
}
};
} // namespace cutlass::fmha::kernel

View File

@@ -0,0 +1,160 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
// clang-format off
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/kernel_hardware_info.h"
namespace cutlass::fmha::kernel {
////////////////////////////////////////////////////////////////////////////////
struct Sm100MlaIndividualTileScheduler {
struct Params {
dim3 grid;
};
bool valid_ = true;
CUTLASS_DEVICE
Sm100MlaIndividualTileScheduler(Params const&) {}
template<class ProblemShape, class ClusterShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, int const& split_kv) {
using namespace cute;
dim3 grid(get<0>(cluster_shape), get<3>(problem_shape) /* Batch */, split_kv /*Maximum Split KV*/);
return Params{ grid };
}
static dim3 get_grid_shape(Params const& params) {
return params.grid;
}
CUTLASS_DEVICE
bool is_valid() {
return valid_;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
return make_coord(blockIdx.x, _0{}, blockIdx.y, blockIdx.z);
}
CUTLASS_DEVICE
Sm100MlaIndividualTileScheduler& operator++() {
valid_ = false;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
struct Sm100MlaPersistentTileScheduler {
struct Params {
int num_blocks;
FastDivmod divmod_m_block;
FastDivmod divmod_b;
FastDivmod divmod_split_kv;
KernelHardwareInfo hw_info;
};
int block_idx = 0;
Params params;
CUTLASS_DEVICE
Sm100MlaPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
template<class ProblemShape, class ClusterShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, int const& split_kv) {
using namespace cute;
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = hw_info.sm_count;
if (sm_count <= 1 || sm_count % size<0>(cluster_shape) != 0) {
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
}
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
hw_info.sm_count = sm_count;
int num_m_blocks = size<0>(cluster_shape);
int num_blocks = num_m_blocks * get<3>(problem_shape) /* Batch */;
num_blocks *= split_kv; /* Maximum Split KV*/
return Params {
num_blocks,
{ num_m_blocks}, { get<3>(problem_shape) }, {split_kv},
hw_info
};
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
return grid;
}
CUTLASS_DEVICE
bool is_valid() {
return block_idx < params.num_blocks;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
int block_decode = block_idx;
int m_block, bidb, n_split_kv;
params.divmod_m_block(block_decode, m_block, block_decode);
params.divmod_b(block_decode, bidb, block_decode);
params.divmod_split_kv(block_decode, n_split_kv, block_decode);
return make_coord(m_block, _0{}, bidb, n_split_kv);
}
CUTLASS_DEVICE
Sm100MlaPersistentTileScheduler& operator++() {
block_idx += gridDim.x;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::kernel

View File

@@ -0,0 +1,154 @@
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/all.h>
#define THREADS_PER_BLOCK 128
template <typename T>
__global__ void lightning_attention_decode_kernel(
const T* __restrict__ q, // [b, h, 1, d]
const T* __restrict__ k, // [b, h, 1, d]
const T* __restrict__ v, // [b, h, 1, e]
const float* __restrict__ past_kv, // [b, h, d, e]
const float* __restrict__ slope, // [h, 1, 1]
T* __restrict__ output, // [b, h, 1, e]
float* __restrict__ new_kv, // [b, h, d, e]
const int batch_size,
const int num_heads,
const int qk_dim,
const int v_dim) {
extern __shared__ char smem[];
T* __restrict__ q_shared = reinterpret_cast<T*>(smem);
T* __restrict__ k_shared = reinterpret_cast<T*>(smem + qk_dim * sizeof(T));
T* __restrict__ v_shared = reinterpret_cast<T*>(smem + 2 * qk_dim * sizeof(T));
float* __restrict__ new_kv_shared = reinterpret_cast<float*>(smem + (2 * qk_dim + v_dim) * sizeof(T));
T* __restrict__ output_shared =
reinterpret_cast<T*>(smem + (2 * qk_dim + v_dim) * sizeof(T) + qk_dim * (v_dim + 1) * sizeof(float));
const int32_t tid = threadIdx.x;
const int32_t current_head = blockIdx.x;
const int32_t b = current_head / num_heads;
const int32_t h = current_head % num_heads;
if (b >= batch_size) return;
const int32_t qk_offset = b * num_heads * qk_dim + h * qk_dim;
const int32_t v_offset = b * num_heads * v_dim + h * v_dim;
const int32_t kv_offset = b * num_heads * qk_dim * v_dim + h * qk_dim * v_dim;
// Load q, k, v into shared memory
for (int d = tid; d < qk_dim; d += blockDim.x) {
q_shared[d] = q[qk_offset + d];
k_shared[d] = k[qk_offset + d];
}
for (int e = tid; e < v_dim; e += blockDim.x) {
v_shared[e] = v[v_offset + e];
}
__syncthreads();
const float ratio = expf(-1.0f * slope[h]);
// Compute new_kv
for (int d = tid; d < qk_dim; d += blockDim.x) {
const T k_val = k_shared[d];
for (int e = 0; e < v_dim; ++e) {
const int past_kv_idx = kv_offset + d * v_dim + e;
const T v_val = v_shared[e];
const float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val;
const int shared_idx = d * (v_dim + 1) + e;
new_kv_shared[shared_idx] = new_val;
}
}
__syncthreads();
// Store new_kv to global memory
for (int idx = tid; idx < qk_dim * v_dim; idx += blockDim.x) {
const int d = idx / v_dim;
const int e = idx % v_dim;
const int shared_idx = d * (v_dim + 1) + e;
const int global_idx = kv_offset + idx;
new_kv[global_idx] = new_kv_shared[shared_idx];
}
__syncthreads();
// Compute output
for (int e = tid; e < v_dim; e += blockDim.x) {
float sum = 0.0f;
for (int d = 0; d < qk_dim; ++d) {
const int shared_idx = d * (v_dim + 1) + e;
sum += q_shared[d] * new_kv_shared[shared_idx];
}
output_shared[e] = static_cast<T>(sum);
}
__syncthreads();
// Store output to global memory
if (tid == 0) {
for (int e = 0; e < v_dim; ++e) {
output[v_offset + e] = output_shared[e];
}
}
}
void lightning_attention_decode(
const torch::Tensor& q,
const torch::Tensor& k,
const torch::Tensor& v,
const torch::Tensor& past_kv,
const torch::Tensor& slope,
torch::Tensor output,
torch::Tensor new_kv) {
TORCH_CHECK(q.is_contiguous(), "q must be contiguous");
TORCH_CHECK(k.is_contiguous(), "k must be contiguous");
TORCH_CHECK(v.is_contiguous(), "v must be contiguous");
TORCH_CHECK(past_kv.is_contiguous(), "past_kv must be contiguous");
auto batch_size = q.size(0);
auto num_heads = q.size(1);
auto qk_dim = q.size(3);
auto v_dim = v.size(3);
dim3 block(THREADS_PER_BLOCK);
dim3 grid(batch_size * num_heads);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "lightning_attention_decode_kernel", ([&] {
size_t smem_size = (2 * qk_dim + 2 * v_dim) * sizeof(scalar_t) + qk_dim * (v_dim + 1) * sizeof(float);
lightning_attention_decode_kernel<scalar_t><<<grid, block, smem_size, stream>>>(
q.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(),
v.data_ptr<scalar_t>(),
past_kv.data_ptr<float>(),
slope.data_ptr<float>(),
output.data_ptr<scalar_t>(),
new_kv.data_ptr<float>(),
batch_size,
num_heads,
qk_dim,
v_dim);
}));
}

View File

@@ -0,0 +1,204 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include <optional>
#include "pytorch_extension_utils.h"
// Helper functions to convert between different data types
// (float, half, bfloat16) for the merge attention states kernel.
inline __device__ float to_float(float u) {
return u;
}
inline __device__ float to_float(half u) {
return __half2float(u);
}
inline __device__ float to_float(__nv_bfloat16 u) {
return __bfloat162float(u);
}
inline __device__ void from_float(float& d, float s) {
d = s;
}
inline __device__ void from_float(half& d, float s) {
d = __float2half(s);
}
inline __device__ void from_float(__nv_bfloat16& d, float s) {
d = __float2bfloat16(s);
}
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
template <typename scalar_t, const uint NUM_THREADS>
__global__ void merge_attn_states_kernel(
scalar_t* output,
float* output_lse,
const scalar_t* prefix_output,
const float* prefix_lse,
const scalar_t* suffix_output,
const float* suffix_lse,
const uint num_tokens,
const uint num_heads,
const uint head_size) {
using pack_128b_t = uint4;
const uint pack_size = 16 / sizeof(scalar_t);
const uint threads_per_head = head_size / pack_size;
const uint global_idx = blockIdx.x * NUM_THREADS + threadIdx.x;
const uint token_head_threads = num_tokens * num_heads * threads_per_head;
if (global_idx >= token_head_threads) return;
// global_idx -> token_idx + head_idx + pack_idx
const uint token_head_idx = global_idx / threads_per_head;
const uint pack_idx = global_idx % threads_per_head;
const uint token_idx = token_head_idx / num_heads;
const uint head_idx = token_head_idx % num_heads;
const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
const uint head_offset = token_idx * num_heads * head_size + head_idx * head_size;
const scalar_t* prefix_head_ptr = prefix_output + head_offset;
const scalar_t* suffix_head_ptr = suffix_output + head_offset;
scalar_t* output_head_ptr = output + head_offset;
// float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
// float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
float p_lse = prefix_lse[token_idx * num_heads + head_idx];
float s_lse = suffix_lse[token_idx * num_heads + head_idx];
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
s_lse = std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse;
const float max_lse = fmaxf(p_lse, s_lse);
p_lse = p_lse - max_lse;
s_lse = s_lse - max_lse;
const float p_se = expf(p_lse);
const float s_se = expf(s_lse);
const float out_se = p_se + s_se;
const float p_scale = p_se / out_se;
const float s_scale = s_se / out_se;
if (pack_offset < head_size) {
// Pack 128b load
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(prefix_head_ptr)[pack_offset / pack_size];
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(suffix_head_ptr)[pack_offset / pack_size];
pack_128b_t o_out_pack;
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
// Always use float for FMA to keep high precision.
// half(uint16_t), bfloat16, float -> float.
const float p_out_f = to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
const float s_out_f = to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
// float -> half(uint16_t), bfloat16, float.
from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f);
}
// Pack 128b storage
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] = o_out_pack;
}
// We only need to write to output_lse once per head.
if (output_lse != nullptr && pack_idx == 0) {
float out_lse = logf(out_se) + max_lse;
output_lse[token_idx * num_heads + head_idx] = out_lse;
}
}
// The following macro is used to dispatch the conversion function based on
// the output data type. The FN is a macro that calls a function with
// template<typename scalar_t>.
#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \
{ \
if (scalar_dtype == at::ScalarType::Float) { \
fn(float); \
} else if (scalar_dtype == at::ScalarType::Half) { \
fn(half); \
} else if (scalar_dtype == at::ScalarType::BFloat16) { \
fn(__nv_bfloat16); \
} else { \
TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \
} \
}
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
{ \
merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block, 0, stream>>>( \
reinterpret_cast<scalar_t*>(output.data_ptr()), \
reinterpret_cast<float*>(output_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
reinterpret_cast<float*>(suffix_lse.data_ptr()), \
num_tokens, \
num_heads, \
head_size); \
}
/*@brief Merges the attention states from prefix and suffix
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
*
* @param output [n,h,d] The output tensor to store the merged attention states.
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
* @param prefix_output [n,h,d] The prefix attention states.
* @param prefix_lse [n,h] The log-sum-exp values for the prefix attention
* states.
* @param suffix_output [n,h,d] The suffix attention states.
* @param suffix_lse [n,h] The log-sum-exp values for the suffix attention
* states.
*/
template <typename scalar_t>
void merge_attn_states_launcher(
const at::Tensor& prefix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
const at::Tensor& prefix_lse, // [NUM_TOKENS, NUM_HEADS]
const at::Tensor& suffix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
const at::Tensor& suffix_lse, // [NUM_TOKENS, NUM_HEADS]
at::Tensor& output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
at::Tensor& output_lse // [NUM_TOKENS, NUM_HEADS]
) {
constexpr uint NUM_THREADS = 128;
const uint num_tokens = output.size(0);
const uint num_heads = output.size(1);
const uint head_size = output.size(2);
const uint pack_size = 16 / sizeof(scalar_t);
TORCH_CHECK(head_size % pack_size == 0, "headsize must be multiple of pack_size:", pack_size);
// Process one pack elements per thread. for float, the
// pack_size is 4 for half/bf16, the pack_size is 8.
const uint threads_per_head = head_size / pack_size;
const uint total_threads = num_tokens * num_heads * threads_per_head;
dim3 block(NUM_THREADS);
dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS);
const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device());
auto stream = at::cuda::getCurrentCUDAStream();
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
}
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
{ merge_attn_states_launcher<scalar_t>(v_a, s_a, v_b, s_b, v_merged, s_merged); }
void merge_state_v2(
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged) {
// Input tensors must be contiguous
CHECK_INPUT(v_a); // v_a prefix_output (seq_len, num_heads, head_dim)
CHECK_INPUT(s_a); // s_a prefix_lse (seq_len, num_heads)
CHECK_INPUT(v_b); // v_b suffix_output (seq_len, num_heads, head_dim)
CHECK_INPUT(s_b); // s_b suffix_lse (seq_len, num_heads)
// v_merged output (seq_len, num_heads, head_dim)
// s_merged output_lse (seq_len, num_heads)
auto device = v_a.device();
CHECK_EQ(s_a.device(), device);
CHECK_EQ(v_b.device(), device);
CHECK_EQ(s_b.device(), device);
CHECK_DIM(3, v_a);
CHECK_DIM(2, s_a);
CHECK_DIM(3, v_b);
CHECK_DIM(2, s_b);
CHECK_SHAPE(v_a, v_b);
CHECK_SHAPE(s_a, s_b);
CHECK_EQ(v_a.size(0), s_a.size(0));
CHECK_EQ(v_a.size(1), s_b.size(1));
DISPATCH_BY_SCALAR_DTYPE(v_merged.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
}

View File

@@ -0,0 +1,462 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
// This file is for blocksparse attention utils cuda kernel.
#include <assert.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda.h>
#include <torch/all.h>
// Save the start index of each block in the given range into block_offset.
// Returns the updated block count.
__device__ int64_t save_blocks(
int* block_offset,
int64_t range_start,
int64_t range_end,
int64_t block_size,
int64_t input_block_count,
int64_t kv_seqlen) {
if (range_start >= kv_seqlen) {
return input_block_count;
}
if (range_end > kv_seqlen) {
range_end = kv_seqlen;
}
int64_t current_block_count = input_block_count;
for (int idx = range_start; idx < range_end; idx += block_size) {
block_offset[current_block_count++] = idx;
}
return current_block_count;
}
// CUDA kernel: convert sparse vertical/slash indices to block/column offsets.
__global__ void convert_vertical_slash_indexes_kernel(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t N_HEADS,
int64_t N_ROWS,
int64_t BLOCK_SIZE_M,
int64_t BLOCK_SIZE_N,
int64_t NNZ_V,
int64_t NNZ_S,
bool causal // True for intra, False for succ
) {
const int batch_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int group_idx = blockIdx.z;
int64_t q_seqlen = q_seqlens[batch_idx];
int64_t kv_seqlen = kv_seqlens[batch_idx];
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= q_seqlen) {
return;
}
int64_t end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
column_count += row_offset;
column_index += row_offset * NNZ_V;
bool has_slash = true;
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
int64_t s = 0, v = 0;
int64_t v_idx = vertical_indexes[v++];
int64_t s_idx = slash_indexes[s++];
if (causal) {
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
} else {
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + kv_seqlen) has_slash = false;
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
}
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
if (!has_slash) {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
}
bool slash_finished = false;
while (1) {
if (v_idx < range_end) {
if (v_idx < range_start) {
column_index[tmp_col_cnt++] = v_idx;
}
if (v < NNZ_V) {
v_idx = vertical_indexes[v++];
} else {
if (causal)
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
else
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
}
} else {
if ((s < NNZ_S && causal) || (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
if (causal)
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++], BLOCK_SIZE_M);
else
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
} else {
if (v == NNZ_V || (v_idx > range_start && causal)) {
// add the last vertical if no more slash
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
column_index[tmp_col_cnt++] = v_idx;
}
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
break;
} else {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
// if slash_finished but there are vertical left, save current
// blocks
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
slash_finished = true;
}
}
if (!slash_finished) {
if (s_idx > range_end + BLOCK_SIZE_M) {
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = s_idx - BLOCK_SIZE_M;
range_end = s_idx;
} else if (s_idx > range_end) {
range_end += BLOCK_SIZE_M;
}
}
}
}
block_count[0] = tmp_blk_cnt;
column_count[0] = tmp_col_cnt;
}
// Host function: launches the kernel with 64 threads per block.
void convert_vertical_slash_indexes_64x64(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t BATCH_SIZE,
int64_t N_HEADS,
int64_t N_ROWS,
int64_t BLOCK_SIZE_M,
int64_t BLOCK_SIZE_N,
int64_t NNZ_V,
int64_t NNZ_S,
bool causal) {
const int N_THREADS = 64;
const dim3 dimBlock((int32_t)N_THREADS);
const dim3 dimGrid(
(int32_t)N_HEADS, (int32_t)BATCH_SIZE, ((int32_t)N_ROWS + (int32_t)N_THREADS - 1) / (int32_t)N_THREADS);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock, 0, stream>>>(
q_seqlens,
kv_seqlens,
vertical_indexes,
slash_indexes,
block_count,
block_offset,
column_count,
column_index,
N_HEADS,
N_ROWS,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
NNZ_V,
NNZ_S,
causal);
}
// Host function: prepares tensor pointers and launches the CUDA kernel.
void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int64_t context_size,
int64_t block_size_M,
int64_t block_size_N,
bool causal) {
cudaSetDevice(q_seqlens.get_device());
int64_t batch_size = slash_indexes.size(0);
int64_t num_heads = slash_indexes.size(1);
int64_t nnz_slash = slash_indexes.size(2);
int64_t nnz_vertical = vertical_indexes.size(2);
int64_t num_rows = (context_size + block_size_M - 1) / block_size_M;
convert_vertical_slash_indexes_64x64(
q_seqlens.data_ptr<int>(),
kv_seqlens.data_ptr<int>(),
vertical_indexes.data_ptr<int>(),
slash_indexes.data_ptr<int>(),
block_count.data_ptr<int>(),
block_offset.data_ptr<int>(),
column_count.data_ptr<int>(),
column_index.data_ptr<int>(),
batch_size,
num_heads,
num_rows,
block_size_M,
block_size_N,
nnz_vertical,
nnz_slash,
causal);
}
// --- mergehead kernels --- //
// Kernel: like above, but supports per-head variable NNZ_V/NNZ_S.
__global__ void convert_vertical_slash_indexes_kernel_mergehead(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
const int* per_head_vertical_topkv,
const int* per_head_slash_topkv,
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t N_HEADS,
int64_t N_ROWS,
int64_t BLOCK_SIZE_M,
int64_t BLOCK_SIZE_N,
int64_t NNZ_V,
int64_t NNZ_S,
bool causal // True for intra, False for succ
) {
const int batch_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int group_idx = blockIdx.z;
int64_t q_seqlen = q_seqlens[batch_idx];
int64_t kv_seqlen = kv_seqlens[batch_idx];
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= q_seqlen) {
return;
}
int64_t end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
column_count += row_offset;
column_index += row_offset * NNZ_V;
// MergeHead: each head has it's unique max topk NNZ_VNNZ_S. (NNZ_VNNZ_S
// above is buffer size, use to compute offset)
NNZ_S = per_head_slash_topkv[head_idx];
NNZ_V = per_head_vertical_topkv[head_idx];
bool has_slash = true;
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
int64_t s = 0, v = 0;
int64_t v_idx = vertical_indexes[v++];
int64_t s_idx = slash_indexes[s++];
if (causal) {
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
} else {
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + kv_seqlen) has_slash = false;
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
}
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
if (!has_slash) {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
}
bool slash_finished = false;
while (1) {
if (v_idx < range_end) {
if (v_idx < range_start) {
column_index[tmp_col_cnt++] = v_idx;
}
if (v < NNZ_V) {
v_idx = vertical_indexes[v++];
} else {
if (causal)
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
else
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
}
} else {
if ((s < NNZ_S && causal) || (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
if (causal)
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++], BLOCK_SIZE_M);
else
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
} else {
if (v == NNZ_V || (v_idx > range_start && causal)) {
// add the last vertical if no more slash
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
column_index[tmp_col_cnt++] = v_idx;
}
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
break;
} else {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
// if slash_finished but there are vertical left, save current
// blocks
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
slash_finished = true;
}
}
if (!slash_finished) {
if (s_idx > range_end + BLOCK_SIZE_M) {
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = s_idx - BLOCK_SIZE_M;
range_end = s_idx;
} else if (s_idx > range_end) {
range_end += BLOCK_SIZE_M;
}
}
}
}
block_count[0] = tmp_blk_cnt;
column_count[0] = tmp_col_cnt;
}
// Launch the mergehead kernel with 64 threads per block.
void convert_vertical_slash_indexes_64x64_mergehead(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* per_head_vertical_topkv,
int* per_head_slash_topkv,
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t BATCH_SIZE,
int64_t N_HEADS,
int64_t N_ROWS,
int64_t BLOCK_SIZE_M,
int64_t BLOCK_SIZE_N,
int64_t NNZ_V,
int64_t NNZ_S,
bool causal) {
const int N_THREADS = 64;
const dim3 dimBlock(N_THREADS);
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
convert_vertical_slash_indexes_kernel_mergehead<<<dimGrid, dimBlock, 0, stream>>>(
q_seqlens,
kv_seqlens,
vertical_indexes,
slash_indexes,
per_head_vertical_topkv,
per_head_slash_topkv,
block_count,
block_offset,
column_count,
column_index,
N_HEADS,
N_ROWS,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
NNZ_V,
NNZ_S,
causal);
}
// Host wrapper for mergehead kernel.
void convert_vertical_slash_indexes_mergehead(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
torch::Tensor vertical_indices_count, // [N_HEADS, ]
torch::Tensor slash_indices_count,
int64_t context_size,
int64_t block_size_M,
int64_t block_size_N,
bool causal) {
cudaSetDevice(q_seqlens.get_device());
int batch_size = slash_indexes.size(0);
int num_heads = slash_indexes.size(1);
int nnz_slash = slash_indexes.size(2);
int nnz_vertical = vertical_indexes.size(2);
int num_rows = (context_size + block_size_M - 1) / block_size_M;
convert_vertical_slash_indexes_64x64_mergehead(
q_seqlens.data_ptr<int>(),
kv_seqlens.data_ptr<int>(),
vertical_indexes.data_ptr<int>(),
slash_indexes.data_ptr<int>(),
vertical_indices_count.data_ptr<int>(),
slash_indices_count.data_ptr<int>(),
block_count.data_ptr<int>(),
block_offset.data_ptr<int>(),
column_count.data_ptr<int>(),
column_index.data_ptr<int>(),
batch_size,
num_heads,
num_rows,
block_size_M,
block_size_N,
nnz_vertical,
nnz_slash,
causal);
}

View File

@@ -0,0 +1,455 @@
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/all.h>
#include <torch/library.h>
#include "sgl_kernel_ops.h"
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
/*
* From csrc/allreduce
*/
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
m.def("register_graph_buffers", &register_graph_buffers);
m.def("dispose", &dispose);
m.def("meta_size", &meta_size);
m.def("register_buffer", &register_buffer);
m.def(
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
"int rank, bool full_nvlink) -> int");
m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
m.def(
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()");
m.impl("all_reduce", torch::kCUDA, &all_reduce);
m.def("mscclpp_generate_unique_id", &mscclpp_generate_unique_id);
m.def(
"mscclpp_init_context(Tensor unique_id, int rank, int world_size, Tensor scratch, Tensor put_buffer, "
"int nranks_per_node, int[] rank_to_node, int[] rank_to_ib, int context_selection) -> int");
m.impl("mscclpp_init_context", torch::kCUDA, &mscclpp_init_context);
m.def("mscclpp_allreduce(int context, Tensor inp, Tensor! out, int nthreads, int nblocks) -> ()");
m.impl("mscclpp_allreduce", torch::kCUDA, &mscclpp_allreduce);
/*
* From csrc/attention
*/
m.def(
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
"new_kv) -> ()");
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
m.def("merge_state(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()");
m.impl("merge_state", torch::kCUDA, &merge_state);
m.def("merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()");
m.impl("merge_state_v2", torch::kCUDA, &merge_state_v2);
m.def(
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"page_table, Tensor! workspace, float sm_scale, int num_kv_splits) -> ()");
m.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
m.def("cutlass_mla_get_workspace_size", &cutlass_mla_get_workspace_size);
/*
* From csrc/elementwise
*/
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, bool enable_pdl) -> ()");
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, bool enable_pdl) -> ()");
m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm);
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, bool enable_pdl) -> ()");
m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm);
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, bool enable_pdl) -> ()");
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);
m.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
m.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
m.def(
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
"Tensor pos_ids, bool interleave, bool enable_pdl, int cuda_stream, "
"Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()");
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
m.def(
"downcast_fp8(Tensor k, Tensor v, Tensor k_out, Tensor v_out, Tensor k_scale, Tensor v_scale, Tensor loc, int "
"mult, int offset, int cuda_stream) -> ()");
m.impl("downcast_fp8", torch::kCUDA, &downcast_fp8);
/*
* From csrc/gemm
*/
m.def("awq_dequantize(Tensor qweight, Tensor scales, Tensor qzeros) -> Tensor");
m.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
m.def(
"int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
"bias) -> Tensor");
m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm);
m.def(
"fp8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
"bias) -> Tensor");
m.impl("fp8_scaled_mm", torch::kCUDA, &fp8_scaled_mm);
m.def(
"fp8_blockwise_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype) -> "
"Tensor");
m.impl("fp8_blockwise_scaled_mm", torch::kCUDA, &fp8_blockwise_scaled_mm);
m.def(
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float fp8_min, float fp8_max, bool scale_ue8m0) -> ()");
m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8);
m.def(
"sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float int8_min, float int8_max) -> ()");
m.impl("sgl_per_token_group_quant_int8", torch::kCUDA, &sgl_per_token_group_quant_int8);
m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()");
m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8);
m.def("sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()");
m.impl("sgl_per_token_quant_fp8", torch::kCUDA, &sgl_per_token_quant_fp8);
m.def(
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor block_scale_a, Tensor block_scale_b,"
" Tensor alpha) -> ()");
m.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
m.def(
"scaled_fp4_quant(Tensor! output, Tensor! input,"
" Tensor! output_scale, Tensor! input_scale) -> ()");
m.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
m.def("dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
m.impl("dsv3_fused_a_gemm", torch::kCUDA, &dsv3_fused_a_gemm);
// Compute NVFP4 experts quantization.
m.def(
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
m.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant);
m.def(
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_global_scale, Tensor mask, bool use_silu_and_mul) -> ()");
m.impl("silu_and_mul_scaled_fp4_experts_quant", torch::kCUDA, &silu_and_mul_scaled_fp4_experts_quant);
m.def(
"cutlass_fp4_group_mm(Tensor! output, Tensor a, Tensor b,"
"Tensor a_blockscale, Tensor b_blockscale, Tensor alphas,"
"Tensor ab_strides, Tensor c_strides, Tensor problem_sizes,"
" Tensor expert_offsets, Tensor sf_offsets) -> ()");
m.impl("cutlass_fp4_group_mm", torch::kCUDA, &cutlass_fp4_group_mm);
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm);
/*
* From csrc/gemm/gptq
*/
m.def(
"gptq_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale_or_none,"
"Tensor? b_zeros_or_none, Tensor? g_idx_or_none, Tensor? perm_or_none,"
"Tensor! workspace, int b_q_type_id, int size_m, int size_n, int size_k,"
"bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
m.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
m.def(
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, Tensor b_gptq_scales, Tensor b_g_idx, bool "
"use_shuffle, int bit) -> Tensor");
m.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
m.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
m.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
m.def("gptq_marlin_repack(Tensor! b_q_weight, Tensor! perm, int size_k, int size_n, int num_bits) -> Tensor");
m.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
m.def("awq_marlin_repack(Tensor! b_q_weight, int size_k, int size_n, int num_bits) -> Tensor");
m.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
/*
* From csrc/moe
*/
m.def(
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
"experts_ids, Tensor! num_tokens_post_pad, Tensor! cumsum_buffer, bool "
"pad_sorted_token_ids) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
m.def("topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
m.def(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
"num_fused_shared_experts, float routed_scaling_factor, bool apply_routed_scaling_factor_on_output) -> "
"(Tensor[])");
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
m.def(
"ep_moe_pre_reorder(Tensor input, Tensor gateup_input, Tensor src2dst, Tensor topk_ids, Tensor "
"a1_scales, int start_expert_id, int end_expert_id, int topk, bool use_per_token_if_dynamic) -> ()");
m.impl("ep_moe_pre_reorder", torch::kCUDA, &ep_moe_pre_reorder);
m.def(
"ep_moe_silu_and_mul(Tensor gateup_output, Tensor down_input, Tensor reorder_topk_ids, Tensor scales, int "
"start_expert_id, int end_expert_id) -> ()");
m.impl("ep_moe_silu_and_mul", torch::kCUDA, &ep_moe_silu_and_mul);
m.def(
"ep_moe_post_reorder(Tensor down_output, Tensor output, Tensor src2dst, Tensor topk_ids, Tensor "
"topk_weights, int start_expert_id, int end_expert_id, int topk) -> ()");
m.impl("ep_moe_post_reorder", torch::kCUDA, &ep_moe_post_reorder);
m.def(
"fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a_ptrs, Tensor b_ptrs, Tensor out_ptrs, Tensor "
"a_scales_ptrs, Tensor b_scales_ptrs, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor "
"stride_a, Tensor stride_b, Tensor stride_c, Tensor layout_sfa, Tensor layout_sfb, Tensor problem_sizes, Tensor "
"expert_offsets, Tensor workspace) -> ()");
m.impl("fp8_blockwise_scaled_grouped_mm", torch::kCUDA, &fp8_blockwise_scaled_grouped_mm);
m.def(
"prepare_moe_input(Tensor topk_ids, Tensor expert_offsets, Tensor? blockscale_offsets, Tensor problem_sizes1,"
" Tensor problem_sizes2, Tensor input_permutation, Tensor output_permutation, int num_experts, int n, int k) -> "
"()");
m.impl("prepare_moe_input", torch::kCUDA, &prepare_moe_input);
m.def("shuffle_rows(Tensor input, Tensor dst2src_map, Tensor output) -> ()");
m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows);
m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()");
m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum);
/*
* From csrc/moe/marlin_moe_wna16
*/
m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
"int size_m, int size_n, int size_k,"
"bool is_k_full, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
m.impl("moe_wna16_marlin_gemm", torch::kCUDA, &moe_wna16_marlin_gemm);
/*
* From csrc/moe/cutlass_moe/w4a8
*/
m.def(
"get_cutlass_w4a8_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, "
" int n, int k) -> ()");
m.impl("get_cutlass_w4a8_moe_mm_data", torch::kCUDA, &get_cutlass_w4a8_moe_mm_data);
m.def(
"cutlass_w4a8_moe_mm(Tensor! d, Tensor a, Tensor b, "
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
" Tensor problem_sizes, Tensor a_strides, "
" Tensor b_strides, Tensor d_strides, Tensor s_strides,"
" int chunk_size, int topk) -> ()");
m.impl("cutlass_w4a8_moe_mm", torch::kCUDA, &cutlass_w4a8_moe_mm);
/*
* From csrc/speculative
*/
m.def(
"tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor uniform_samples, Tensor uniform_samples_for_final_sampling, Tensor target_probs, Tensor draft_probs, "
"float threshold_single, float threshold_acc, "
"bool deterministic, int cuda_stream) -> ()");
m.impl("tree_speculative_sampling_target_only", torch::kCUDA, &tree_speculative_sampling_target_only);
m.def(
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor target_predict, int cuda_stream) -> ()");
m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy);
m.def(
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num, int tree_mask_mode) -> "
"()");
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
m.def(
"segment_packbits(Tensor x, Tensor input_indptr, Tensor output_indptr, Tensor! y, int batch_size, "
"int cuda_stream) -> ()");
m.impl("segment_packbits", torch::kCUDA, &segment_packbits);
/*
* From csrc/kvcacheio
*/
m.def(
"transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_per_layer", torch::kCUDA, &transfer_kv_per_layer);
m.def(
"transfer_kv_per_layer_pf_lf(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int layer_id, int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_per_layer_pf_lf", torch::kCUDA, &transfer_kv_per_layer_pf_lf);
m.def(
"transfer_kv_all_layer(Tensor src_k_layers, Tensor dst_k_layers, Tensor src_v_layers, Tensor dst_v_layers, "
"Tensor src_indices, Tensor dst_indices, int item_size, int num_layers, int block_quota, int "
"num_warps_per_block) -> ()");
m.impl("transfer_kv_all_layer", torch::kCUDA, &transfer_kv_all_layer);
m.def(
"transfer_kv_all_layer_lf_pf(Tensor src_k_layers, Tensor dst_k, Tensor src_v_layers, Tensor dst_v, "
"Tensor src_indices, Tensor dst_indices, int item_size, int dst_layout_dim, int num_layers, int block_quota, int "
"num_warps_per_block) -> ()");
m.impl("transfer_kv_all_layer_lf_pf", torch::kCUDA, &transfer_kv_all_layer_lf_pf);
m.def(
"transfer_kv_per_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int "
"block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_per_layer_mla", torch::kCUDA, &transfer_kv_per_layer_mla);
m.def(
"transfer_kv_per_layer_mla_pf_lf(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int layer_id, "
"int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_per_layer_mla_pf_lf", torch::kCUDA, &transfer_kv_per_layer_mla_pf_lf);
m.def(
"transfer_kv_all_layer_mla(Tensor src_layers, Tensor dst_layers, Tensor src_indices, Tensor dst_indices, int "
"item_size, int num_layers, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_all_layer_mla", torch::kCUDA, &transfer_kv_all_layer_mla);
m.def(
"transfer_kv_all_layer_mla_lf_pf(Tensor src_layers, Tensor dst, Tensor src_indices, Tensor dst_indices, "
"int item_size, int dst_layout_dim, int num_layers, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_all_layer_mla_lf_pf", torch::kCUDA, &transfer_kv_all_layer_mla_lf_pf);
m.def(
"transfer_kv_direct(Tensor[] src_layers, Tensor[] dst_layers, Tensor src_indices, Tensor dst_indices, int "
"page_size) -> ()");
m.impl("transfer_kv_direct", torch::kCUDA, &transfer_kv_direct);
/*
* From csrc/memory
*/
m.def("store_kv_cache(Tensor k_cache, Tensor v_cache, Tensor out_loc, Tensor k, Tensor v) -> ()");
m.impl("store_kv_cache", &store_kv_cache);
/*
* From FlashInfer
*/
m.def(
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int "
"cublas_handle, int cuda_stream) -> ()",
{at::Tag::needs_fixed_stride_order});
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
m.def(
"min_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? maybe_min_p_arr, float "
"min_p_val, bool deterministic, Generator? gen) -> ()");
m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs);
m.def("top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val) -> ()");
m.impl("top_k_renorm_probs", torch::kCUDA, &top_k_renorm_probs);
m.def("top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val) -> ()");
m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs);
m.def(
"top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? "
"maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
m.def(
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? maybe_top_k_arr, "
"float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs);
m.def("top_k_mask_logits(Tensor logits, Tensor mask_logits, Tensor? maybe_top_k_arr, int top_k_val) -> ()");
m.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits);
/*
* From Sparse Flash Attention
*/
m.def(
"fwd_sparse(Tensor! q, Tensor k, Tensor v, "
"Tensor block_count, Tensor block_offset, Tensor column_count, Tensor column_index, "
"Tensor!? out, Tensor? alibi_slopes, "
"float p_dropout, float softmax_scale, bool is_causal, "
"float softcap, bool return_softmax, Generator? gen)"
"-> Tensor[]");
m.impl("fwd_sparse", torch::kCUDA, &flash::mha_fwd_sparse);
m.def(
"varlen_fwd_sparse(Tensor! q, Tensor k, Tensor v, "
"Tensor block_count, Tensor block_offset, Tensor column_count, Tensor column_index, "
"Tensor!? out, Tensor cu_seqlens_q, "
"Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? alibi_slopes, "
"int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, "
"bool is_causal, float softcap, bool return_softmax, "
"Generator? gen) -> Tensor[]");
m.impl("varlen_fwd_sparse", torch::kCUDA, &flash::mha_varlen_fwd_sparse);
// Sparse Attention utils
m.def(
"convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, "
" Tensor! column_count, Tensor! column_index, "
" Tensor q_seqlens, Tensor q_seqlens, "
" Tensor vertical_indexes, Tensor slash_indexes, "
" int context_size, int block_size_M, int block_size_N, "
" bool causal) -> ()");
m.impl("convert_vertical_slash_indexes", torch::kCUDA, &convert_vertical_slash_indexes);
m.def(
"convert_vertical_slash_indexes_mergehead("
" Tensor! block_count, Tensor! block_offset, "
" Tensor! column_count, Tensor! column_index, "
" Tensor q_seqlens, Tensor q_seqlens, "
" Tensor vertical_indexes, Tensor slash_indexes, "
" Tensor vertical_indices_count, Tensor slash_indices_count, "
" int context_size, int block_size_M, int block_size_N, "
" bool causal) -> ()");
m.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA, &convert_vertical_slash_indexes_mergehead);
/*
* From csrc/grammar
*/
m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()");
m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace);
/*
* From csrc/gemm (QServe)
*/
m.def(
"qserve_w4a8_per_chn_gemm(Tensor _in_feats, Tensor _kernel, Tensor _wscales, Tensor _ascales, Tensor _w_szs, "
"Tensor _a_ssums, Tensor! _out_feats) -> ()");
m.impl("qserve_w4a8_per_chn_gemm", torch::kCUDA, &qserve_w4a8_per_chn_gemm);
m.def(
"qserve_w4a8_per_group_gemm(Tensor _in_feats, Tensor _kernel, Tensor _zeros, Tensor _scales_i8, Tensor _wscales, "
"Tensor _ascales, Tensor! _out_feats) -> ()");
m.impl("qserve_w4a8_per_group_gemm", torch::kCUDA, &qserve_w4a8_per_group_gemm);
}
REGISTER_EXTENSION(common_ops)

View File

@@ -0,0 +1,168 @@
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h>
#include "sgl_kernel_ops.h"
TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/*
* From csrc/activation
*/
m.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
m.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
m.def("gelu_quick(Tensor! out, Tensor input) -> ()");
m.impl("gelu_quick", torch::kCUDA, &gelu_quick);
/*
* From csrc/allreduce
*/
m.def(
"init_custom_ar(Tensor meta, Tensor rank_data, "
"str[] handles, int[] offsets, int rank, "
"bool full_nvlink) -> int");
m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
m.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
m.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);
m.def(
"all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> "
"()");
m.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);
m.def("dispose", &dispose);
m.def("meta_size", &meta_size);
m.def(
"register_buffer(int fa, Tensor t, str[] handles, "
"int[] offsets) -> ()");
m.impl("register_buffer", torch::kCUDA, &register_buffer);
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
m.def("register_graph_buffers", &register_graph_buffers);
m.def("allocate_meta_buffer", &allocate_meta_buffer);
m.impl("allocate_meta_buffer", torch::kCUDA, &allocate_meta_buffer);
m.def("get_meta_buffer_ipc_handle", &get_meta_buffer_ipc_handle);
m.impl("get_meta_buffer_ipc_handle", torch::kCPU, &get_meta_buffer_ipc_handle);
// quick allreduce
#ifdef USE_ROCM
m.def(
"qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool "
"cast_bf2half) -> ()");
m.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce);
m.def("init_custom_qr", &init_custom_qr);
m.def("qr_destroy", &qr_destroy);
m.def("qr_get_handle", &qr_get_handle);
m.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()");
m.impl("qr_open_handles", torch::kCPU, &qr_open_handles);
// Max input size in bytes
m.def("qr_max_size", &qr_max_size);
#endif
/*
* From csrc/moe
*/
m.def(
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
"experts_ids, Tensor! num_tokens_post_pad, Tensor! cumsum_buffer, bool "
"pad_sorted_token_ids) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
m.def("topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
/*
* From csrc/speculative
*/
m.def(
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
"Tensor target_predict, int cuda_stream) -> ()");
m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy);
m.def(
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num, int tree_mask_mode) -> "
"()");
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
/*
* From XGrammar
*/
m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()");
m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace);
/*
* From csrc/kvcacheio
*/
m.def(
"transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_per_layer", torch::kCUDA, &transfer_kv_per_layer);
m.def(
"transfer_kv_per_layer_pf_lf(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int layer_id, int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_per_layer_pf_lf", torch::kCUDA, &transfer_kv_per_layer_pf_lf);
m.def(
"transfer_kv_all_layer(Tensor src_k_layers, Tensor dst_k_layers, Tensor src_v_layers, Tensor dst_v_layers, "
"Tensor src_indices, Tensor dst_indices, int item_size, int num_layers, int block_quota, int "
"num_warps_per_block) -> ()");
m.impl("transfer_kv_all_layer", torch::kCUDA, &transfer_kv_all_layer);
m.def(
"transfer_kv_all_layer_lf_pf(Tensor src_k_layers, Tensor dst_k, Tensor src_v_layers, Tensor dst_v, "
"Tensor src_indices, Tensor dst_indices, int item_size, int dst_layout_dim, int num_layers, int block_quota, int "
"num_warps_per_block) -> ()");
m.impl("transfer_kv_all_layer_lf_pf", torch::kCUDA, &transfer_kv_all_layer_lf_pf);
m.def(
"transfer_kv_per_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int "
"block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_per_layer_mla", torch::kCUDA, &transfer_kv_per_layer_mla);
m.def(
"transfer_kv_per_layer_mla_pf_lf(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int layer_id, "
"int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_per_layer_mla_pf_lf", torch::kCUDA, &transfer_kv_per_layer_mla_pf_lf);
m.def(
"transfer_kv_all_layer_mla(Tensor src_layers, Tensor dst_layers, Tensor src_indices, Tensor dst_indices, int "
"item_size, int num_layers, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_all_layer_mla", torch::kCUDA, &transfer_kv_all_layer_mla);
m.def(
"transfer_kv_all_layer_mla_lf_pf(Tensor src_layers, Tensor dst, Tensor src_indices, Tensor dst_indices, "
"int item_size, int dst_layout_dim, int num_layers, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_all_layer_mla_lf_pf", torch::kCUDA, &transfer_kv_all_layer_mla_lf_pf);
m.def(
"transfer_kv_direct(Tensor[] src_layers, Tensor[] dst_layers, Tensor src_indices, Tensor dst_indices, int "
"page_size) -> ()");
m.impl("transfer_kv_direct", torch::kCUDA, &transfer_kv_direct);
}
REGISTER_EXTENSION(common_ops)

View File

@@ -0,0 +1,83 @@
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
project(sgl_kernel)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
execute_process(
COMMAND ${Python_EXECUTABLE}
-c "import torch; print(torch.utils.cmake_prefix_path)"
OUTPUT_VARIABLE TORCH_PY_PREFIX
OUTPUT_STRIP_TRAILING_WHITESPACE
)
message(STATUS ${TORCH_PY_PREFIX})
list(APPEND CMAKE_PREFIX_PATH ${TORCH_PY_PREFIX}/Torch)
find_package(Torch REQUIRED)
include_directories(
${TORCH_INCLUDE_DIRS}
${TORCH_INSTALL_PREFIX}/include
${Python_INCLUDE_DIRS}
${CMAKE_CURRENT_SOURCE_DIR}/../../csrc
${CMAKE_CURRENT_SOURCE_DIR}/../../include
)
# Platform-specific library directory
if(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|AMD64")
set(PLAT_LIB_DIR "/usr/lib/x86_64-linux-gnu")
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64")
set(PLAT_LIB_DIR "/usr/lib/aarch64-linux-gnu")
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "ppc64le|ppc64")
set(PLAT_LIB_DIR "/usr/lib/powerpc64le-linux-gnu")
else()
set(PLAT_LIB_DIR "/usr/lib/${CMAKE_SYSTEM_PROCESSOR}-linux-gnu")
endif()
link_directories(${PLAT_LIB_DIR})
# Conda library path support
if(DEFINED ENV{CONDA_PREFIX})
set(CONDA_LIB_DIR "$ENV{CONDA_PREFIX}/lib")
message(STATUS "Using Conda lib dir: ${CONDA_LIB_DIR}")
link_directories(${CONDA_LIB_DIR})
set(CONDA_INCLUDE_DIR "$ENV{CONDA_PREFIX}/include")
include_directories(${CONDA_INCLUDE_DIR})
# Look for libnuma in Conda's lib directory
find_library(NUMA_LIB numa HINTS "${CONDA_LIB_DIR}")
if(NUMA_LIB)
message(STATUS "Found libnuma: ${NUMA_LIB}")
else()
message(FATAL_ERROR "libnuma not found in Conda environment at ${CONDA_LIB_DIR}\n"
"Please install it using: conda install libnuma numactl\n")
endif()
endif()
file(GLOB SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp")
if(NOT DEFINED ENV{SGLANG_CPU_FP8_CVT_FTZ})
set(ENV{SGLANG_CPU_FP8_CVT_FTZ} "1")
endif()
if("$ENV{SGLANG_CPU_FP8_CVT_FTZ}" STREQUAL "1")
message(STATUS "Enabling macro: SGLANG_CPU_FP8_CVT_FTZ")
add_compile_definitions(SGLANG_CPU_FP8_CVT_FTZ)
endif()
add_compile_options(
-O3
-Wno-unknown-pragmas
-march=native
-fopenmp
)
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} ${NUMA_LIB})
target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS})
install(TARGETS common_ops
LIBRARY DESTINATION sgl_kernel
)

View File

@@ -0,0 +1,79 @@
#include "common.h"
#include "vec.h"
namespace {
template <typename scalar_t, typename func_t, typename vec_func_t>
void act_and_mul_kernel_impl(
scalar_t* __restrict__ output,
const scalar_t* __restrict__ input,
int64_t num_tokens,
int64_t dim,
const func_t& f,
const vec_func_t& vf) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int64_t kVecSize = bVec::size();
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
// local ptrs
const scalar_t* __restrict__ input_ptr = input + i * 2 * dim;
const scalar_t* __restrict__ input_other_ptr = input_ptr + dim;
scalar_t* __restrict__ output_ptr = output + i * dim;
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= dim - kVecSize; d += kVecSize) {
bVec x_bvec = bVec::loadu(input_ptr + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
bVec y_bvec = bVec::loadu(input_other_ptr + d);
fVec y_fvec0, y_fvec1;
std::tie(y_fvec0, y_fvec1) = at::vec::convert_to_float(y_bvec);
x_fvec0 = vf(x_fvec0);
x_fvec1 = vf(x_fvec1);
x_fvec0 = x_fvec0 * y_fvec0;
x_fvec1 = x_fvec1 * y_fvec1;
x_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
x_bvec.store(output_ptr + d);
}
#pragma GCC unroll 4
for (; d < dim; ++d) {
float x_val = static_cast<float>(input_ptr[d]);
float y_val = static_cast<float>(input_other_ptr[d]);
output_ptr[d] = f(x_val) * y_val;
}
}
});
}
} // anonymous namespace
// input : {num_tokens, 2 * d}
// output : {num_tokens, d}
at::Tensor silu_and_mul_cpu(at::Tensor& input) {
RECORD_FUNCTION("sgl-kernel::silu_and_mul_cpu", std::vector<c10::IValue>({input}));
auto sizes = input.sizes().vec();
int64_t last_dim = input.ndimension() - 1;
int64_t d = sizes[last_dim] / 2;
sizes[last_dim] = d;
int64_t num_tokens = input.numel() / input.size(-1);
at::Tensor out = at::empty(sizes, input.options());
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "silu_and_mul", [&] {
using Vec = at::vec::Vectorized<float>;
act_and_mul_kernel_impl(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
num_tokens,
d,
[](float x) { return x / (1.f + std::exp(-x)); },
[](Vec x) { return x / (Vec(1.f) + x.neg().exp()); });
});
return out;
}

123
sgl-kernel/csrc/cpu/bmm.cpp Normal file
View File

@@ -0,0 +1,123 @@
#include "common.h"
#include "gemm.h"
#include "vec.h"
namespace {
template <typename scalar_t>
void bmm_kernel_impl(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ mat1,
const scalar_t* __restrict__ mat2,
int64_t B,
int64_t M,
int64_t N,
int64_t K,
int64_t mat1_strideB,
int64_t mat1_strideM,
int64_t out_strideB,
int64_t out_strideM,
float scale = 0.f) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
// mat2 contiguous in [B, N, K]
int64_t mat2_strideB = N * K;
int64_t mat2_strideN = K;
const bool use_brgemm = can_use_brgemm<scalar_t>(M);
// parallel on [B, MB, NB]
at::parallel_for(0, B * MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t bs{0}, mb{0}, nb{0};
data_index_init(begin, bs, B, mb, MB, nb, NB);
// for brgemm, use float32 for accumulate
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
for (int i = begin; i < end; ++i) {
UNUSED(i);
int mb_start = mb * BLOCK_M;
int mb_size = std::min(M - mb_start, BLOCK_M);
int nb_start = nb * BLOCK_N;
int nb_size = std::min(N - nb_start, BLOCK_N);
tinygemm_kernel<scalar_t>(
/* A */ mat1 + bs * mat1_strideB + mb_start * mat1_strideM,
/* B */ mat2 + bs * mat2_strideB + nb_start * mat2_strideN /* nb * BLOCK_N * K */,
/* C */ out + bs * out_strideB + mb_start * out_strideM + nb_start,
/* Ctmp*/ Ctmp,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ mat1_strideM,
/* ldb */ nb_size,
/* ldc */ out_strideM,
/* brg */ use_brgemm);
// move to the next index
data_index_step(bs, B, mb, MB, nb, NB);
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
}
} // anonymous namespace
// mat1 : [B, M, K]
// mat2 : [B, N, K] or [B, OC, IC]
// out : [B, M, N]
// scale: [] 0-dim tensor for per tensor quant
//
void bmm_cpu(
at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale) {
RECORD_FUNCTION("sgl-kernel::bmm_cpu", std::vector<c10::IValue>({out, mat1, mat2}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
// input and out could be non-contiguous
// weight needs to be contiguous in [OC, IC] order
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(out);
CHECK_INPUT(mat2);
CHECK_DIM(3, out);
CHECK_DIM(3, mat1);
CHECK_DIM(3, mat2);
int64_t B = mat1.size(0);
int64_t M = mat1.size(1);
int64_t N = mat2.size(1);
int64_t K = mat1.size(2);
TORCH_CHECK(!scale.has_value(), "bmm: do not support fp8 weight for now.")
TORCH_CHECK(N % 32 == 0, "tinygemm requires N to be 32x.");
int64_t mat1_strideB = mat1.stride(0);
int64_t mat1_strideM = mat1.stride(1);
int64_t out_strideB = out.stride(0);
int64_t out_strideM = out.stride(1);
// check shapes
TORCH_CHECK(mat2.size(0) == B && mat2.size(2) == K, "bmm: mat2 shape mismatch!");
TORCH_CHECK(out.size(0) == B && out.size(1) == M, "bmm: out shape mismatch!");
AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "bmm_kernel_impl", [&] {
bmm_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
mat1.data_ptr<scalar_t>(),
packed_w.data_ptr<scalar_t>(),
B,
M,
N,
K,
mat1_strideB,
mat1_strideM,
out_strideB,
out_strideM);
});
}

View File

@@ -0,0 +1,324 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <ATen/record_function.h>
#if defined(_OPENMP)
#include <omp.h>
#endif
namespace {
// dispatch bool
#define AT_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \
[&] { \
if (BOOL_V) { \
constexpr bool BOOL_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool BOOL_NAME = false; \
return __VA_ARGS__(); \
} \
}()
// dispatch: bfloat16, float16, int8_t, fp8_e4m3
#define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \
[&] { \
switch (TYPE) { \
case at::ScalarType::BFloat16: { \
using packed_t = at::BFloat16; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Half: { \
using packed_t = at::Half; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Char: { \
using packed_t = int8_t; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Float8_e4m3fn: { \
using packed_t = at::Float8_e4m3fn; \
return __VA_ARGS__(); \
} \
default: \
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
} \
}()
// dispatch with mixed dtypes (TYPE1, TYPE2):
// TYPE1: the primary dtype (input, output, weight);
// TYPE2: the secondary dtype (bias, etc.).
#define CPU_DISPATCH_REDUCED_FLOATING_TYPES_EXT(TYPE1, TYPE2, ...) \
[&] { \
if (TYPE2 == at::kFloat) { \
switch (TYPE1) { \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
using param_t = float; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
using param_t = float; \
return __VA_ARGS__(); \
} \
default: \
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
} \
} else { \
TORCH_CHECK(TYPE1 == TYPE2); \
switch (TYPE1) { \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
using param_t = at::BFloat16; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
using param_t = at::Half; \
return __VA_ARGS__(); \
} \
default: \
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
} \
} \
}()
#define UNUSED(x) (void)(x)
#define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension")
#define CHECK_INPUT(x) \
CHECK_CPU(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
CHECK_CPU(x); \
CHECK_LAST_DIM_CONTIGUOUS(x)
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
// [NB] Parallel Routines
//
// * at::parallel_for - applies for most of generic use cases, this will be compiled
// against openmp in default torch release.
//
// * parallel_for - same function as above, can choose payload partition scheme in
// balance211.
//
// * parallel_2d - parallel for 2 dimensions, used in GEMM, etc.
// this one will do payload balance across 2 dimensions.
//
// grain size for each thread
constexpr int GRAIN_SIZE = 1024;
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
inline T div_up(T x, T y) {
return (x + y - 1) / y;
}
// you can only use at::get_thread_num() with at::parallel_for()
// as it is lazy initialized, otherwise it will always return 0.
inline int get_thread_num() {
#if defined(_OPENMP)
return omp_get_thread_num();
#else
return 0;
#endif
}
// balance payload across each thread
template <typename T>
inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
#if 0
// onednn partition pattern
T& n_my = n_end;
if (nth <= 1 || n == 0) {
n_start = 0;
n_my = n;
} else {
T n1 = div_up(n, nth);
T n2 = n1 - 1;
T T1 = n - n2 * nth;
n_my = ith < T1 ? n1 : n2;
n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2;
}
n_end += n_start;
#else
// pytorch aten partition pattern
T n_my = div_up(n, nth);
n_start = ith * n_my;
n_end = std::min(n_start + n_my, n);
#endif
}
template <typename func_t>
inline void parallel_for(int n, const func_t& f) {
#if defined(_OPENMP)
#pragma omp parallel
{
int nth = omp_get_num_threads();
int ith = omp_get_thread_num();
int tbegin, tend;
balance211(n, nth, ith, tbegin, tend);
f(tbegin, tend);
}
#else
f(0, n);
#endif
}
// for 1d parallel, use `actual_nth`
// for 2d parallel, use even nths, e.g. 43->42
int inline adjust_num_threads(int m) {
int actual_nth = at::get_num_threads();
if (m == 1) {
return actual_nth;
}
return std::max(1, (actual_nth >> 1) * 2);
}
template <typename func_t>
inline void parallel_2d(int m, int n, const func_t& f) {
// make sure we have even num_threads
int nth = adjust_num_threads(m);
// [NOTE] thread blocking:
//
// 1) prefer square block per thread
// 2) use even number of CPU cores
// 3) use all `num_threads` cores
//
// we have:
// TM * TN = T
// BM / TM = BN / TN
// then:
// TM = ((BM / BN) * T) ^ 0.5
//
float r = float(m) / n;
int nth_m = std::ceil(std::sqrt(r * nth));
int nth_n = 1;
for (; nth_m > 0; --nth_m) {
nth_n = nth / nth_m;
if (nth_m * nth_n == nth) {
break;
}
}
#if defined(_OPENMP)
#pragma omp parallel num_threads(nth)
{
int ith = omp_get_thread_num();
int ith_m = ith / nth_n;
int ith_n = ith % nth_n;
int thread_block_m = div_up(m, nth_m);
int thread_block_n = div_up(n, nth_n);
int begin_m = ith_m * thread_block_m;
int end_m = std::min(m, begin_m + thread_block_m);
int begin_n = ith_n * thread_block_n;
int end_n = std::min(n, begin_n + thread_block_n);
f(begin_m, end_m, begin_n, end_n);
}
#else
f(0, m, 0, n);
#endif
}
// limit max cache blocks
// when we need to do pre-unpack for weights, e.g. fp8
#define MAX_CACHE_BLOCK_SIZE 4
template <typename T>
inline int get_cache_blocks(int chunk_size) {
// L2 2MB and ratio of 50%
const int L2_size = 2048 * 1024 >> 1;
return std::max(1, int(L2_size / (chunk_size * sizeof(T))));
}
template <>
inline int get_cache_blocks<at::Float8_e4m3fn>(int chunk_size) {
// fp8 uses bf16 as accumulate type
int cache_block_size = get_cache_blocks<at::BFloat16>(chunk_size);
return std::min(MAX_CACHE_BLOCK_SIZE, cache_block_size);
}
// 2d sequential loop in range : [mb0, mb1), [nb0, nb1)
template <typename T, typename func_t>
inline void loop_2d(int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1, int64_t chunk_size, const func_t& f) {
// get number of blocks for L2 in most inner loop
int64_t cache_blocks_nb = get_cache_blocks<T>(chunk_size);
// loop order: [NB / cache_blocks_nb, MB, cache_blocks_nb]
// TODO: implement reverse order of [MB / cache_blocks_mb, NB, cache_blocks_mb]
for (int64_t nbb = nb0; nbb < nb1; nbb += cache_blocks_nb) {
for (int64_t mb = mb0; mb < mb1; ++mb) {
for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, nb1); ++nb) {
f(mb, nb, nb - nbb);
}
}
}
}
// data indexing for dimension collapse
template <typename T>
inline T data_index_init(T offset) {
return offset;
}
template <typename T, typename... Args>
inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
offset = data_index_init(offset, std::forward<Args>(args)...);
x = offset % X;
return offset / X;
}
inline bool data_index_step() {
return true;
}
template <typename T, typename... Args>
inline bool data_index_step(T& x, const T& X, Args&&... args) {
if (data_index_step(std::forward<Args>(args)...)) {
x = ((x + 1) == X) ? 0 : (x + 1);
return x == 0;
}
return false;
}
// forced unroll for perf critical path
#if __has_attribute(always_inline)
#define ALWAYS_INLINE __attribute__((__always_inline__)) inline
#else
#define ALWAYS_INLINE inline
#endif
template <int n>
struct Unroll {
template <typename Func, typename... Args>
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
Unroll<n - 1>{}(f, args...);
f(std::integral_constant<int, n - 1>{}, args...);
}
};
template <>
struct Unroll<1> {
template <typename Func, typename... Args>
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
f(std::integral_constant<int, 0>{}, args...);
}
};
} // anonymous namespace

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,723 @@
#include "common.h"
#include "gemm.h"
#include "vec.h"
namespace {
// [NOTE]: extend attention for CPU
// 1. tune BLOCK_M and BLOCK_N
// 2. can handle non-contiguous k_exttend and v_extend
// 3. computes attention for prefix and extend separately
// 4. TODO: vectorize `pack_vnni` and `pack_vnni2`
//
template <typename index_t>
inline index_t get_index(index_t* ind, int i) {
return (ind == nullptr) ? (index_t)i : ind[i];
}
#if defined(CPU_CAPABILITY_AVX512)
// key: from [N, 32] to [32/2, N, 2]
template <typename scalar_t, typename index_t>
inline void pack_vnni_Nx32(
scalar_t* __restrict__ dst,
const scalar_t* __restrict__ src,
const index_t* __restrict__ ind,
int N,
int ld_src,
int ld_dst) {
__m512i vinputs[16];
int n = 0;
for (; n < N; ++n) {
index_t index = get_index(ind, n);
vinputs[n] = _mm512_loadu_si512(src + index * ld_src);
}
// padding with zero to avoid uninitialized vectors
for (; n < 16; ++n) {
vinputs[n] = _mm512_set1_epi32(0);
}
// pack key
transpose_16x16_32bit(vinputs);
const __mmask16 vmask = (1 << N) - 1;
for (int k = 0; k < 16; ++k) {
_mm512_mask_storeu_epi32(dst + k * ld_dst * 2, vmask, vinputs[k]);
}
}
// value: from [K, 32] to [K/2, 32, 2]
template <typename scalar_t, typename index_t>
inline void pack_vnni_Kx32(
scalar_t* __restrict__ dst,
const scalar_t* __restrict__ src,
const index_t* __restrict__ ind,
int K,
int ld_src,
int ld_dst) {
__m512i vinputs[2];
int k = 0;
for (; k < K; ++k) {
index_t index = get_index(ind, k);
vinputs[k] = _mm512_loadu_si512(src + index * ld_src);
}
// padding with zero to avoid uninitialized vectors
for (; k < 2; ++k) {
vinputs[k] = _mm512_set1_epi32(0);
}
// pack value
__m512i d0, d1;
std::tie(d0, d1) = transpose_2x32_16bit(vinputs[0], vinputs[1]);
_mm512_storeu_si512(dst + 0 * ld_dst * 2, d0);
_mm512_storeu_si512(dst + 0 * ld_dst * 2 + 32, d1);
}
#endif
// convert to vnni format
// from [N, K/2, 2] to [K/2, N, 2] for bfloat16 and float16
template <typename scalar_t, typename index_t>
void pack_vnni(
scalar_t* __restrict__ dst,
const scalar_t* __restrict__ src,
const index_t* __restrict__ ind,
int N,
int K,
int ld_src,
int ld_dst) {
#if defined(CPU_CAPABILITY_AVX512)
const int NB = div_up(N, 16);
const int KB = K / 32; // no remainder
const bool is_indexed = ind != nullptr;
for (int nb = 0; nb < NB; ++nb) {
for (int kb = 0; kb < KB; ++kb) {
// handle 16x512bits each block
int nb_size = std::min(N - nb * 16, 16);
pack_vnni_Nx32<scalar_t, index_t>(
/* dst */ dst + ((kb * 32) >> 1) * ld_dst * 2 + nb * 16 * 2,
/* src */ src + kb * 32 + (is_indexed ? 0 : nb * 16 * ld_src),
/* ind */ is_indexed ? ind + nb * 16 : nullptr,
/* N */ nb_size,
/* ld_src */ ld_src,
/* ld_dst */ ld_dst);
}
}
#else
for (int n = 0; n < N; ++n) {
index_t index = get_index(ind, n);
for (int k = 0; k < K / 2; ++k) {
for (int d = 0; d < 2; ++d) {
dst[k * ld_dst * 2 + n * 2 + d] = src[index * ld_src + k * 2 + d];
}
}
}
#endif
}
// convert to vnni format
// from [K/2, 2, N] to [K/2, N, 2] for bfloat16 and float16
template <typename scalar_t, typename index_t>
void pack_vnni2(
scalar_t* __restrict__ dst,
const scalar_t* __restrict__ src,
const index_t* __restrict__ ind,
int K,
int N,
int ld_src,
int ld_dst) {
#if defined(CPU_CAPABILITY_AVX512)
const int KB = div_up(K, 2);
const int NB = N / 32; // no remainder
const bool is_indexed = ind != nullptr;
for (int kb = 0; kb < KB; ++kb) {
for (int nb = 0; nb < NB; ++nb) {
// handle 2x512bits each block
int kb_size = std::min(K - kb * 2, 2);
pack_vnni_Kx32<scalar_t, index_t>(
/* dst */ dst + ((kb * 2) >> 1) * ld_dst * 2 + nb * 32 * 2,
/* src */ src + (is_indexed ? 0 : kb * 2 * ld_src) + nb * 32,
/* ind */ is_indexed ? ind + kb * 2 : nullptr,
/* K */ kb_size,
/* ld_src */ ld_src,
/* ld_dst */ ld_dst);
}
}
#else
int k = 0;
for (; k < (K >> 1) * 2; k += 2) {
index_t index0 = get_index(ind, k + 0);
index_t index1 = get_index(ind, k + 1);
for (int n = 0; n < N; ++n) {
dst[(k >> 1) * ld_dst * 2 + n * 2 + 0] = src[index0 * ld_src + n];
dst[(k >> 1) * ld_dst * 2 + n * 2 + 1] = src[index1 * ld_src + n];
}
}
if (K % 2 != 0) {
index_t index = get_index(ind, K - 1);
for (int n = 0; n < N; ++n) {
dst[(K >> 1) * ld_dst * 2 + n * 2 + 0] = src[index * ld_src + n];
dst[(K >> 1) * ld_dst * 2 + n * 2 + 1] = 0;
}
k += 2;
}
#endif
}
template <typename scalar_t>
inline void fill_stub(scalar_t* __restrict__ out, float val, int size) {
using Vec = at::vec::Vectorized<scalar_t>;
constexpr int kVecSize = Vec::size();
const Vec data_vec = Vec(static_cast<scalar_t>(val));
int d = 0;
#pragma GCC unroll 4
for (; d <= size - kVecSize; d += kVecSize) {
data_vec.store(out + d);
}
if (size - d > 0) {
data_vec.store(out + d, size - d);
}
}
template <typename scalar_t, int BLOCK_N>
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input) {
static_assert(BLOCK_N % 32 == 0);
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int COLS = BLOCK_N / 16;
auto store = [&](auto i) {
constexpr int col = i % COLS;
// for COLS = 2, 4 use 512bit store
if constexpr (col % 2 == 0) {
fVec a_fvec0 = fVec::loadu(input + col * 16);
fVec a_fvec1 = fVec::loadu(input + col * 16 + 16);
bVec out_bvec = convert_from_float_ext<scalar_t>(a_fvec0, a_fvec1);
out_bvec.store(out + col * 16);
}
};
Unroll<COLS>{}(store);
}
template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc, float s, int size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
const fVec s_fvec = fVec(s);
int d = 0;
#pragma GCC unroll 4
for (; d <= size - kVecSize; d += kVecSize) {
fVec a_fvec0 = fVec::loadu(acc + d) * s_fvec;
fVec a_fvec1 = fVec::loadu(acc + d + fVec::size()) * s_fvec;
bVec out_bvec = convert_from_float_ext<scalar_t>(a_fvec0, a_fvec1);
out_bvec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(acc[d] * s);
}
}
template <typename scalar_t, typename index_t, int BLOCK_M, int BLOCK_N>
void extend_attention_kernel_impl(
scalar_t* __restrict__ o_extend,
const scalar_t* __restrict__ q_extend,
const scalar_t* __restrict__ k_extend,
const scalar_t* __restrict__ v_extend,
const scalar_t* __restrict__ k_buffer,
const scalar_t* __restrict__ v_buffer,
const index_t* __restrict__ req_to_token,
const int64_t* __restrict__ req_pool_indices,
const int64_t* __restrict__ seq_lens,
const index_t* __restrict__ extend_seq_lens,
const index_t* __restrict__ extend_start_loc,
const void* __restrict__ buffer,
int batches,
int num_heads,
int num_heads_kv,
int head_size,
int head_size_v,
int q_strideM,
int q_strideH,
int ke_strideN,
int ke_strideH,
int ve_strideN,
int ve_strideH,
int k_strideN,
int k_strideH,
int v_strideN,
int v_strideH,
float scaling,
float logit_cap,
int max_num_reqs,
int max_context_len,
int max_total_num_tokens,
int max_len_extend,
int buffer_size_per_thread,
bool is_prefix_skipped) {
using Vec = at::vec::Vectorized<float>;
// strides
const int o_strideM = num_heads * head_size_v;
const int o_strideH = head_size_v;
// we use same buffer for packed key and value
const int ldb_tmp = std::max(head_size, head_size_v);
const bool has_logit_cap = logit_cap > 0;
float rlogit_cap = has_logit_cap ? 1 / logit_cap : 0.f;
const int num_groups = num_heads / num_heads_kv;
TORCH_CHECK(num_groups * num_heads_kv == num_heads);
// number of blocks along M
int MB = div_up(max_len_extend, BLOCK_M);
// parallel on [batches, num_heads, BM]
at::parallel_for(0, batches * num_heads * MB, 0, [&](int begin, int end) {
int bs{0}, head_id{0}, mb{0};
data_index_init(begin, bs, batches, head_id, num_heads, mb, MB);
int tid = at::get_thread_num();
// s_i and s_delta: [BLOCK_M, BLOCK_N]
float* __restrict__ s_i = reinterpret_cast<float*>((char*)(buffer) + tid * buffer_size_per_thread);
float* __restrict__ s_delta = s_i;
// v_prime: [BLOCK_M, head_size_v]
float* __restrict__ v_prime = s_i + BLOCK_M * BLOCK_N;
// s_delta2: [BLOCK_M, BLOCK_N]; copy of s_delta in scalar_t
scalar_t* __restrict__ s_delta2 = reinterpret_cast<scalar_t*>(v_prime + BLOCK_N * head_size_v);
// Btmp: [BLOCK_N, max(head_size, head_size_v)]
scalar_t* __restrict__ Btmp = s_delta2 + BLOCK_M * BLOCK_N;
// init Btmp just once for each thread to prevent NaN
fill_stub(Btmp, 0.f, BLOCK_N * ldb_tmp);
alignas(64) float s_prime[BLOCK_M];
alignas(64) float m_prime[BLOCK_M];
for (int i = begin; i < end; ++i) {
// seq_len = prefix + extend
int head_kv_id = head_id / num_groups;
int seq_len = seq_lens[bs];
int seq_len_extend = extend_seq_lens[bs];
int seq_len_prefix = seq_len - seq_len_extend;
int seq_extend_start_loc = extend_start_loc[bs];
int req_pool_id = req_pool_indices[bs];
TORCH_CHECK(seq_len_prefix >= 0, "prefix len < 0!");
TORCH_CHECK(seq_len <= max_context_len, "seq_len out of scope!");
TORCH_CHECK(req_pool_id < max_num_reqs, "req_pool_id out of scope!");
if (is_prefix_skipped) {
TORCH_CHECK(seq_len_prefix == 0, "extend attention: expect seq_len_prefix to be 0, got ", seq_len_prefix);
}
// offset and size in MB
int m = mb * BLOCK_N;
int m_size = std::min(BLOCK_M, seq_len_extend - m);
if (m_size <= 0) {
data_index_step(bs, batches, head_id, num_heads, mb, MB);
continue;
}
// get query
const scalar_t* __restrict__ q_ptr = q_extend + (seq_extend_start_loc + m) * q_strideM + head_id * q_strideH;
// init v', s' and m'
fill_stub(v_prime, 0.f, m_size * head_size_v);
fill_stub(s_prime, 0.f, m_size);
fill_stub(m_prime, -std::numeric_limits<scalar_t>::infinity(), m_size);
// stage 1: compute scores with prefix
for (int n = 0; n < seq_len_prefix; n += BLOCK_N) {
int n_size = std::min(BLOCK_N, seq_len_prefix - n);
// `n_size` is K in 2nd gemm, pad to TILE_K;
const int padded_n_size = div_up(n_size, TILE_K) * TILE_K;
// get key and pack
pack_vnni<scalar_t, index_t>(
/* dst */ Btmp,
/* src */ k_buffer + head_kv_id * k_strideH,
/* ind */ req_to_token + req_pool_id * max_context_len + n,
/* N */ n_size,
/* K */ head_size,
/* ld_src */ k_strideN,
/* ld_dst */ BLOCK_N);
// calculate s_i <- Q @ K
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ n_size,
/* K */ head_size,
/* lda */ q_strideM,
/* ldb */ BLOCK_N,
/* ldc */ BLOCK_N,
/* add_C */ false,
/* A */ q_ptr,
/* B */ Btmp,
/* C */ s_i);
const Vec scale_vec = Vec(scaling);
for (int row = 0; row < m_size; ++row) {
// s_i <- s_i * scale
at::vec::map<float>(
[scale_vec](Vec x) { return x * scale_vec; }, s_i + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
// TODO: `tanh` from torch uses sleef u10, going to be slow
if (has_logit_cap) {
at::vec::map<float>(
[logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); },
s_i + row * BLOCK_N,
s_i + row * BLOCK_N,
n_size);
}
// m_i: max value per row
float m_i = at::vec::reduce_all<float>(
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i + row * BLOCK_N, n_size);
m_i = std::max(m_i, m_prime[row]);
// m_delta <- exp(m' - m_i)
float m_delta = std::exp(m_prime[row] - m_i);
// s_delta <- exp(s_i - m_i)
at::vec::map<float>(
[m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
// s' <- s' * m_delta + sum(s_delta)
s_prime[row] *= m_delta;
s_prime[row] +=
at::vec::reduce_all<float>([](Vec& x, Vec& y) { return x + y; }, s_delta + row * BLOCK_N, n_size);
m_prime[row] = m_i;
// v' <- v' * m_delta
at::vec::map<float>(
[m_delta](Vec x) { return x * Vec(m_delta); },
v_prime + row * head_size_v,
v_prime + row * head_size_v,
head_size_v);
// pad s_delta with 0 first and then convert to scalar_t
fill_stub(s_delta + row * BLOCK_N + n_size, 0.f, padded_n_size - n_size);
copy_stub<scalar_t, BLOCK_N>(s_delta2 + row * BLOCK_N, s_delta + row * BLOCK_N);
}
// get value and pack
pack_vnni2<scalar_t, index_t>(
/* dst */ Btmp,
/* src */ v_buffer + head_kv_id * v_strideH,
/* ind */ req_to_token + req_pool_id * max_context_len + n,
/* K */ n_size,
/* N */ head_size_v,
/* ld_src */ v_strideN,
/* ld_dst */ head_size_v);
// calculate V' <- s_delta @ V + V'
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ head_size_v,
/* K */ padded_n_size, // n_size
/* lda */ BLOCK_N,
/* ldb */ head_size_v,
/* ldc */ head_size_v,
/* add_C */ true,
/* A */ s_delta2,
/* B */ Btmp,
/* C */ v_prime);
} // loop with seq_len_prefix
// stage 2: compute the triangle part
int num_keys = std::min(seq_len_extend, m + BLOCK_M);
for (int n = 0; n < num_keys; n += BLOCK_N) {
int n_size = std::min(BLOCK_N, num_keys - n);
// `n_size` is K in 2nd gemm, pad to TILE_K;
const int padded_n_size = div_up(n_size, TILE_K) * TILE_K;
// get key and pack
pack_vnni<scalar_t, index_t>(
/* dst */ Btmp,
/* src */ k_extend + (seq_extend_start_loc + n) * ke_strideN + head_kv_id * ke_strideH,
/* ind */ nullptr,
/* N */ n_size,
/* K */ head_size,
/* ld_src */ ke_strideN,
/* ld_dst */ BLOCK_N);
// calculate s_i <- Q @ K
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ n_size,
/* K */ head_size,
/* lda */ q_strideM,
/* ldb */ BLOCK_N,
/* ldc */ BLOCK_N,
/* add_C */ false,
/* A */ q_ptr,
/* B */ Btmp,
/* C */ s_i);
// apply causal mask
if (num_keys - n <= BLOCK_N) {
for (int row = 0; row < m_size; ++row) {
int last_col = m + row - n;
// fill [last_col + 1, n_size) to -inf
float* row_ptr = s_i + row * BLOCK_N;
fill_stub(row_ptr + last_col + 1, -std::numeric_limits<float>::infinity(), n_size - last_col - 1);
}
}
const Vec scale_vec = Vec(scaling);
for (int row = 0; row < m_size; ++row) {
// s_i <- s_i * scale
at::vec::map<float>(
[scale_vec](Vec x) { return x * scale_vec; }, s_i + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
// TODO: `tanh` from torch uses sleef u10, going to be slow
if (has_logit_cap) {
at::vec::map<float>(
[logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); },
s_i + row * BLOCK_N,
s_i + row * BLOCK_N,
n_size);
}
// m_i: max value per row
float m_i = at::vec::reduce_all<float>(
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i + row * BLOCK_N, n_size);
m_i = std::max(m_i, m_prime[row]);
// m_delta <- exp(m' - m_i)
float m_delta = std::exp(m_prime[row] - m_i);
// s_delta <- exp(s_i - m_i)
at::vec::map<float>(
[m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
// s' <- s' * m_delta + sum(s_delta)
s_prime[row] *= m_delta;
s_prime[row] +=
at::vec::reduce_all<float>([](Vec& x, Vec& y) { return x + y; }, s_delta + row * BLOCK_N, n_size);
m_prime[row] = m_i;
// v' <- v' * m_delta
at::vec::map<float>(
[m_delta](Vec x) { return x * Vec(m_delta); },
v_prime + row * head_size_v,
v_prime + row * head_size_v,
head_size_v);
// pad s_delta with 0 first and then convert to scalar_t
fill_stub(s_delta + row * BLOCK_N + n_size, 0.f, padded_n_size - n_size);
copy_stub<scalar_t, BLOCK_N>(s_delta2 + row * BLOCK_N, s_delta + row * BLOCK_N);
}
// get value and pack
pack_vnni2<scalar_t, index_t>(
/* dst */ Btmp,
/* src */ v_extend + (seq_extend_start_loc + n) * ve_strideN + head_kv_id * ve_strideH,
/* ind */ nullptr,
/* K */ n_size,
/* N */ head_size_v,
/* ld_src */ ve_strideN,
/* ld_dst */ head_size_v);
// calculate V' <- s_delta @ V + V'
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ head_size_v,
/* K */ padded_n_size, // n_size
/* lda */ BLOCK_N,
/* ldb */ head_size_v,
/* ldc */ head_size_v,
/* add_C */ true,
/* A */ s_delta2,
/* B */ Btmp,
/* C */ v_prime);
} // loop with seq_len_extend
scalar_t* __restrict__ out_ptr = o_extend + (seq_extend_start_loc + m) * o_strideM + head_id * o_strideH;
for (int row = 0; row < m_size; ++row) {
float s = 1 / s_prime[row];
copy_stub<scalar_t>(out_ptr + row * o_strideM, v_prime + row * head_size_v, s, head_size_v);
}
// move to the next index
data_index_step(bs, batches, head_id, num_heads, mb, MB);
}
at::native::cpublas::brgemm_release();
});
}
} // anonymous namespace
// q_extend, k_extend, v_extend, o_extend: contiguous tensors
// k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
//
// q_extend: [num_tokens, num_heads, head_size]
// k_extend: [num_extend_tokens, num_heads, head_size]
// v_extend: [num_extend_tokens, num_heads, head_size]
// o_extend: [num_tokens, num_heads, head_size]
// k_buffer: [max_total_num_tokens, num_heads, head_size]
// v_buffer: [max_total_num_tokens, num_heads, head_size]
// req_to_token: [max_num_reqs, max_context_len] int32 or int64
// req_pool_indices: [num_seqs] int64
// seq_lens: [num_seqs] int64
// extend_seq_lens: [num_seqs]
// extend_start_loc: [num_seqs]
//
void extend_attention_cpu(
at::Tensor& q_extend,
at::Tensor& k_extend,
at::Tensor& v_extend,
at::Tensor& o_extend,
at::Tensor& k_buffer,
at::Tensor& v_buffer,
at::Tensor& req_to_token,
at::Tensor& req_pool_indices,
at::Tensor& seq_lens,
at::Tensor& extend_seq_lens,
at::Tensor& extend_start_loc,
int64_t max_len_extend,
double sm_scale,
double logit_cap) {
RECORD_FUNCTION(
"sgl-kernel::extend_attention_cpu",
std::vector<c10::IValue>(
{q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_token,
req_pool_indices,
seq_lens,
extend_seq_lens,
extend_start_loc}));
CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_extend);
CHECK_INPUT(o_extend);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_extend);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_extend);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_buffer);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_buffer);
int num_seqs = seq_lens.size(0);
int max_num_reqs = req_to_token.size(0);
int max_context_len = req_to_token.size(1);
int max_total_num_tokens = k_buffer.size(0);
int num_heads = q_extend.size(1);
int num_heads_kv = k_extend.size(1);
int head_size = q_extend.size(2);
int head_size_v = v_extend.size(2);
// strides for q_extend, k_extend and v_extend
int q_strideM = q_extend.stride(0);
int q_strideH = q_extend.stride(1);
int ke_strideN = k_extend.stride(0);
int ke_strideH = k_extend.stride(1);
int ve_strideN = v_extend.stride(0);
int ve_strideH = v_extend.stride(1);
// strides for k_buffer and v_buffer
int k_strideN = k_buffer.stride(0);
int k_strideH = k_buffer.stride(1);
int v_strideN = v_buffer.stride(0);
int v_strideH = v_buffer.stride(1);
// check sizes
CHECK_EQ(req_pool_indices.size(0), num_seqs);
CHECK_EQ(extend_seq_lens.size(0), num_seqs);
CHECK_EQ(extend_start_loc.size(0), num_seqs);
CHECK_EQ(v_extend.size(1), num_heads_kv);
CHECK_EQ(k_buffer.size(1), v_buffer.size(1));
// MLA will skip prefix part
const bool is_prefix_skipped = k_buffer.size(1) != num_heads_kv;
// check index data types
const auto index_dtype = req_to_token.scalar_type();
TORCH_CHECK(
index_dtype == at::kInt || index_dtype == at::kLong,
"extend: expect req_to_token to be int32 or int64, got ",
index_dtype);
TORCH_CHECK(seq_lens.scalar_type() == at::kLong, "extend: expect req_lens to be int64, got ", seq_lens.scalar_type());
TORCH_CHECK(
req_pool_indices.scalar_type() == at::kLong,
"extend: expect req_pool_indices to be int64, got ",
req_pool_indices.scalar_type());
TORCH_CHECK(
extend_seq_lens.scalar_type() == index_dtype && extend_start_loc.scalar_type() == index_dtype,
"extend: expect extend_seq_lens and extend_start_loc to have same dtype as req_to_token.");
// D and DV need to be 32x as we transpose by 512-bit
TORCH_CHECK(head_size % 32 == 0, "invalid head_size ", head_size);
TORCH_CHECK(head_size_v % 32 == 0, "invalid head_size_v ", head_size_v);
// block size for query seq length
constexpr int BLOCK_M = 32;
// block size for key/value seq length
constexpr int BLOCK_N = 32;
const int size_per_thread =
/* s_i */ BLOCK_M * BLOCK_N * sizeof(float) +
/* v_prime */ BLOCK_M * head_size_v * sizeof(float) +
/* s_delta */ BLOCK_M * BLOCK_N * sizeof(uint16_t) +
/* Btmp */ BLOCK_N * std::max(head_size, head_size_v) * sizeof(uint16_t);
int num_threads = at::get_num_threads();
auto buffer = at::empty({num_threads, size_per_thread}, q_extend.options().dtype(at::kChar));
AT_DISPATCH_REDUCED_FLOATING_TYPES(q_extend.scalar_type(), "extend_attention_kernel", [&] {
AT_DISPATCH_INDEX_TYPES(index_dtype, "extend_attention_indices", [&] {
extend_attention_kernel_impl<scalar_t, index_t, BLOCK_M, BLOCK_N>(
o_extend.data_ptr<scalar_t>(),
q_extend.data_ptr<scalar_t>(),
k_extend.data_ptr<scalar_t>(),
v_extend.data_ptr<scalar_t>(),
k_buffer.data_ptr<scalar_t>(),
v_buffer.data_ptr<scalar_t>(),
req_to_token.data_ptr<index_t>(),
req_pool_indices.data_ptr<int64_t>(),
seq_lens.data_ptr<int64_t>(),
extend_seq_lens.data_ptr<index_t>(),
extend_start_loc.data_ptr<index_t>(),
buffer.data_ptr(),
num_seqs,
num_heads,
num_heads_kv,
head_size,
head_size_v,
q_strideM,
q_strideH,
ke_strideN,
ke_strideH,
ve_strideN,
ve_strideH,
k_strideN,
k_strideH,
v_strideN,
v_strideH,
sm_scale,
logit_cap,
max_num_reqs,
max_context_len,
max_total_num_tokens,
max_len_extend,
size_per_thread,
is_prefix_skipped);
});
});
}

View File

@@ -0,0 +1,525 @@
#include "gemm.h"
#include "common.h"
#include "vec.h"
namespace {
// packed layout:
// quants {N, K} int8_t
// comp {N} int32_t
template <int BLOCK_N>
inline void s8s8_compensation(int8_t* __restrict__ packed, int K) {
#if defined(CPU_CAPABILITY_AVX512)
constexpr int COLS = BLOCK_N / 16;
__m512i vcomp[COLS];
for (int col = 0; col < COLS; ++col) {
vcomp[col] = _mm512_setzero_si512();
}
const int64_t offset = BLOCK_N * K;
const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
for (int k = 0; k < K / 4; ++k) {
for (int col = 0; col < COLS; ++col) {
__m512i vb = _mm512_loadu_si512((const __m512i*)(packed + k * BLOCK_N * 4 + col * 64));
vcomp[col] = _mm512_dpbusd_epi32(vcomp[col], off, vb);
}
}
for (int col = 0; col < COLS; ++col) {
_mm512_storeu_si512((__m512i*)(packed + offset + col * 64), vcomp[col]);
}
#else
TORCH_CHECK(false, "s8s8_compensation not implemented!");
#endif
}
// convert to vnni format
// from [N, K] to [K/2, N, 2] for bfloat16 and float16
template <typename packed_t>
inline void pack_vnni(packed_t* __restrict__ packed, const packed_t* __restrict__ weight, int N, int K) {
const int VNNI_BLK = 2;
for (int n = 0; n < N; ++n) {
for (int k = 0; k < K / VNNI_BLK; ++k) {
for (int d = 0; d < VNNI_BLK; ++d) {
packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d];
}
}
}
}
template <>
inline void pack_vnni<int8_t>(int8_t* __restrict__ packed, const int8_t* __restrict__ weight, int N, int K) {
constexpr int BLOCK_N = block_size_n();
TORCH_CHECK(N == BLOCK_N);
const int VNNI_BLK = 4;
for (int n = 0; n < N; ++n) {
for (int k = 0; k < K / VNNI_BLK; ++k) {
for (int d = 0; d < VNNI_BLK; ++d) {
packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d];
}
}
}
s8s8_compensation<BLOCK_N>(packed, K);
}
template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec data0 = fVec::loadu(input + d);
fVec data1 = fVec::loadu(input + d + fVec::size());
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d]);
}
}
template <typename scalar_t>
inline void copy_add_stub(
scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d);
fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size());
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d] + bias[d]);
}
}
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn {
static inline void apply(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
const float* __restrict__ bias,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
static inline void apply(
const at::BFloat16* __restrict__ A,
const at::BFloat16* __restrict__ B,
at::BFloat16* __restrict__ C,
const float* __restrict__ bias,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;
// prefetch distance
constexpr int PREFETCH_SIZE_K = 0;
__m512bh va;
__m512bh vb[COLS];
__m512 vc[ROWS * COLS];
auto loadc = [&](auto i) {
constexpr int col = i % COLS;
if constexpr (has_bias) {
vc[i] = _mm512_loadu_ps(bias + col * 16);
} else {
vc[i] = _mm512_set1_ps(0.f);
}
};
Unroll<ROWS * COLS>{}(loadc);
const int64_t K2 = K >> 1;
const int64_t lda2 = lda >> 1;
const int64_t ldb2 = ldb; // ldb * 2 >> 1;
const float* a_ptr = reinterpret_cast<const float*>(A);
const float* b_ptr = reinterpret_cast<const float*>(B);
auto compute = [&](auto i, int64_t k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
}
if constexpr (row == 0) {
vb[col] = (__m512bh)(_mm512_loadu_si512(b_ptr + k * ldb2 + col * 16));
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
}
}
vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]);
};
for (int64_t k = 0; k < K2; ++k) {
Unroll<ROWS * COLS>{}(compute, k);
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
// for COLS = 2, 4 use 512bit store
// for COLS = 1, 3 use 256bit store
if constexpr (COLS % 2 == 0) {
if constexpr (col % 2 == 0) {
_mm512_storeu_si512(
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
}
} else {
_mm256_storeu_si256(reinterpret_cast<__m256i*>(C + row * ldc + col * 16), (__m256i)(_mm512_cvtneps_pbh(vc[i])));
}
};
Unroll<ROWS * COLS>{}(storec);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, \
B + nb_start * 2, \
C + mb_start * ldc + nb_start, \
has_bias ? bias + nb_start : nullptr, \
K, \
lda, \
ldb, \
ldc);
template <typename scalar_t, bool has_bias>
struct brgemm {
static inline void apply(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
float* __restrict__ Ctmp,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
constexpr int BLOCK_N = block_size_n();
at::native::cpublas::brgemm(M, N, K, lda, ldb, BLOCK_N, /* add_C */ false, A, B, Ctmp);
// copy from Ctmp to C
for (int64_t m = 0; m < M; ++m) {
if constexpr (has_bias) {
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
} else {
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
}
}
}
};
template <typename scalar_t, bool has_bias>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
float* __restrict__ Ctmp,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg) {
if (brg) {
brgemm<scalar_t, has_bias>::apply(A, B, C, Ctmp, bias, M, N, K, lda, ldb, ldc);
return;
}
// pattern: 1-4-16, N = 16, 32, 48, 64
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 64;
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
for (int mb = 0; mb < MB; ++mb) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch (mb_size << 4 | nb_size >> 4) {
// mb_size = 1
case 0x11:
LAUNCH_TINYGEMM_KERNEL_NN(1, 16);
break;
case 0x12:
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
break;
case 0x13:
LAUNCH_TINYGEMM_KERNEL_NN(1, 48);
break;
case 0x14:
LAUNCH_TINYGEMM_KERNEL_NN(1, 64);
break;
// mb_size = 2
case 0x21:
LAUNCH_TINYGEMM_KERNEL_NN(2, 16);
break;
case 0x22:
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
break;
case 0x23:
LAUNCH_TINYGEMM_KERNEL_NN(2, 48);
break;
case 0x24:
LAUNCH_TINYGEMM_KERNEL_NN(2, 64);
break;
// mb_size = 3
case 0x31:
LAUNCH_TINYGEMM_KERNEL_NN(3, 16);
break;
case 0x32:
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
break;
case 0x33:
LAUNCH_TINYGEMM_KERNEL_NN(3, 48);
break;
case 0x34:
LAUNCH_TINYGEMM_KERNEL_NN(3, 64);
break;
// mb_size = 4
case 0x41:
LAUNCH_TINYGEMM_KERNEL_NN(4, 16);
break;
case 0x42:
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
break;
case 0x43:
LAUNCH_TINYGEMM_KERNEL_NN(4, 48);
break;
case 0x44:
LAUNCH_TINYGEMM_KERNEL_NN(4, 64);
break;
default:
TORCH_CHECK(false, "Unexpected block size, ", mb_size, " x ", nb_size);
}
}
}
}
template <typename scalar_t>
void weight_packed_linear_kernel_impl(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ mat1,
const scalar_t* __restrict__ mat2,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t mat1_strideM,
int64_t out_strideM) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
const bool use_brgemm = can_use_brgemm<scalar_t>(M);
// parallel on [MB, NB]
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
// for brgemm, use float32 for accumulate
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
loop_2d<scalar_t>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
tinygemm_kernel<scalar_t, has_bias>(
/* A */ mat1 + mb_start * mat1_strideM,
/* B */ mat2 + nb_start * K /* nb * BLOCK_N * K */,
/* C */ out + mb_start * out_strideM + nb_start,
/* Ctmp*/ Ctmp,
/* bias*/ bias + nb_start,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ mat1_strideM,
/* ldb */ nb_size,
/* ldc */ out_strideM,
/* brg */ use_brgemm);
});
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
});
}
} // anonymous namespace
// tinygemm interface
template <typename scalar_t>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
float* __restrict__ Ctmp,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg) {
tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, nullptr, M, N, K, lda, ldb, ldc, brg);
}
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
template void tinygemm_kernel<TYPE>( \
const TYPE* __restrict__ A, \
const TYPE* __restrict__ B, \
TYPE* __restrict__ C, \
float* __restrict__ Ctmp, \
int64_t M, \
int64_t N, \
int64_t K, \
int64_t lda, \
int64_t ldb, \
int64_t ldc, \
bool brg)
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
at::Tensor convert_weight_packed(at::Tensor& weight) {
// for 3d moe weights
// weight : [E, OC, IC]
// w1 : [E, 2N, K]
// w2 : [E, K, N]
CHECK_INPUT(weight);
const int64_t ndim = weight.ndimension();
TORCH_CHECK(ndim == 2 || ndim == 3, "expect weight to be 2d or 3d, got ", ndim, "d tensor.");
const auto st = weight.scalar_type();
const int64_t E = ndim == 3 ? weight.size(0) : 1;
const int64_t OC = ndim == 3 ? weight.size(1) : weight.size(0);
const int64_t IC = ndim == 3 ? weight.size(2) : weight.size(1);
// we handle 2 TILE_N at a time.
TORCH_CHECK(OC % TILE_N == 0, "invalid weight out features ", OC);
TORCH_CHECK(IC % TILE_K == 0, "invalid weight input features ", IC);
constexpr int64_t BLOCK_N = block_size_n();
const int64_t NB = div_up(OC, BLOCK_N);
// use phony sizes here [E, OC, IC], for each [E], [OC, IC] -> [IC / 2, OC, 2]
auto packed_weight = at::empty({}, weight.options());
const int64_t stride = OC * IC;
TORCH_CHECK(
st == at::kBFloat16 || st == at::kHalf || st == at::kChar || st == at::kFloat8_e4m3fn,
"expect weight to be bfloat16, float16, int8 or fp8_e4m3.");
CPU_DISPATCH_PACKED_TYPES(st, [&] {
// adjust most inner dimension size
const int packed_row_size = get_row_size<packed_t>(IC);
auto sizes = weight.sizes().vec();
sizes[ndim - 1] = packed_row_size;
packed_weight.resize_(sizes);
const packed_t* w_data = weight.data_ptr<packed_t>();
packed_t* packed_data = packed_weight.data_ptr<packed_t>();
// parallel on {E, NB}
at::parallel_for(0, E * NB, 0, [&](int64_t begin, int64_t end) {
int64_t e{0}, nb{0};
data_index_init(begin, e, E, nb, NB);
for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
int64_t n = nb * BLOCK_N;
int64_t n_size = std::min(BLOCK_N, OC - n);
pack_vnni<packed_t>(
packed_data + e * OC * packed_row_size + n * packed_row_size, w_data + e * stride + n * IC, n_size, IC);
// move to the next index
data_index_step(e, E, nb, NB);
}
});
});
return packed_weight;
}
// mat1 : [M, K]
// mat2 : [N, K]
// bias : [N]
// out : [M, N]
//
at::Tensor
weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional<at::Tensor>& bias, bool is_vnni) {
RECORD_FUNCTION("sgl-kernel::weight_packed_linear", std::vector<c10::IValue>({mat1, mat2, bias}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
CHECK_INPUT(mat2);
int64_t M = mat1.size(0);
int64_t N = mat2.size(0);
int64_t K = mat2.size(1);
CHECK_EQ(mat1.size(1), K);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
auto out = at::empty({M, N}, mat1.options());
// strides
int64_t mat1_strideM = mat1.stride(0);
int64_t out_strideM = out.stride(0);
const bool has_bias = bias.has_value();
const float* bias_data = nullptr;
if (has_bias) {
CHECK_EQ(bias.value().size(0), N);
bias_data = bias.value().data_ptr<float>();
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "weight_packed_linear_kernel_impl", [&] {
weight_packed_linear_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
mat1.data_ptr<scalar_t>(),
packed_w.data_ptr<scalar_t>(),
bias_data,
M,
N,
K,
mat1_strideM,
out_strideM);
});
return out;
}

202
sgl-kernel/csrc/cpu/gemm.h Normal file
View File

@@ -0,0 +1,202 @@
#pragma once
#include <ATen/native/CPUBlas.h>
#include "common.h"
// amx-bf16
#define TILE_M 16
#define TILE_N 16
#define TILE_K 32
// block size for AMX gemm
constexpr int block_size_m() {
return 2 * TILE_M;
}
constexpr int block_size_n() {
return 2 * TILE_N;
}
// define threshold using brgemm (intel AMX)
template <typename T>
inline bool can_use_brgemm(int M);
template <>
inline bool can_use_brgemm<at::BFloat16>(int M) {
return M > 4;
}
template <>
inline bool can_use_brgemm<at::Half>(int M) {
return true;
}
// this requires PyTorch 2.7 or above
template <>
inline bool can_use_brgemm<int8_t>(int M) {
return M > 4;
}
template <>
inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) {
return M > 4;
}
// work around compiler internal error
#define BLOCK_K 128 // 4 * TILE_K
// adjust leading dimension size for K
template <typename T>
inline int64_t get_row_size(int64_t K) {
return K;
}
template <>
inline int64_t get_row_size<int8_t>(int64_t K) {
return K + sizeof(int32_t);
}
inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) {
return use_int8_w8a8 ? K + sizeof(int32_t) : K;
}
// pack weight to vnni format
at::Tensor convert_weight_packed(at::Tensor& weight);
// moe implementations for int8 w8a8
template <typename scalar_t>
void fused_experts_int8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ ic2,
uint8_t* __restrict__ A_tmp,
float* __restrict__ C_tmp,
uint8_t* __restrict__ Aq_tmp,
float* __restrict__ As_tmp,
const scalar_t* __restrict__ input,
const int8_t* __restrict__ packed_w1,
const int8_t* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
const float* __restrict__ topk_weights,
const int32_t* __restrict__ sorted_ids,
const int32_t* __restrict__ expert_ids,
const int32_t* __restrict__ offsets,
int64_t M,
int64_t N,
int64_t K,
int64_t E,
int64_t topk,
int64_t num_tokens_post_pad);
// moe implementations for fp8 w8a16
template <typename scalar_t>
void fused_experts_fp8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic0,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ ic2,
scalar_t* __restrict__ A_tmp,
scalar_t* __restrict__ B_tmp,
float* __restrict__ C_tmp,
const scalar_t* __restrict__ input,
const at::Float8_e4m3fn* __restrict__ packed_w1,
const at::Float8_e4m3fn* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
int64_t block_size_N,
int64_t block_size_K,
const float* __restrict__ topk_weights,
const int32_t* __restrict__ sorted_ids,
const int32_t* __restrict__ expert_ids,
const int32_t* __restrict__ offsets,
int64_t M,
int64_t N,
int64_t K,
int64_t E,
int64_t topk,
int64_t num_tokens_post_pad);
// shared expert implementation for int8 w8a8
template <typename scalar_t>
void shared_expert_int8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic1,
float* __restrict__ C_tmp,
uint8_t* __restrict__ Aq_tmp,
float* __restrict__ As_tmp,
const scalar_t* __restrict__ input,
const int8_t* __restrict__ packed_w1,
const int8_t* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
const scalar_t* __restrict__ fused_experts_out,
float routed_scaling_factor,
int64_t M,
int64_t N,
int64_t K);
template <typename scalar_t>
void shared_expert_fp8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic0,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ B_tmp,
float* __restrict__ C_tmp,
const scalar_t* __restrict__ input,
const at::Float8_e4m3fn* __restrict__ packed_w1,
const at::Float8_e4m3fn* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
int64_t block_size_N,
int64_t block_size_K,
const scalar_t* __restrict__ fused_experts_out,
float routed_scaling_factor,
int64_t M,
int64_t N,
int64_t K);
// tinygemm interface
template <typename scalar_t>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
float* __restrict__ Ctmp,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg);
template <typename scalar_t>
void tinygemm_kernel(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
scalar_t* __restrict__ C,
int32_t* __restrict__ Ctmp,
const float* __restrict__ As,
const float* __restrict__ Bs,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg);
template <typename scalar_t>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const at::Float8_e4m3fn* __restrict__ B,
scalar_t* __restrict__ C,
scalar_t* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ scale,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg,
int64_t block_size_K,
bool do_unpack = true);

View File

@@ -0,0 +1,551 @@
#include "common.h"
#include "gemm.h"
#include "vec.h"
namespace {
template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec data0 = fVec::loadu(input + d);
fVec data1 = fVec::loadu(input + d + fVec::size());
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d]);
}
}
template <typename scalar_t>
inline void copy_add_stub(
scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d);
fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size());
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d] + bias[d]);
}
}
inline void unpack_B(
at::BFloat16* __restrict__ Btmp,
const at::Float8_e4m3fn* __restrict__ packed_B,
int N,
int K,
int ldb,
int ldb_tmp,
float scale) {
#if defined(CPU_CAPABILITY_AVX512)
// [K/2, N, 2]
const int K2 = K >> 1;
const int ldb2 = ldb; // ldb * 2 >> 1;
const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(packed_B);
const __m512 vd = _mm512_set1_ps(scale);
constexpr int BLOCK_N = block_size_n();
static_assert(BLOCK_N == 32);
// prefetch distance
constexpr int PREFETCH_SIZE_K = 64;
#pragma GCC unroll 4
for (int k = 0; k < K2; ++k) {
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2);
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2, _MM_HINT_T0);
}
__m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0);
__m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1);
__m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0);
__m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1);
// Apply scale
__m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0));
__m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1));
__m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0));
__m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1));
f0_lo = _mm512_mul_ps(f0_lo, vd);
f0_hi = _mm512_mul_ps(f0_hi, vd);
f1_lo = _mm512_mul_ps(f1_lo, vd);
f1_hi = _mm512_mul_ps(f1_hi, vd);
bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo);
bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo);
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 0, (__m512i)bf16_0);
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 32, (__m512i)bf16_1);
}
#else
TORCH_CHECK(false, "unpack_B: scalar path not implemented!");
#endif
}
template <typename scalar_t, typename packed_t, bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn {
static inline void apply(
const scalar_t* __restrict__ A,
const packed_t* __restrict__ B,
scalar_t* __restrict__ C,
const float* __restrict__ bias,
const float* __restrict__ scale,
int K,
int lda,
int ldb,
int ldc,
int64_t block_size_K) {
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn<at::BFloat16, at::Float8_e4m3fn, has_bias, BLOCK_M, BLOCK_N> {
static inline void apply(
const at::BFloat16* __restrict__ A,
const at::Float8_e4m3fn* __restrict__ B,
at::BFloat16* __restrict__ C,
const float* __restrict__ bias,
const float* __restrict__ scale,
int K,
int lda,
int ldb,
int ldc,
int64_t block_size_K) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;
const int KB = div_up(K, BLOCK_K);
// prefetch distance
constexpr int PREFETCH_SIZE_K = 64;
constexpr int PREFETCH_SIZE_KB = 1;
__m512bh va;
__m512bh vb[COLS];
__m512 vc[ROWS * COLS];
__m512 vsum[ROWS * COLS];
// block quant scale
__m512 vscale;
auto loadc = [&](auto i) {
constexpr int col = i % COLS;
if constexpr (has_bias) {
vc[i] = _mm512_loadu_ps(bias + col * 16);
} else {
vc[i] = _mm512_setzero_ps();
}
};
Unroll<ROWS * COLS>{}(loadc);
const int lda2 = lda >> 1;
const int ldb2 = ldb; // ldb * 2 >> 1;
const float* a_ptr = reinterpret_cast<const float*>(A);
const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(B);
auto compute = [&](auto i, int k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(a_ptr + row * lda2 + k + PREFETCH_SIZE_K, _MM_HINT_T0);
}
}
if constexpr (row == 0) {
if constexpr (col % 2 == 0) {
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + col * 16);
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
}
vb[col + 0] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 0));
vb[col + 1] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 1));
}
}
vsum[i] = _mm512_dpbf16_ps(vsum[i], va, vb[col]);
};
constexpr int BLOCK_K2 = BLOCK_K >> 1;
for (int kb = 0; kb < KB; ++kb) {
int kb_start = kb * BLOCK_K2;
int kb_end = std::min(K >> 1, kb_start + BLOCK_K2);
// 1. load scale vector
vscale = _mm512_set1_ps(scale[kb]);
if constexpr (PREFETCH_SIZE_KB > 0) {
_mm_prefetch(scale + kb + PREFETCH_SIZE_KB, _MM_HINT_T0);
}
// 2. zero vsum for each block
Unroll<ROWS * COLS>{}([&](auto i) { vsum[i] = _mm512_setzero_ps(); });
// 3. accumulate across each block
for (int k = kb_start; k < kb_end; ++k) {
Unroll<ROWS * COLS>{}(compute, k);
}
// 4. apply scale
Unroll<ROWS * COLS>{}([&](auto i) { vc[i] = _mm512_fmadd_ps(vsum[i], vscale, vc[i]); });
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
// for COLS = 2,4 use 512bit store
if constexpr (col % 2 == 0) {
_mm512_storeu_si512(
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
}
};
Unroll<ROWS * COLS>{}(storec);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
tinygemm_kernel_nn<scalar_t, at::Float8_e4m3fn, has_bias, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, \
B + nb_start * 2, \
C + mb_start * ldc + nb_start, \
has_bias ? bias + nb_start : nullptr, \
scale, \
K, \
lda, \
ldb, \
ldc, \
block_size_K);
template <typename scalar_t, typename packed_t, bool has_bias>
struct brgemm {
static inline void apply(
const scalar_t* __restrict__ A,
const packed_t* __restrict__ B,
scalar_t* __restrict__ C,
scalar_t* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ bias,
const float* __restrict__ scale,
int M,
int N,
int K,
int lda,
int ldb,
int ldc,
bool do_unpack = true) {
TORCH_CHECK(false, "struct brgemm: primary template not implemented!");
}
};
template <bool has_bias>
struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> {
static inline void apply(
const at::BFloat16* __restrict__ A,
const at::Float8_e4m3fn* __restrict__ B,
at::BFloat16* __restrict__ C,
at::BFloat16* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ bias,
const float* __restrict__ scale,
int M,
int N,
int K,
int lda,
int ldb,
int ldc,
bool do_unpack = true) {
constexpr int BLOCK_N = block_size_n();
// [K, BLOCK_N] -> [K / 2, BLOCK_N * 2]
const int ldb_tmp = BLOCK_N;
if (do_unpack) {
for (int k = 0; k < K; k += BLOCK_K) {
int kb_size = std::min(BLOCK_K, K - k);
int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128
unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]);
}
}
at::native::cpublas::brgemm(M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp);
// copy from Ctmp to C
for (int m = 0; m < M; ++m) {
if constexpr (has_bias) {
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
} else {
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
}
}
}
};
template <typename scalar_t, bool has_bias>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const at::Float8_e4m3fn* __restrict__ B,
scalar_t* __restrict__ C,
scalar_t* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ scale,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg,
int64_t block_size_K,
bool do_unpack = true) {
if (brg) {
brgemm<scalar_t, at::Float8_e4m3fn, has_bias>::apply(
A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc, do_unpack);
return;
}
// pattern: 1-4-16
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 64;
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
for (int mb = 0; mb < MB; ++mb) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch (mb_size << 4 | nb_size >> 4) {
case 0x12:
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
break;
case 0x22:
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
break;
case 0x32:
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
break;
case 0x42:
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
break;
default:
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
}
}
}
}
template <typename scalar_t>
void fp8_scaled_mm_kernel_impl(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ mat1,
const at::Float8_e4m3fn* __restrict__ mat2,
const float* __restrict__ scales2,
const float* __restrict__ bias,
scalar_t* __restrict__ buffer,
int64_t M,
int64_t N,
int64_t K,
int64_t mat1_strideM,
int64_t out_strideM,
int64_t block_size_N,
int64_t block_size_K,
int64_t buffer_size_per_thread) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
const int64_t scale_size_K = div_up(K, block_size_K);
const int64_t blocks_n_per_group = block_size_N / BLOCK_N;
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
// parallel on [MB, NB]
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
int tid = get_thread_num();
scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread;
float* __restrict__ Ctmp = (float*)((void*)(Btmp + MAX_CACHE_BLOCK_SIZE * BLOCK_N * K));
loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K;
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
// only do unpacking for the first row
bool do_unpack = (mb == mb0);
tinygemm_kernel<scalar_t, has_bias>(
/* A */ mat1 + mb_start * mat1_strideM,
/* B */ mat2 + nb_start * K, // nb * BLOCK_N * K
/* C */ out + mb_start * out_strideM + nb_start,
/* Btmp */ Btmp + nb_offset * BLOCK_N * K,
/* Ctmp */ Ctmp,
/* scale */ scale_ptr,
/* bias */ bias + nb_start,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ mat1_strideM,
/* ldb */ nb_size,
/* ldc */ out_strideM,
/* brg */ use_brgemm,
/* block_size_K */ block_size_K,
/* do_unpack */ do_unpack);
});
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
});
}
} // anonymous namespace
// tinygemm interface
template <typename scalar_t>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const at::Float8_e4m3fn* __restrict__ B,
scalar_t* __restrict__ C,
scalar_t* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ scale,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg,
int64_t block_size_K,
bool do_unpack) {
tinygemm_kernel<scalar_t, false>(
A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K, do_unpack);
}
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
template void tinygemm_kernel<TYPE>( \
const TYPE* __restrict__ A, \
const at::Float8_e4m3fn* __restrict__ B, \
TYPE* __restrict__ C, \
TYPE* __restrict__ Btmp, \
float* __restrict__ Ctmp, \
const float* __restrict__ scale, \
int64_t M, \
int64_t N, \
int64_t K, \
int64_t lda, \
int64_t ldb, \
int64_t ldc, \
bool brg, \
int64_t block_size_K, \
bool do_unpack)
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
at::Tensor fp8_scaled_mm_cpu(
at::Tensor& mat1,
at::Tensor& mat2,
at::Tensor& scales2,
std::vector<int64_t> block_size,
const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype,
bool is_vnni) {
RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, block_size, bias}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
CHECK_INPUT(mat2);
CHECK_INPUT(scales2);
TORCH_CHECK(scales2.scalar_type() == at::kFloat, "fp8_scaled_mm_cpu: expect scales2 to be float32.");
int64_t M = mat1.size(0);
int64_t N = mat2.size(0);
int64_t K = mat2.size(1);
CHECK_EQ(mat1.size(1), K);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
TORCH_CHECK(block_size.size() == 2, "fp8_scaled_mm_cpu: expect block_size.size() to be 2.");
int64_t block_size_N = block_size[0];
int64_t block_size_K = block_size[1];
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N");
TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K");
CHECK_EQ(scales2.size(0), div_up(N, block_size_N));
CHECK_EQ(scales2.size(1), div_up(K, block_size_K));
const auto st = mat1.scalar_type();
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, "fp8_scaled_mm_cpu: expect A to be bfloat16 or half.");
TORCH_CHECK(st == out_dtype, "fp8_scaled_mm_cpu: expect A has same dtype with out_dtype.");
TORCH_CHECK(mat2.scalar_type() == at::kFloat8_e4m3fn, "fp8_scaled_mm_cpu: expect mat2 to be fp8_e4m3.");
TORCH_CHECK(scales2.scalar_type() == at::kFloat, "fp8_scaled_mm_cpu: expect scales to be float32.");
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
// strides
int64_t mat1_strideM = mat1.stride(0);
int64_t out_strideM = out.stride(0);
const bool has_bias = bias.has_value();
const float* bias_data = nullptr;
if (has_bias) {
CHECK_EQ(bias.value().size(0), N);
bias_data = bias.value().data_ptr<float>();
}
// Btmp : [T, BLOCK_N * K]
// Ctmp : [T, BLOCK_M * BLOCK_N]
int num_threads = at::get_num_threads();
int64_t size_per_thread = MAX_CACHE_BLOCK_SIZE * BLOCK_N * K + BLOCK_M * BLOCK_N * 2;
auto buffer = at::empty({num_threads, size_per_thread}, mat1.options());
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] {
fp8_scaled_mm_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
mat1.data_ptr<scalar_t>(),
packed_w.data_ptr<at::Float8_e4m3fn>(),
scales2.data_ptr<float>(),
bias_data,
buffer.data_ptr<scalar_t>(),
M,
N,
K,
mat1_strideM,
out_strideM,
block_size_N,
block_size_K,
size_per_thread);
});
return out;
}

View File

@@ -0,0 +1,547 @@
#include "common.h"
#include "gemm.h"
#include "vec.h"
namespace {
template <typename scalar_t, bool has_bias, int BLOCK_N>
struct scale_C {
static inline void apply(
scalar_t* __restrict__ C,
const int32_t* __restrict__ Ctmp,
const int32_t* __restrict__ Bcomp,
const float* __restrict__ bias,
float As,
const float* __restrict__ Bs) {
TORCH_CHECK(false, "scale_C: scalar path not implemented!");
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <bool has_bias, int BLOCK_N>
struct scale_C<at::BFloat16, has_bias, BLOCK_N> {
static inline void apply(
at::BFloat16* __restrict__ C,
const int32_t* __restrict__ Ctmp,
const int32_t* __restrict__ Bcomp,
const float* __restrict__ bias,
float As,
const float* __restrict__ Bs) {
constexpr int COLS = BLOCK_N / 16;
static_assert(COLS % 2 == 0);
__m512 vc[COLS];
__m512 vd0 = _mm512_set1_ps(As);
auto compute = [&](auto col) {
__m512 vd1 = _mm512_loadu_ps(Bs + col * 16);
__m512i vcomp = _mm512_loadu_si512(Bcomp + col * 16);
__m512i vc32 = _mm512_loadu_si512(Ctmp + col * 16);
vc[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32, vcomp));
if constexpr (has_bias) {
__m512 vbias = _mm512_loadu_ps(bias + col * 16);
vc[col] = _mm512_fmadd_ps(_mm512_mul_ps(vc[col], vd0), vd1, vbias);
} else {
vc[col] = _mm512_mul_ps(_mm512_mul_ps(vc[col], vd0), vd1);
}
};
Unroll<COLS>{}(compute);
auto storec = [&](auto col) {
// for COLS = 2, 4 use 512bit store
if constexpr (col % 2 == 0) {
_mm512_storeu_si512(
reinterpret_cast<__m512i*>((C + col * 16)), (__m512i)(_mm512_cvtne2ps_pbh(vc[col + 1], vc[col + 0])));
}
};
Unroll<COLS>{}(storec);
}
};
#endif
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn {
static inline void apply(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
scalar_t* __restrict__ C,
const float* __restrict__ As,
const float* __restrict__ Bs,
const int32_t* __restrict__ Bcomp,
const float* __restrict__ bias,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
static inline void apply(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
at::BFloat16* __restrict__ C,
const float* __restrict__ As,
const float* __restrict__ Bs,
const int32_t* __restrict__ Bcomp,
const float* __restrict__ bias,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;
static_assert(COLS % 2 == 0);
// prefetch distance
constexpr int PREFETCH_SIZE_K = 0;
__m512i va;
__m512i vb[COLS];
__m512i vc[ROWS * COLS];
__m512i vcomp[COLS];
__m512 vd0;
__m512 vd1[COLS];
// oops! 4x4 spills but we use 4x2
__m512 vbias[COLS];
// [NOTE]: s8s8 igemm compensation in avx512-vnni
//
// avx512-vnni has no s8s8, so we need to change s8s8 to u8s8 with compensate:
//
// a * b = (a + 128) * b - 128 * b
// s s u s u s
//
// 1) 128 * b is pre-computed when packing B to vnni formats
// 2) a + 128 is fused when dynamically quantize A
//
auto loadc = [&](auto i) { vc[i] = _mm512_set1_epi32(0); };
Unroll<ROWS * COLS>{}(loadc);
const int64_t K4 = K >> 2;
const int64_t lda4 = lda >> 2;
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(B);
auto compute = [&](auto i, int64_t k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
}
if constexpr (row == 0) {
vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16);
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb4 + col * 16, _MM_HINT_T0);
}
}
vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
};
for (int64_t k = 0; k < K4; ++k) {
Unroll<ROWS * COLS>{}(compute, k);
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
// load a scale
if constexpr (col == 0) {
vd0 = _mm512_set1_ps(As[row]);
}
// load b scale and vcomp per 2 vectors
// also load bias if any
if constexpr (row == 0) {
if constexpr (col % 2 == 0) {
vd1[col + 0] = _mm512_loadu_ps(Bs + col * 16);
vd1[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16);
vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16);
vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16);
if constexpr (has_bias) {
vbias[col + 0] = _mm512_loadu_ps(bias + col * 16);
vbias[col + 1] = _mm512_loadu_ps(bias + col * 16 + 16);
}
}
}
// for COLS = 2, 4 use 512bit store
if constexpr (col % 2 == 0) {
__m512 vc0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 0], vcomp[col + 0]));
__m512 vc1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 1], vcomp[col + 1]));
if constexpr (has_bias) {
vc0 = _mm512_fmadd_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0], vbias[col + 0]);
vc1 = _mm512_fmadd_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1], vbias[col + 1]);
} else {
vc0 = _mm512_mul_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0]);
vc1 = _mm512_mul_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1]);
}
_mm512_storeu_si512(
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), (__m512i)(_mm512_cvtne2ps_pbh(vc1, vc0)));
}
};
Unroll<ROWS * COLS>{}(storec);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, \
B + nb_start * 4, \
C + mb_start * ldc + nb_start, \
As + mb_start, \
Bs + nb_start, \
Bcomp + nb_start, \
has_bias ? bias + nb_start : nullptr, \
K, \
lda, \
ldb, \
ldc);
template <typename scalar_t, bool has_bias>
void tinygemm_kernel(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
scalar_t* __restrict__ C,
int32_t* __restrict__ Ctmp,
const float* __restrict__ As,
const float* __restrict__ Bs,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg) {
// B compensation
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K);
if (brg) {
constexpr int BLOCK_N = block_size_n();
at::native::cpublas::brgemm(M, N, K, lda, ldb, BLOCK_N, /* add_C */ false, A, B, Ctmp);
// apply compensation and scale
for (int64_t m = 0; m < M; ++m) {
scale_C<scalar_t, has_bias, BLOCK_N>::apply(C + m * ldc, Ctmp + m * BLOCK_N, Bcomp, bias, As[m], Bs);
}
return;
}
// pattern: 1-4-16
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 64;
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
for (int64_t mb = 0; mb < MB; ++mb) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch (mb_size << 4 | nb_size >> 4) {
// mb_size = 1
case 0x12:
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
break;
case 0x14:
LAUNCH_TINYGEMM_KERNEL_NN(1, 64);
break;
// mb_size = 2
case 0x22:
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
break;
case 0x24:
LAUNCH_TINYGEMM_KERNEL_NN(2, 64);
break;
// mb_size = 3
case 0x32:
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
break;
case 0x34:
LAUNCH_TINYGEMM_KERNEL_NN(3, 64);
break;
// mb_size = 4
case 0x42:
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
break;
case 0x44:
LAUNCH_TINYGEMM_KERNEL_NN(4, 64);
break;
default:
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
}
}
}
}
template <typename scalar_t>
void int8_scaled_mm_kernel_impl(
scalar_t* __restrict__ out,
const uint8_t* __restrict__ mat1,
const int8_t* __restrict__ mat2,
const float* __restrict__ scales1,
const float* __restrict__ scales2,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
const bool use_brgemm = can_use_brgemm<int8_t>(M);
// K + 4 after compensation
const int64_t packed_row_size = get_row_size<int8_t>(K);
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
// for brgemm, use int32_t for accumulate
alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N];
loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
int mb_start = mb * BLOCK_M;
int mb_size = std::min(M - mb_start, BLOCK_M);
int nb_start = nb * BLOCK_N;
int nb_size = std::min(N - nb_start, BLOCK_N);
tinygemm_kernel<scalar_t, has_bias>(
/* A */ mat1 + mb_start * K,
/* B */ mat2 + nb_start * packed_row_size /* nb * BLOCK_N * (K + 4) */,
/* C */ out + mb_start * N + nb_start,
/* Ctmp*/ Ctmp,
/* As */ scales1 + mb_start,
/* Bs */ scales2 + nb_start,
/* bias*/ bias + nb_start,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ K,
/* ldb */ nb_size,
/* ldc */ N,
/* brg */ use_brgemm);
});
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
});
}
} // anonymous namespace
// tinygemm interface
template <typename scalar_t>
void tinygemm_kernel(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
scalar_t* __restrict__ C,
int32_t* __restrict__ Ctmp,
const float* __restrict__ As,
const float* __restrict__ Bs,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg) {
tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, As, Bs, nullptr, M, N, K, lda, ldb, ldc, brg);
}
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
template void tinygemm_kernel<TYPE>( \
const uint8_t* __restrict__ A, \
const int8_t* __restrict__ B, \
TYPE* __restrict__ C, \
int32_t* __restrict__ Ctmp, \
const float* __restrict__ As, \
const float* __restrict__ Bs, \
int64_t M, \
int64_t N, \
int64_t K, \
int64_t lda, \
int64_t ldb, \
int64_t ldc, \
bool brg)
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
std::tuple<at::Tensor, at::Tensor> per_token_quant_int8_cpu(at::Tensor& A) {
RECORD_FUNCTION("sgl-kernel::per_token_quant_int8_cpu", std::vector<c10::IValue>({A}));
CHECK_LAST_DIM_CONTIGUOUS_INPUT(A);
CHECK_DIM(2, A);
int64_t M = A.size(0);
int64_t K = A.size(1);
int64_t lda = A.stride(0);
const auto st = A.scalar_type();
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, "per_token_quant_int8: expect A to be bfloat16 or half.");
auto Aq = at::empty({M, K}, A.options().dtype(at::kByte));
auto As = at::empty({M}, A.options().dtype(at::kFloat));
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "per_token_quant_int8", [&] {
uint8_t* __restrict__ Aq_data = Aq.data_ptr<uint8_t>();
float* __restrict__ As_data = As.data_ptr<float>();
const scalar_t* __restrict__ A_data = A.data_ptr<scalar_t>();
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
quantize_row_int8<scalar_t>(Aq_data + m * K, As_data[m], A_data + m * lda, K);
}
});
});
return std::make_tuple(Aq, As);
}
// weight : static, per-channel, symmetric
// activation : dynamic, per-token, symmetric
//
// mat1 : [M, K]
// mat2 : [N, K]
// scales1 : [M]
// scales2 : [N]
// bias : [N]
// out : [M, N]
//
at::Tensor int8_scaled_mm_cpu(
at::Tensor& mat1,
at::Tensor& mat2,
at::Tensor& scales1,
at::Tensor& scales2,
const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype,
bool is_vnni) {
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales1, scales2, bias}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
CHECK_INPUT(mat1);
CHECK_INPUT(mat2);
CHECK_INPUT(scales1);
CHECK_INPUT(scales2);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
int64_t M = mat1.size(0);
int64_t N = mat2.size(0);
int64_t K = mat1.size(1);
// see [NOTE]: s8s8 igemm compensation in avx512-vnni
CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K));
CHECK_EQ(scales1.numel(), M);
CHECK_EQ(scales2.numel(), N);
TORCH_CHECK(mat1.scalar_type() == at::kByte, "int8_scaled_mm: expect mat1 to be uint8.");
TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm: expect mat2 to be int8.");
TORCH_CHECK(
scales1.scalar_type() == at::kFloat && scales2.scalar_type() == at::kFloat,
"int8_scaled_mm: expect scales to be float32.");
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
const bool has_bias = bias.has_value();
const float* bias_data = nullptr;
if (has_bias) {
CHECK_EQ(bias.value().size(0), N);
bias_data = bias.value().data_ptr<float>();
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_kernel_impl", [&] {
int8_scaled_mm_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
mat1.data_ptr<uint8_t>(),
packed_w.data_ptr<int8_t>(),
scales1.data_ptr<float>(),
scales2.data_ptr<float>(),
bias_data,
M,
N,
K);
});
return out;
}
// fused `per_token_quant_int8_cpu` and `int8_scaled_mm_cpu`
at::Tensor int8_scaled_mm_with_quant(
at::Tensor& mat1,
at::Tensor& mat2,
at::Tensor& scales2,
const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype,
bool is_vnni) {
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, bias}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
CHECK_INPUT(mat2);
CHECK_INPUT(scales2);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
int64_t M = mat1.size(0);
int64_t N = mat2.size(0);
int64_t K = mat1.size(1);
int64_t lda = mat1.stride(0);
// see [NOTE]: s8s8 igemm compensation in avx512-vnni
CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K));
CHECK_EQ(scales2.numel(), N);
const auto st = mat1.scalar_type();
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, "int8_scaled_mm_with_quant: expect A to be bfloat16 or half.");
TORCH_CHECK(st == out_dtype, "int8_scaled_mm_with_quant: expect A has same dtype with out_dtype.");
TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm_with_quant: expect mat2 to be int8.");
TORCH_CHECK(scales2.scalar_type() == at::kFloat, "int8_scaled_mm_with_quant: expect scales to be float32.");
const int64_t buffer_size = M * K + M * sizeof(float);
auto buffer = at::empty({buffer_size}, mat1.options().dtype(at::kByte));
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
const bool has_bias = bias.has_value();
const float* bias_data = nullptr;
if (has_bias) {
CHECK_EQ(bias.value().size(0), N);
bias_data = bias.value().data_ptr<float>();
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_with_quant_kernel_impl", [&] {
uint8_t* __restrict__ Aq_data = buffer.data_ptr<uint8_t>();
float* __restrict__ As_data = (float*)((void*)(Aq_data + M * K));
const scalar_t* __restrict__ A_data = mat1.data_ptr<scalar_t>();
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
quantize_row_int8<scalar_t>(Aq_data + m * K, As_data[m], A_data + m * lda, K);
}
});
int8_scaled_mm_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
Aq_data,
packed_w.data_ptr<int8_t>(),
As_data,
scales2.data_ptr<float>(),
bias_data,
M,
N,
K);
});
return out;
}

Some files were not shown because too many files have changed in this diff Show More