Support compile sgl-kernel on cuda 13.0 (#9721)
This commit is contained in:
@@ -78,7 +78,7 @@ FetchContent_Populate(repo-triton)
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
repo-flashinfer
|
repo-flashinfer
|
||||||
GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git
|
GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git
|
||||||
GIT_TAG 9220fb3443b5a5d274f00ca5552f798e225239b7
|
GIT_TAG 018b551825c8e5579206e6eb9d3229fa679202b3
|
||||||
GIT_SHALLOW OFF
|
GIT_SHALLOW OFF
|
||||||
)
|
)
|
||||||
FetchContent_Populate(repo-flashinfer)
|
FetchContent_Populate(repo-flashinfer)
|
||||||
@@ -174,11 +174,28 @@ if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
|
|||||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||||
"-gencode=arch=compute_100,code=sm_100"
|
"-gencode=arch=compute_100,code=sm_100"
|
||||||
"-gencode=arch=compute_100a,code=sm_100a"
|
"-gencode=arch=compute_100a,code=sm_100a"
|
||||||
"-gencode=arch=compute_101,code=sm_101"
|
|
||||||
"-gencode=arch=compute_101a,code=sm_101a"
|
|
||||||
"-gencode=arch=compute_120,code=sm_120"
|
"-gencode=arch=compute_120,code=sm_120"
|
||||||
"-gencode=arch=compute_120a,code=sm_120a"
|
"-gencode=arch=compute_120a,code=sm_120a"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# refer sm_121, sm_110 and sm_101 description https://github.com/pytorch/pytorch/pull/156176
|
||||||
|
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "13.0")
|
||||||
|
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||||
|
"-gencode=arch=compute_103,code=sm_103"
|
||||||
|
"-gencode=arch=compute_103a,code=sm_103a"
|
||||||
|
"-gencode=arch=compute_110,code=sm_110"
|
||||||
|
"-gencode=arch=compute_110a,code=sm_110a"
|
||||||
|
"-gencode=arch=compute_121,code=sm_121"
|
||||||
|
"-gencode=arch=compute_121a,code=sm_121a"
|
||||||
|
"--compress-mode=size"
|
||||||
|
)
|
||||||
|
else()
|
||||||
|
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||||
|
"-gencode=arch=compute_101,code=sm_101"
|
||||||
|
"-gencode=arch=compute_101a,code=sm_101a"
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
else()
|
else()
|
||||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||||
"-use_fast_math"
|
"-use_fast_math"
|
||||||
@@ -261,12 +278,6 @@ set(SOURCES
|
|||||||
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
|
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
|
||||||
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
|
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
|
||||||
"csrc/moe/marlin_moe_wna16/ops.cu"
|
"csrc/moe/marlin_moe_wna16/ops.cu"
|
||||||
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu"
|
|
||||||
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu"
|
|
||||||
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu"
|
|
||||||
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu"
|
|
||||||
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu"
|
|
||||||
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu"
|
|
||||||
"csrc/moe/moe_align_kernel.cu"
|
"csrc/moe/moe_align_kernel.cu"
|
||||||
"csrc/moe/moe_fused_gate.cu"
|
"csrc/moe/moe_fused_gate.cu"
|
||||||
"csrc/moe/moe_topk_softmax_kernels.cu"
|
"csrc/moe/moe_topk_softmax_kernels.cu"
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import jinja2
|
|||||||
FILE_HEAD = """
|
FILE_HEAD = """
|
||||||
// auto generated by generate.py
|
// auto generated by generate.py
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
#pragma once
|
||||||
|
|
||||||
#include "kernel.h"
|
#include "kernel.h"
|
||||||
#include "marlin_template.h"
|
#include "marlin_template.h"
|
||||||
@@ -33,6 +34,17 @@ TEMPLATE = (
|
|||||||
"( MARLIN_KERNEL_PARAMS );"
|
"( MARLIN_KERNEL_PARAMS );"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
KERNEL_FILE_TEMPLATE = (
|
||||||
|
"// auto generated by generate.py\n"
|
||||||
|
"// clang-format off\n"
|
||||||
|
"#pragma once\n\n"
|
||||||
|
"{% for kernel_file in kernel_files %}"
|
||||||
|
'#include "{{ kernel_file }}"\n'
|
||||||
|
"{% endfor %}"
|
||||||
|
)
|
||||||
|
|
||||||
|
KERNEL_FILE_NAME = "kernel_marlin.cuh"
|
||||||
|
|
||||||
# int8 with zero point case (sglang::kU8) is also supported,
|
# int8 with zero point case (sglang::kU8) is also supported,
|
||||||
# we don't add it to reduce wheel size.
|
# we don't add it to reduce wheel size.
|
||||||
SCALAR_TYPES = ["sglang::kU4", "sglang::kU4B8", "sglang::kU8B128"]
|
SCALAR_TYPES = ["sglang::kU4", "sglang::kU4B8", "sglang::kU8B128"]
|
||||||
@@ -48,11 +60,12 @@ DTYPES = ["fp16", "bf16"]
|
|||||||
|
|
||||||
|
|
||||||
def remove_old_kernels():
|
def remove_old_kernels():
|
||||||
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
|
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cuh"):
|
||||||
subprocess.call(["rm", "-f", filename])
|
subprocess.call(["rm", "-f", filename])
|
||||||
|
|
||||||
|
|
||||||
def generate_new_kernels():
|
def generate_new_kernels():
|
||||||
|
kernel_files = set()
|
||||||
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
|
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
|
||||||
has_zp = "B" not in scalar_type
|
has_zp = "B" not in scalar_type
|
||||||
all_template_str_list = []
|
all_template_str_list = []
|
||||||
@@ -95,10 +108,20 @@ def generate_new_kernels():
|
|||||||
|
|
||||||
file_content = FILE_HEAD + "\n\n"
|
file_content = FILE_HEAD + "\n\n"
|
||||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||||
filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cu"
|
filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cuh"
|
||||||
|
|
||||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||||
f.write(file_content)
|
f.write(file_content)
|
||||||
|
kernel_files.add(filename)
|
||||||
|
|
||||||
|
kernel_files = list(kernel_files)
|
||||||
|
kernel_files.sort()
|
||||||
|
|
||||||
|
file_content = jinja2.Template(KERNEL_FILE_TEMPLATE).render(
|
||||||
|
kernel_files=kernel_files
|
||||||
|
)
|
||||||
|
with open(os.path.join(os.path.dirname(__file__), KERNEL_FILE_NAME), "w") as f:
|
||||||
|
f.write(file_content)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
#ifndef MARLIN_NAMESPACE_NAME
|
#ifndef MARLIN_NAMESPACE_NAME
|
||||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
// auto generated by generate.py
|
// auto generated by generate.py
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
#pragma once
|
||||||
|
|
||||||
#include "kernel.h"
|
#include "kernel.h"
|
||||||
#include "marlin_template.h"
|
#include "marlin_template.h"
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
// auto generated by generate.py
|
// auto generated by generate.py
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
#pragma once
|
||||||
|
|
||||||
#include "kernel.h"
|
#include "kernel.h"
|
||||||
#include "marlin_template.h"
|
#include "marlin_template.h"
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
// auto generated by generate.py
|
// auto generated by generate.py
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
#pragma once
|
||||||
|
|
||||||
#include "kernel.h"
|
#include "kernel.h"
|
||||||
#include "marlin_template.h"
|
#include "marlin_template.h"
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
// auto generated by generate.py
|
// auto generated by generate.py
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
#pragma once
|
||||||
|
|
||||||
#include "kernel.h"
|
#include "kernel.h"
|
||||||
#include "marlin_template.h"
|
#include "marlin_template.h"
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
// auto generated by generate.py
|
// auto generated by generate.py
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
#pragma once
|
||||||
|
|
||||||
#include "kernel.h"
|
#include "kernel.h"
|
||||||
#include "marlin_template.h"
|
#include "marlin_template.h"
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
// auto generated by generate.py
|
// auto generated by generate.py
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
#pragma once
|
||||||
|
|
||||||
#include "kernel.h"
|
#include "kernel.h"
|
||||||
#include "marlin_template.h"
|
#include "marlin_template.h"
|
||||||
10
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_marlin.cuh
Normal file
10
sgl-kernel/csrc/moe/marlin_moe_wna16/kernel_marlin.cuh
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
// auto generated by generate.py
|
||||||
|
// clang-format off
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "kernel_bf16_ku4.cuh"
|
||||||
|
#include "kernel_bf16_ku4b8.cuh"
|
||||||
|
#include "kernel_bf16_ku8b128.cuh"
|
||||||
|
#include "kernel_fp16_ku4.cuh"
|
||||||
|
#include "kernel_fp16_ku4b8.cuh"
|
||||||
|
#include "kernel_fp16_ku8b128.cuh"
|
||||||
@@ -18,6 +18,8 @@
|
|||||||
/*
|
/*
|
||||||
* Adapted from https://github.com/IST-DASLab/marlin
|
* Adapted from https://github.com/IST-DASLab/marlin
|
||||||
*/
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
#ifndef MARLIN_NAMESPACE_NAME
|
#ifndef MARLIN_NAMESPACE_NAME
|
||||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -24,6 +24,7 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "kernel.h"
|
#include "kernel.h"
|
||||||
|
#include "kernel_marlin.cuh"
|
||||||
|
|
||||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||||
static_assert( \
|
static_assert( \
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#include <cub/cub.cuh>
|
#include <cub/cub.cuh>
|
||||||
#include <cub/util_type.cuh>
|
#include <cub/util_type.cuh>
|
||||||
|
#include <cuda/functional>
|
||||||
#else
|
#else
|
||||||
#include <hipcub/hipcub.hpp>
|
#include <hipcub/hipcub.hpp>
|
||||||
#include <hipcub/util_type.hpp>
|
#include <hipcub/util_type.hpp>
|
||||||
@@ -33,6 +34,16 @@ limitations under the License.
|
|||||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||||
|
|
||||||
|
// Define reduction operators based on CUDA version
|
||||||
|
// CUDA 13 (12.9+) deprecated cub::Max/Min in favor of cuda::maximum/minimum
|
||||||
|
#if CUDA_VERSION >= 12090
|
||||||
|
using MaxReduceOp = cuda::maximum<>;
|
||||||
|
using MinReduceOp = cuda::minimum<>;
|
||||||
|
#else
|
||||||
|
using MaxReduceOp = cub::Max;
|
||||||
|
using MinReduceOp = cub::Min;
|
||||||
|
#endif
|
||||||
|
|
||||||
/// Aligned array type
|
/// Aligned array type
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
@@ -72,7 +83,6 @@ __launch_bounds__(TPB) __global__
|
|||||||
|
|
||||||
const int thread_row_offset = blockIdx.x * num_cols;
|
const int thread_row_offset = blockIdx.x * num_cols;
|
||||||
|
|
||||||
cub::Sum sum;
|
|
||||||
float threadData(-FLT_MAX);
|
float threadData(-FLT_MAX);
|
||||||
|
|
||||||
// Don't touch finished rows.
|
// Don't touch finished rows.
|
||||||
@@ -85,7 +95,7 @@ __launch_bounds__(TPB) __global__
|
|||||||
threadData = max(convert_to_float<T>(input[idx]), threadData);
|
threadData = max(convert_to_float<T>(input[idx]), threadData);
|
||||||
}
|
}
|
||||||
|
|
||||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, MaxReduceOp());
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
float_max = maxElem;
|
float_max = maxElem;
|
||||||
@@ -99,7 +109,7 @@ __launch_bounds__(TPB) __global__
|
|||||||
threadData += exp((convert_to_float<T>(input[idx]) - float_max));
|
threadData += exp((convert_to_float<T>(input[idx]) - float_max));
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
|
const auto Z = BlockReduce(tmpStorage).Sum(threadData);
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
normalizing_factor = 1.f / Z;
|
normalizing_factor = 1.f / Z;
|
||||||
|
|||||||
Reference in New Issue
Block a user