[sgl-kernel] misc: update deepgemm version for sgl-kernel (#9340)
Co-authored-by: Yineng Zhang <me@zhyncs.com> Co-authored-by: fzyzcjy <ch271828n@outlook.com>
This commit is contained in:
@@ -50,25 +50,17 @@ FetchContent_Declare(
|
||||
)
|
||||
FetchContent_Populate(repo-cutlass)
|
||||
|
||||
# DeepGEMM
|
||||
if("${CUDA_VERSION}" VERSION_EQUAL "12.8")
|
||||
set(DeepGEMM_REPO "https://github.com/sgl-project/DeepGEMM")
|
||||
set(DeepGEMM_TAG "blackwell")
|
||||
elseif("${CUDA_VERSION}" VERSION_EQUAL "12.9")
|
||||
set(DeepGEMM_REPO "https://github.com/sgl-project/DeepGEMM")
|
||||
set(DeepGEMM_TAG "blackwell")
|
||||
elseif("${CUDA_VERSION}" VERSION_EQUAL "13.0")
|
||||
set(DeepGEMM_REPO "https://github.com/sgl-project/DeepGEMM")
|
||||
set(DeepGEMM_TAG "blackwell")
|
||||
else()
|
||||
set(DeepGEMM_REPO "https://github.com/deepseek-ai/DeepGEMM")
|
||||
set(DeepGEMM_TAG "391755ada0ffefa9a6a52b6f14dcaf22d1a463e0")
|
||||
endif()
|
||||
FetchContent_Declare(
|
||||
repo-fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt
|
||||
GIT_TAG 553ec11ec06fbe0beebfbb45f9dc3c9eabd83d28
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
|
||||
FetchContent_Declare(
|
||||
repo-deepgemm
|
||||
GIT_REPOSITORY ${DeepGEMM_REPO}
|
||||
GIT_TAG ${DeepGEMM_TAG}
|
||||
GIT_REPOSITORY https://github.com/sgl-project/DeepGEMM
|
||||
GIT_TAG sgl
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-deepgemm)
|
||||
@@ -86,7 +78,7 @@ FetchContent_Populate(repo-triton)
|
||||
FetchContent_Declare(
|
||||
repo-flashinfer
|
||||
GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git
|
||||
GIT_TAG 018b551825c8e5579206e6eb9d3229fa679202b3
|
||||
GIT_TAG 9220fb3443b5a5d274f00ca5552f798e225239b7
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-flashinfer)
|
||||
@@ -182,28 +174,11 @@ if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-gencode=arch=compute_100,code=sm_100"
|
||||
"-gencode=arch=compute_100a,code=sm_100a"
|
||||
"-gencode=arch=compute_103,code=sm_103"
|
||||
"-gencode=arch=compute_103a,code=sm_103a"
|
||||
"-gencode=arch=compute_101,code=sm_101"
|
||||
"-gencode=arch=compute_101a,code=sm_101a"
|
||||
"-gencode=arch=compute_120,code=sm_120"
|
||||
"-gencode=arch=compute_120a,code=sm_120a"
|
||||
)
|
||||
|
||||
# refer sm_121, sm_110 and sm_101 description https://github.com/pytorch/pytorch/pull/156176
|
||||
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "13.0")
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-gencode=arch=compute_110,code=sm_110"
|
||||
"-gencode=arch=compute_110a,code=sm_110a"
|
||||
"-gencode=arch=compute_121,code=sm_121"
|
||||
"-gencode=arch=compute_121a,code=sm_121a"
|
||||
"--compress-mode=size"
|
||||
)
|
||||
else()
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-gencode=arch=compute_101,code=sm_101"
|
||||
"-gencode=arch=compute_101a,code=sm_101a"
|
||||
)
|
||||
endif()
|
||||
|
||||
else()
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-use_fast_math"
|
||||
@@ -286,6 +261,12 @@ set(SOURCES
|
||||
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
|
||||
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
|
||||
"csrc/moe/marlin_moe_wna16/ops.cu"
|
||||
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu"
|
||||
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu"
|
||||
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu"
|
||||
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu"
|
||||
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu"
|
||||
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu"
|
||||
"csrc/moe/moe_align_kernel.cu"
|
||||
"csrc/moe/moe_fused_gate.cu"
|
||||
"csrc/moe/moe_topk_softmax_kernels.cu"
|
||||
@@ -321,8 +302,6 @@ target_include_directories(common_ops PRIVATE
|
||||
${repo-cutlass_SOURCE_DIR}/examples/common
|
||||
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src
|
||||
)
|
||||
set_source_files_properties("csrc/gemm/per_token_group_quant_8bit" PROPERTIES COMPILE_OPTIONS "--use_fast_math")
|
||||
|
||||
|
||||
find_package(Python3 COMPONENTS Interpreter REQUIRED)
|
||||
execute_process(
|
||||
@@ -464,13 +443,38 @@ install(TARGETS spatial_ops LIBRARY DESTINATION sgl_kernel)
|
||||
set(DEEPGEMM_SOURCES
|
||||
"${repo-deepgemm_SOURCE_DIR}/csrc/python_api.cpp"
|
||||
)
|
||||
# JIT Logic
|
||||
# DeepGEMM
|
||||
|
||||
install(DIRECTORY "${repo-deepgemm_SOURCE_DIR}/deep_gemm/"
|
||||
DESTINATION "deep_gemm"
|
||||
PATTERN ".git*" EXCLUDE
|
||||
PATTERN "__pycache__" EXCLUDE)
|
||||
Python_add_library(deep_gemm_cpp MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${DEEPGEMM_SOURCES})
|
||||
|
||||
# Link against necessary libraries, including nvrtc for JIT compilation.
|
||||
target_link_libraries(deep_gemm_cpp PRIVATE ${TORCH_LIBRARIES} c10 cuda nvrtc mscclpp_static)
|
||||
|
||||
# Add include directories needed by DeepGEMM.
|
||||
target_include_directories(deep_gemm_cpp PRIVATE
|
||||
${repo-deepgemm_SOURCE_DIR}/deep_gemm/include
|
||||
${repo-cutlass_SOURCE_DIR}/include
|
||||
${repo-fmt_SOURCE_DIR}/include
|
||||
)
|
||||
|
||||
# Apply the same compile options as common_ops.
|
||||
target_compile_options(deep_gemm_cpp PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
|
||||
|
||||
# Create an empty __init__.py to make `deepgemm` a Python package.
|
||||
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/deepgemm_pkg_init.py "")
|
||||
install(
|
||||
FILES ${CMAKE_CURRENT_BINARY_DIR}/deepgemm_pkg_init.py
|
||||
DESTINATION deep_gemm
|
||||
RENAME __init__.py
|
||||
)
|
||||
|
||||
# Install the compiled DeepGEMM API library.
|
||||
install(TARGETS deep_gemm_cpp LIBRARY DESTINATION deep_gemm)
|
||||
|
||||
# Install the source files required by DeepGEMM for runtime JIT compilation.
|
||||
install(
|
||||
DIRECTORY ${repo-deepgemm_SOURCE_DIR}/deep_gemm/
|
||||
DESTINATION deep_gemm
|
||||
)
|
||||
|
||||
install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cute/"
|
||||
DESTINATION "deep_gemm/include/cute")
|
||||
|
||||
@@ -9,7 +9,6 @@ import jinja2
|
||||
FILE_HEAD = """
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
@@ -34,17 +33,6 @@ TEMPLATE = (
|
||||
"( MARLIN_KERNEL_PARAMS );"
|
||||
)
|
||||
|
||||
KERNEL_FILE_TEMPLATE = (
|
||||
"// auto generated by generate.py\n"
|
||||
"// clang-format off\n"
|
||||
"#pragma once\n\n"
|
||||
"{% for kernel_file in kernel_files %}"
|
||||
'#include "{{ kernel_file }}"\n'
|
||||
"{% endfor %}"
|
||||
)
|
||||
|
||||
KERNEL_FILE_NAME = "kernel_marlin.cuh"
|
||||
|
||||
# int8 with zero point case (sglang::kU8) is also supported,
|
||||
# we don't add it to reduce wheel size.
|
||||
SCALAR_TYPES = ["sglang::kU4", "sglang::kU4B8", "sglang::kU8B128"]
|
||||
@@ -60,12 +48,11 @@ DTYPES = ["fp16", "bf16"]
|
||||
|
||||
|
||||
def remove_old_kernels():
|
||||
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cuh"):
|
||||
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
|
||||
subprocess.call(["rm", "-f", filename])
|
||||
|
||||
|
||||
def generate_new_kernels():
|
||||
kernel_files = set()
|
||||
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
|
||||
has_zp = "B" not in scalar_type
|
||||
all_template_str_list = []
|
||||
@@ -108,20 +95,10 @@ def generate_new_kernels():
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cuh"
|
||||
filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cu"
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
kernel_files.add(filename)
|
||||
|
||||
kernel_files = list(kernel_files)
|
||||
kernel_files.sort()
|
||||
|
||||
file_content = jinja2.Template(KERNEL_FILE_TEMPLATE).render(
|
||||
kernel_files=kernel_files
|
||||
)
|
||||
with open(os.path.join(os.path.dirname(__file__), KERNEL_FILE_NAME), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
@@ -1,6 +1,5 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
@@ -1,6 +1,5 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
@@ -1,6 +1,5 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
@@ -1,6 +1,5 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
@@ -1,6 +1,5 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
@@ -1,10 +0,0 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel_bf16_ku4.cuh"
|
||||
#include "kernel_bf16_ku4b8.cuh"
|
||||
#include "kernel_bf16_ku8b128.cuh"
|
||||
#include "kernel_fp16_ku4.cuh"
|
||||
#include "kernel_fp16_ku4b8.cuh"
|
||||
#include "kernel_fp16_ku8b128.cuh"
|
||||
@@ -18,8 +18,6 @@
|
||||
/*
|
||||
* Adapted from https://github.com/IST-DASLab/marlin
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
@@ -24,7 +24,6 @@
|
||||
#endif
|
||||
|
||||
#include "kernel.h"
|
||||
#include "kernel_marlin.cuh"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
static_assert( \
|
||||
|
||||
@@ -23,7 +23,6 @@ limitations under the License.
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/util_type.cuh>
|
||||
#include <cuda/functional>
|
||||
#else
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#include <hipcub/util_type.hpp>
|
||||
@@ -34,16 +33,6 @@ limitations under the License.
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
|
||||
// Define reduction operators based on CUDA version
|
||||
// CUDA 13 (12.9+) deprecated cub::Max/Min in favor of cuda::maximum/minimum
|
||||
#if CUDA_VERSION >= 12090
|
||||
using MaxReduceOp = cuda::maximum<>;
|
||||
using MinReduceOp = cuda::minimum<>;
|
||||
#else
|
||||
using MaxReduceOp = cub::Max;
|
||||
using MinReduceOp = cub::Min;
|
||||
#endif
|
||||
|
||||
/// Aligned array type
|
||||
template <
|
||||
typename T,
|
||||
@@ -83,6 +72,7 @@ __launch_bounds__(TPB) __global__
|
||||
|
||||
const int thread_row_offset = blockIdx.x * num_cols;
|
||||
|
||||
cub::Sum sum;
|
||||
float threadData(-FLT_MAX);
|
||||
|
||||
// Don't touch finished rows.
|
||||
@@ -95,7 +85,7 @@ __launch_bounds__(TPB) __global__
|
||||
threadData = max(convert_to_float<T>(input[idx]), threadData);
|
||||
}
|
||||
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, MaxReduceOp());
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
float_max = maxElem;
|
||||
@@ -109,7 +99,7 @@ __launch_bounds__(TPB) __global__
|
||||
threadData += exp((convert_to_float<T>(input[idx]) - float_max));
|
||||
}
|
||||
|
||||
const auto Z = BlockReduce(tmpStorage).Sum(threadData);
|
||||
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
normalizing_factor = 1.f / Z;
|
||||
|
||||
@@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "sgl-kernel"
|
||||
version = "0.3.6.post2"
|
||||
version = "0.3.7"
|
||||
description = "Kernel Library for SGLang"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "sgl-kernel"
|
||||
version = "0.3.6.post2"
|
||||
version = "0.3.7"
|
||||
description = "Kernel Library for SGLang"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "sgl-kernel"
|
||||
version = "0.3.6.post2"
|
||||
version = "0.3.7"
|
||||
description = "Kernel Library for SGLang"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "0.3.6.post2"
|
||||
__version__ = "0.3.7"
|
||||
|
||||
Reference in New Issue
Block a user