From 79e6a8a6acd81325630125c663868335f47cc07f Mon Sep 17 00:00:00 2001 From: Rain Jiang <96632942+rainj-me@users.noreply.github.com> Date: Tue, 26 Aug 2025 23:13:27 -0700 Subject: [PATCH] support cuda 13.0 and trtllm kernel by Aug 25 2025 (#9495) --- sgl-kernel/CMakeLists.txt | 32 +++++++++++++------ .../moe/marlin_moe_wna16/generate_kernels.py | 27 ++++++++++++++-- sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h | 1 + ...kernel_bf16_ku4.cu => kernel_bf16_ku4.cuh} | 1 + ...el_bf16_ku4b8.cu => kernel_bf16_ku4b8.cuh} | 1 + ...f16_ku8b128.cu => kernel_bf16_ku8b128.cuh} | 1 + ...kernel_fp16_ku4.cu => kernel_fp16_ku4.cuh} | 1 + ...el_fp16_ku4b8.cu => kernel_fp16_ku4b8.cuh} | 1 + ...p16_ku8b128.cu => kernel_fp16_ku8b128.cuh} | 1 + .../moe/marlin_moe_wna16/kernel_marlin.cuh | 10 ++++++ .../moe/marlin_moe_wna16/marlin_template.h | 2 ++ sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu | 1 + .../csrc/moe/moe_topk_softmax_kernels.cu | 16 ++++++++-- 13 files changed, 81 insertions(+), 14 deletions(-) rename sgl-kernel/csrc/moe/marlin_moe_wna16/{kernel_bf16_ku4.cu => kernel_bf16_ku4.cuh} (99%) rename sgl-kernel/csrc/moe/marlin_moe_wna16/{kernel_bf16_ku4b8.cu => kernel_bf16_ku4b8.cuh} (99%) rename sgl-kernel/csrc/moe/marlin_moe_wna16/{kernel_bf16_ku8b128.cu => kernel_bf16_ku8b128.cuh} (99%) rename sgl-kernel/csrc/moe/marlin_moe_wna16/{kernel_fp16_ku4.cu => kernel_fp16_ku4.cuh} (99%) rename sgl-kernel/csrc/moe/marlin_moe_wna16/{kernel_fp16_ku4b8.cu => kernel_fp16_ku4b8.cuh} (99%) rename sgl-kernel/csrc/moe/marlin_moe_wna16/{kernel_fp16_ku8b128.cu => kernel_fp16_ku8b128.cuh} (99%) create mode 100644 sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_marlin.cuh diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 307734ca7..975291435 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -57,6 +57,9 @@ if("${CUDA_VERSION}" VERSION_EQUAL "12.8") elseif("${CUDA_VERSION}" VERSION_EQUAL "12.9") set(DeepGEMM_REPO "https://github.com/sgl-project/DeepGEMM") set(DeepGEMM_TAG "blackwell") +elseif("${CUDA_VERSION}" VERSION_EQUAL "13.0") + set(DeepGEMM_REPO "https://github.com/sgl-project/DeepGEMM") + set(DeepGEMM_TAG "blackwell") else() set(DeepGEMM_REPO "https://github.com/deepseek-ai/DeepGEMM") set(DeepGEMM_TAG "391755ada0ffefa9a6a52b6f14dcaf22d1a463e0") @@ -83,7 +86,7 @@ FetchContent_Populate(repo-triton) FetchContent_Declare( repo-flashinfer GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git - GIT_TAG 9220fb3443b5a5d274f00ca5552f798e225239b7 + GIT_TAG 018b551825c8e5579206e6eb9d3229fa679202b3 GIT_SHALLOW OFF ) FetchContent_Populate(repo-flashinfer) @@ -179,11 +182,28 @@ if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A) list(APPEND SGL_KERNEL_CUDA_FLAGS "-gencode=arch=compute_100,code=sm_100" "-gencode=arch=compute_100a,code=sm_100a" - "-gencode=arch=compute_101,code=sm_101" - "-gencode=arch=compute_101a,code=sm_101a" + "-gencode=arch=compute_103,code=sm_103" + "-gencode=arch=compute_103a,code=sm_103a" "-gencode=arch=compute_120,code=sm_120" "-gencode=arch=compute_120a,code=sm_120a" ) + + # refer sm_121, sm_110 and sm_101 description https://github.com/pytorch/pytorch/pull/156176 + if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "13.0") + list(APPEND SGL_KERNEL_CUDA_FLAGS + "-gencode=arch=compute_110,code=sm_110" + "-gencode=arch=compute_110a,code=sm_110a" + "-gencode=arch=compute_121,code=sm_121" + "-gencode=arch=compute_121a,code=sm_121a" + "--compress-mode=size" + ) + else() + list(APPEND SGL_KERNEL_CUDA_FLAGS + "-gencode=arch=compute_101,code=sm_101" + "-gencode=arch=compute_101a,code=sm_101a" + ) + endif() + else() list(APPEND SGL_KERNEL_CUDA_FLAGS "-use_fast_math" @@ -266,12 +286,6 @@ set(SOURCES "csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu" "csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu" "csrc/moe/marlin_moe_wna16/ops.cu" - "csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu" - "csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu" - "csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu" - "csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu" - "csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu" - "csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu" "csrc/moe/moe_align_kernel.cu" "csrc/moe/moe_fused_gate.cu" "csrc/moe/moe_topk_softmax_kernels.cu" diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py b/sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py index 833d074ea..b3ed863a3 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -9,6 +9,7 @@ import jinja2 FILE_HEAD = """ // auto generated by generate.py // clang-format off +#pragma once #include "kernel.h" #include "marlin_template.h" @@ -33,6 +34,17 @@ TEMPLATE = ( "( MARLIN_KERNEL_PARAMS );" ) +KERNEL_FILE_TEMPLATE = ( + "// auto generated by generate.py\n" + "// clang-format off\n" + "#pragma once\n\n" + "{% for kernel_file in kernel_files %}" + '#include "{{ kernel_file }}"\n' + "{% endfor %}" +) + +KERNEL_FILE_NAME = "kernel_marlin.cuh" + # int8 with zero point case (sglang::kU8) is also supported, # we don't add it to reduce wheel size. SCALAR_TYPES = ["sglang::kU4", "sglang::kU4B8", "sglang::kU8B128"] @@ -48,11 +60,12 @@ DTYPES = ["fp16", "bf16"] def remove_old_kernels(): - for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): + for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cuh"): subprocess.call(["rm", "-f", filename]) def generate_new_kernels(): + kernel_files = set() for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): has_zp = "B" not in scalar_type all_template_str_list = [] @@ -95,10 +108,20 @@ def generate_new_kernels(): file_content = FILE_HEAD + "\n\n" file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" - filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cu" + filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cuh" with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: f.write(file_content) + kernel_files.add(filename) + + kernel_files = list(kernel_files) + kernel_files.sort() + + file_content = jinja2.Template(KERNEL_FILE_TEMPLATE).render( + kernel_files=kernel_files + ) + with open(os.path.join(os.path.dirname(__file__), KERNEL_FILE_NAME), "w") as f: + f.write(file_content) if __name__ == "__main__": diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h index 88d157507..afa7c377b 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel.h @@ -1,3 +1,4 @@ +#pragma once #ifndef MARLIN_NAMESPACE_NAME #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cuh similarity index 99% rename from sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu rename to sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cuh index 1e3d923ae..7e83bed8f 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cuh @@ -1,5 +1,6 @@ // auto generated by generate.py // clang-format off +#pragma once #include "kernel.h" #include "marlin_template.h" diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cuh similarity index 99% rename from sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu rename to sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cuh index 513ddc2ed..60e2dea31 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cuh @@ -1,5 +1,6 @@ // auto generated by generate.py // clang-format off +#pragma once #include "kernel.h" #include "marlin_template.h" diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cuh similarity index 99% rename from sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu rename to sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cuh index eebe9d3da..7eb6b18de 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cuh @@ -1,5 +1,6 @@ // auto generated by generate.py // clang-format off +#pragma once #include "kernel.h" #include "marlin_template.h" diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cuh similarity index 99% rename from sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu rename to sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cuh index 9adc6623a..ec41e018b 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cuh @@ -1,5 +1,6 @@ // auto generated by generate.py // clang-format off +#pragma once #include "kernel.h" #include "marlin_template.h" diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cuh similarity index 99% rename from sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu rename to sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cuh index 66ca7e36a..7df28701b 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cuh @@ -1,5 +1,6 @@ // auto generated by generate.py // clang-format off +#pragma once #include "kernel.h" #include "marlin_template.h" diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cuh similarity index 99% rename from sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu rename to sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cuh index 21fdf0c1a..1150844e2 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cuh @@ -1,5 +1,6 @@ // auto generated by generate.py // clang-format off +#pragma once #include "kernel.h" #include "marlin_template.h" diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_marlin.cuh b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_marlin.cuh new file mode 100644 index 000000000..bb828dc5b --- /dev/null +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_marlin.cuh @@ -0,0 +1,10 @@ +// auto generated by generate.py +// clang-format off +#pragma once + +#include "kernel_bf16_ku4.cuh" +#include "kernel_bf16_ku4b8.cuh" +#include "kernel_bf16_ku8b128.cuh" +#include "kernel_fp16_ku4.cuh" +#include "kernel_fp16_ku4b8.cuh" +#include "kernel_fp16_ku8b128.cuh" diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h b/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h index 71c91839d..ade562af6 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -18,6 +18,8 @@ /* * Adapted from https://github.com/IST-DASLab/marlin */ +#pragma once + #ifndef MARLIN_NAMESPACE_NAME #define MARLIN_NAMESPACE_NAME marlin_moe_wna16 #endif diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu index f430390d1..b249f6415 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu @@ -24,6 +24,7 @@ #endif #include "kernel.h" +#include "kernel_marlin.cuh" #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ static_assert( \ diff --git a/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu b/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu index 050e8d52b..c9bc8a628 100644 --- a/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu +++ b/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.cu @@ -23,6 +23,7 @@ limitations under the License. #ifndef USE_ROCM #include #include +#include #else #include #include @@ -33,6 +34,16 @@ limitations under the License. #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) +// Define reduction operators based on CUDA version +// CUDA 13 (12.9+) deprecated cub::Max/Min in favor of cuda::maximum/minimum +#if CUDA_VERSION >= 12090 +using MaxReduceOp = cuda::maximum<>; +using MinReduceOp = cuda::minimum<>; +#else +using MaxReduceOp = cub::Max; +using MinReduceOp = cub::Min; +#endif + /// Aligned array type template < typename T, @@ -72,7 +83,6 @@ __launch_bounds__(TPB) __global__ const int thread_row_offset = blockIdx.x * num_cols; - cub::Sum sum; float threadData(-FLT_MAX); // Don't touch finished rows. @@ -85,7 +95,7 @@ __launch_bounds__(TPB) __global__ threadData = max(convert_to_float(input[idx]), threadData); } - const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, MaxReduceOp()); if (threadIdx.x == 0) { float_max = maxElem; @@ -99,7 +109,7 @@ __launch_bounds__(TPB) __global__ threadData += exp((convert_to_float(input[idx]) - float_max)); } - const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); + const auto Z = BlockReduce(tmpStorage).Sum(threadData); if (threadIdx.x == 0) { normalizing_factor = 1.f / Z;