diff --git a/.github/workflows/pr-test-amd.yml b/.github/workflows/pr-test-amd.yml index 03b9c433c..a510a5ded 100644 --- a/.github/workflows/pr-test-amd.yml +++ b/.github/workflows/pr-test-amd.yml @@ -47,7 +47,7 @@ jobs: run: | docker exec ci_sglang pip install --upgrade pip docker exec ci_sglang pip uninstall sgl-kernel -y || true - docker exec -w /sglang-checkout/sgl-kernel ci_sglang python3 setup_rocm.py install + docker exec -w /sglang-checkout/sgl-kernel ci_sglang bash -c "rm -f pyproject.toml && mv pyproject_rocm.toml pyproject.toml && python3 setup_rocm.py install" docker exec ci_sglang pip install -e "python[dev_hip]" docker exec -w / ci_sglang git clone https://github.com/merrymercy/human-eval.git @@ -87,7 +87,7 @@ jobs: run: | docker exec ci_sglang pip install --upgrade pip docker exec ci_sglang pip uninstall sgl-kernel -y || true - docker exec -w /sglang-checkout/sgl-kernel ci_sglang python3 setup_rocm.py install + docker exec -w /sglang-checkout/sgl-kernel ci_sglang bash -c "rm -f pyproject.toml && mv pyproject_rocm.toml pyproject.toml && python3 setup_rocm.py install" docker exec ci_sglang pip install -e "python[dev_hip]" docker exec -w / ci_sglang git clone https://github.com/merrymercy/human-eval.git diff --git a/.github/workflows/pr-test-sgl-kernel.yml b/.github/workflows/pr-test-sgl-kernel.yml index 6944f9a44..c42f487f7 100644 --- a/.github/workflows/pr-test-sgl-kernel.yml +++ b/.github/workflows/pr-test-sgl-kernel.yml @@ -84,6 +84,7 @@ jobs: pip3 uninstall sgl-kernel -y || true pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps pip3 list | grep sgl-kernel + git clone --recursive https://github.com/deepseek-ai/DeepGEMM.git && cd DeepGEMM && python3 setup.py develop - name: Run test timeout-minutes: 30 @@ -115,6 +116,7 @@ jobs: pip3 uninstall sgl-kernel -y || true pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps pip3 list | grep sgl-kernel + git clone --recursive https://github.com/deepseek-ai/DeepGEMM.git && cd DeepGEMM && python3 setup.py develop - name: Run test timeout-minutes: 30 diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index dfbbe5c70..014272c75 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -29,6 +29,8 @@ RUN git clone ${SGL_REPO} \ git checkout ${SGL_BRANCH}; \ fi \ && cd sgl-kernel \ + && rm -f pyproject.toml \ + && mv pyproject_rocm.toml pyproject.toml \ && python setup_rocm.py install \ && cd .. \ && if [ "$BUILD_TYPE" = "srt" ]; then \ diff --git a/sgl-kernel/3rdparty/flashinfer b/sgl-kernel/3rdparty/flashinfer index e5a3befbe..79fd1ae90 160000 --- a/sgl-kernel/3rdparty/flashinfer +++ b/sgl-kernel/3rdparty/flashinfer @@ -1 +1 @@ -Subproject commit e5a3befbe3e63025f0158bc96b218a9c5f402ac7 +Subproject commit 79fd1ae90d9b8098ca70dec6071da96f3f6da7b9 diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt new file mode 100644 index 000000000..dd889d37b --- /dev/null +++ b/sgl-kernel/CMakeLists.txt @@ -0,0 +1,166 @@ +cmake_minimum_required(VERSION 3.26 FATAL_ERROR) +project(sgl-kernel LANGUAGES CXX CUDA) + +# we only want to download 3rd, but not build them. +# FetchContent_MakeAvailable will build it. +cmake_policy(SET CMP0169 OLD) + +find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule REQUIRED) + +enable_language(CUDA) +find_package(CUDAToolkit REQUIRED) + +message(STATUS "Detected CUDA_VERSION=${CUDA_VERSION}") +if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8") + message("CUDA_VERSION ${CUDA_VERSION} >= 12.8") +elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4") + message("CUDA_VERSION ${CUDA_VERSION} >= 12.4") +elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.1") + message("CUDA_VERSION ${CUDA_VERSION} >= 12.1") +elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "11.8") + message("CUDA_VERSION ${CUDA_VERSION} >= 11.8") +endif() + +find_package(Torch REQUIRED) + +include(FetchContent) + +FetchContent_Declare( + repo-cutlass + GIT_REPOSITORY https://github.com/NVIDIA/cutlass + GIT_TAG 62750a2b75c802660e4894434dc55e839f322277 + GIT_SHALLOW ON +) +FetchContent_Populate(repo-cutlass) +FetchContent_Declare( + repo-deepgemm + GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM + GIT_TAG c57699ac933a93651c34d365797c2d8b41a4765b + GIT_SHALLOW ON +) +FetchContent_Populate(repo-deepgemm) +FetchContent_Declare( + repo-flashinfer + GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer + GIT_TAG 79fd1ae90d9b8098ca70dec6071da96f3f6da7b9 + GIT_SHALLOW OFF +) +FetchContent_Populate(repo-flashinfer) + +include_directories( + ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/csrc + ${repo-cutlass_SOURCE_DIR}/include + ${repo-cutlass_SOURCE_DIR}/tools/util/include + ${repo-flashinfer_SOURCE_DIR}/include + ${repo-flashinfer_SOURCE_DIR}/csrc +) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") + +set(SGL_KERNEL_CUDA_FLAGS + "-DNDEBUG" + "-DOPERATOR_NAMESPACE=sgl-kernel" + "-O3" + "-Xcompiler" + "-fPIC" + "-gencode=arch=compute_75,code=sm_75" + "-gencode=arch=compute_80,code=sm_80" + "-gencode=arch=compute_89,code=sm_89" + "-gencode=arch=compute_90,code=sm_90" + "-std=c++17" + "-DFLASHINFER_ENABLE_F16" + "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1" + "-DCUTLASS_VERSIONS_GENERATED" + "-DCUTE_USE_PACKED_TUPLE=1" + "-DCUTLASS_TEST_LEVEL=0" + "-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1" + "-DCUTLASS_DEBUG_TRACE_LEVEL=0" + "--expt-relaxed-constexpr" + "-Xcompiler=-Wconversion" + "-Xcompiler=-fno-strict-aliasing" +) + +option(SGL_KERNEL_ENABLE_SM100A "Enable SM100A" OFF) +option(SGL_KERNEL_ENABLE_SM90A "Enable SM90A" OFF) +option(SGL_KERNEL_ENABLE_BF16 "Enable BF16" ON) +option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON) + +if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A) + list(APPEND SGL_KERNEL_CUDA_FLAGS + "-gencode=arch=compute_100,code=sm_100" + "-gencode=arch=compute_100a,code=sm_100a" + ) +else() + list(APPEND SGL_KERNEL_CUDA_FLAGS + "-use_fast_math" + ) +endif() + +if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A) + list(APPEND SGL_KERNEL_CUDA_FLAGS + "-gencode=arch=compute_90a,code=sm_90a" + ) +endif() + +if (SGL_KERNEL_ENABLE_BF16) + list(APPEND SGL_KERNEL_CUDA_FLAGS + "-DFLASHINFER_ENABLE_BF16" + ) +endif() + +if (SGL_KERNEL_ENABLE_FP8) + list(APPEND SGL_KERNEL_CUDA_FLAGS + "-DFLASHINFER_ENABLE_FP8" + "-DFLASHINFER_ENABLE_FP8_E4M3" + "-DFLASHINFER_ENABLE_FP8_E5M2" + ) +endif() + +string(REPLACE "-D__CUDA_NO_HALF_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}") +string(REPLACE "-D__CUDA_NO_HALF_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}") +string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}") +string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}") + +set(SOURCES + "csrc/allreduce/trt_reduce_internal.cu" + "csrc/allreduce/trt_reduce_kernel.cu" + "csrc/attention/lightning_attention_decode_kernel.cu" + "csrc/elementwise/activation.cu" + "csrc/elementwise/fused_add_rms_norm_kernel.cu" + "csrc/elementwise/rope.cu" + "csrc/gemm/awq_kernel.cu" + "csrc/gemm/bmm_fp8.cu" + "csrc/gemm/cublas_grouped_gemm.cu" + "csrc/gemm/fp8_blockwise_gemm_kernel.cu" + "csrc/gemm/fp8_gemm_kernel.cu" + "csrc/gemm/int8_gemm_kernel.cu" + "csrc/gemm/nvfp4_quant_entry.cu" + "csrc/gemm/nvfp4_quant_kernels.cu" + "csrc/gemm/nvfp4_scaled_mm_entry.cu" + "csrc/gemm/nvfp4_scaled_mm_kernels.cu" + "csrc/gemm/per_tensor_quant_fp8.cu" + "csrc/gemm/per_token_group_quant_8bit.cu" + "csrc/gemm/per_token_quant_fp8.cu" + "csrc/moe/moe_align_kernel.cu" + "csrc/moe/moe_topk_softmax_kernels.cu" + "csrc/speculative/eagle_utils.cu" + "csrc/speculative/speculative_sampling.cu" + "csrc/speculative/packbit.cu" + "csrc/torch_extension.cc" + "${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu" + "${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu" + "${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu" +) + +# Support abi3 for build +Python_add_library(common_ops MODULE USE_SABI 3.9 WITH_SOABI ${SOURCES}) + +target_compile_options(common_ops PRIVATE $<$:${SGL_KERNEL_CUDA_FLAGS}>) + +target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS}) + +target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt) + +install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel") diff --git a/sgl-kernel/Makefile b/sgl-kernel/Makefile index 53375fa0f..2eb9ddc55 100644 --- a/sgl-kernel/Makefile +++ b/sgl-kernel/Makefile @@ -19,13 +19,14 @@ submodule: ## Initialize and update git submodules @git submodule update --init --recursive ln: submodule ## Create compilation database - @rm -rf build && bear python3 setup.py build + @rm -rf build && mkdir build && cd build && cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=YES + install: submodule ## Install package in development mode @pip install -e . build: submodule ## Build and install wheel package - @rm -rf dist/* || true && export MAX_JOBS=$(nproc) && python3 setup.py bdist_wheel && pip3 install dist/*whl --force-reinstall --no-deps + @rm -rf dist/* || true && export MAX_JOBS=$(nproc) && uv build --wheel -Cbuild-dir=build . --verbose --color=always && pip3 install dist/*whl --force-reinstall --no-deps clean: ## Remove build artifacts @rm -rf build dist *.egg-info diff --git a/sgl-kernel/build.sh b/sgl-kernel/build.sh index a9eb9c5cc..7d5ee6b6b 100755 --- a/sgl-kernel/build.sh +++ b/sgl-kernel/build.sh @@ -15,7 +15,7 @@ docker run --rm \ pytorch/manylinux-builder:cuda${CUDA_VERSION} \ bash -c " ${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==2.5.1 --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \ - ${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir ninja setuptools==75.0.0 wheel==0.41.0 numpy && \ + ${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir ninja setuptools==75.0.0 wheel==0.41.0 numpy uv && \ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' && \ export CUDA_VERSION=${CUDA_VERSION} && \ export SGL_KERNEL_ENABLE_BF16=1 && \ @@ -25,5 +25,5 @@ docker run --rm \ ln -s /usr/local/cuda-${CUDA_VERSION}/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/x86_64-linux-gnu/libcuda.so && \ cd /sgl-kernel && \ ls -la ${PYTHON_ROOT_PATH}/lib/python${PYTHON_VERSION}/site-packages/wheel/ && \ - PYTHONPATH=${PYTHON_ROOT_PATH}/lib/python${PYTHON_VERSION}/site-packages ${PYTHON_ROOT_PATH}/bin/python setup.py bdist_wheel + PYTHONPATH=${PYTHON_ROOT_PATH}/lib/python${PYTHON_VERSION}/site-packages ${PYTHON_ROOT_PATH}/bin/python -m uv build --wheel -Cbuild-dir=build . --color=always " diff --git a/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu b/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu index f9d524f60..01bd4797c 100644 --- a/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu +++ b/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu @@ -18,7 +18,7 @@ limitations under the License. #include #include #include -#include +#include #define THREADS_PER_BLOCK 128 diff --git a/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu b/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu index 2add0826f..ca7d131e1 100644 --- a/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu +++ b/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu @@ -12,7 +12,6 @@ #include #include #include -#include #include #include diff --git a/sgl-kernel/csrc/moe/moe_align_kernel.cu b/sgl-kernel/csrc/moe/moe_align_kernel.cu index cfb7adca5..6ffa73924 100644 --- a/sgl-kernel/csrc/moe/moe_align_kernel.cu +++ b/sgl-kernel/csrc/moe/moe_align_kernel.cu @@ -16,7 +16,6 @@ limitations under the License. #include #include #include -#include #include diff --git a/sgl-kernel/csrc/torch_extension.cc b/sgl-kernel/csrc/torch_extension.cc index 80c5c73d1..16d9adb12 100644 --- a/sgl-kernel/csrc/torch_extension.cc +++ b/sgl-kernel/csrc/torch_extension.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #include +#include #include #include "sgl_kernel_ops.h" @@ -178,9 +178,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs); m.def( - "top_k_renorm_probs_wrapper(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int " + "top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int " "cuda_stream) -> ()"); - m.impl("top_k_renorm_probs_wrapper", torch::kCUDA, &top_k_renorm_probs_wrapper); + m.impl("top_k_renorm_probs", torch::kCUDA, &top_k_renorm_probs); m.def( "top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int " diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 36921a29f..f4961ab4f 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -15,8 +15,11 @@ limitations under the License. #pragma once +#include +#include #include -#include +#include +#include #include @@ -253,23 +256,12 @@ void min_p_sampling_from_probs( double min_p_val, bool deterministic, int64_t cuda_stream); -// top k renorm probs -// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension. void top_k_renorm_probs( - at::Tensor probs, - at::Tensor renorm_probs, - std::optional maybe_top_k_arr, - unsigned int top_k_val, - int64_t cuda_stream); -// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension. -inline void top_k_renorm_probs_wrapper( at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_k_arr, int64_t top_k_val, - int64_t cuda_stream) { - top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast(top_k_val), cuda_stream); -} + int64_t cuda_stream); void top_p_renorm_probs( at::Tensor probs, at::Tensor renorm_probs, diff --git a/sgl-kernel/include/utils.h b/sgl-kernel/include/utils.h index 7213c05c5..229c6e9c4 100644 --- a/sgl-kernel/include/utils.h +++ b/sgl-kernel/include/utils.h @@ -15,14 +15,190 @@ limitations under the License. #pragma once +#include #include -#ifndef USE_ROCM -#include -#endif -#include +#include #include +#ifndef USE_ROCM +// Adapt from FlashInfer +#ifdef FLASHINFER_ENABLE_F16 +#define _DISPATCH_CASE_F16(c_type, ...) \ + case at::ScalarType::Half: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } +#else +#define _DISPATCH_CASE_F16(c_type, ...) +#endif + +#ifdef FLASHINFER_ENABLE_BF16 +#define _DISPATCH_CASE_BF16(c_type, ...) \ + case at::ScalarType::BFloat16: { \ + using c_type = nv_bfloat16; \ + return __VA_ARGS__(); \ + } +#else +#define _DISPATCH_CASE_BF16(c_type, ...) +#endif + +#ifdef FLASHINFER_ENABLE_FP8_E4M3 +#define _DISPATCH_CASE_FP8_E4M3(c_type, ...) \ + case at::ScalarType::Float8_e4m3fn: { \ + using c_type = __nv_fp8_e4m3; \ + return __VA_ARGS__(); \ + } +#else +#define _DISPATCH_CASE_FP8_E4M3(c_type, ...) +#endif + +#ifdef FLASHINFER_ENABLE_FP8_E5M2 +#define _DISPATCH_CASE_FP8_E5M2(c_type, ...) \ + case at::ScalarType::Float8_e5m2: { \ + using c_type = __nv_fp8_e5m2; \ + return __VA_ARGS__(); \ + } +#else +#define _DISPATCH_CASE_FP8_E5M2(c_type, ...) +#endif + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + _DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch fp8 data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define _DISPATCH_SWITCH(var_name, cond, ...) \ + [&]() -> bool { \ + switch (cond) { \ + __VA_ARGS__ \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch " var_name " " << int(cond); \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define _DISPATCH_SWITCH_U16x2(var1_name, var2_name, cond1, cond2, ...) \ + [&]() -> bool { \ + switch (pack_u16(cond1, cond2)) { \ + __VA_ARGS__ \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch (" var1_name ", " var2_name "): (" << int(cond1) << ", " \ + << int(cond2) << ")"; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define _DISPATCH_CASE(case_expr, case_var, ...) \ + case case_expr: { \ + constexpr auto case_var = case_expr; \ + return __VA_ARGS__(); \ + } + +#define _DISPATCH_CASE_U16x2(case_expr1, case_expr2, case_var1, case_var2, ...) \ + case pack_u16(case_expr1, case_expr2): { \ + constexpr auto case_var1 = case_expr1; \ + constexpr auto case_var2 = case_expr2; \ + return __VA_ARGS__(); \ + } + +#define DISPATCH_BOOL(expr, const_expr, ...) \ + [&]() -> bool { \ + if (expr) { \ + constexpr bool const_expr = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool const_expr = false; \ + return __VA_ARGS__(); \ + } \ + }() + +inline void check_shape(const at::Tensor& a, const at::Tensor& b, const char* a_name, const char* b_name) { + TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ", b.dim()); + for (int i = 0; i < a.dim(); ++i) { + TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, ".size(", i, ")"); + } +} + +inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { + return (uint32_t(a) << 16) | uint32_t(b); +} + +#define CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads) \ + TORCH_CHECK( \ + num_qo_heads % num_kv_heads == 0, \ + "num_qo_heads(", \ + num_qo_heads, \ + ") must be divisible by num_kv_heads(", \ + num_kv_heads, \ + ")") + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_LAST_DIM_CONTIGUOUS(x) \ + TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension") + +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) +#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_LAST_DIM_CONTIGUOUS(x) + +#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") + +#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b) + +#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) + +#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b) + +inline bool is_float8_tensor(const at::Tensor& tensor) { + return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn || tensor.scalar_type() == at::ScalarType::Float8_e5m2; +} +#endif + struct cuda_error : public std::runtime_error { /** * @brief Constructs a `cuda_error` object with the given `message`. diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 850f4ca20..2169f5118 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -1,11 +1,10 @@ [build-system] requires = [ - "setuptools>=75.0", "scikit-build-core>=0.10", "torch==2.5.1", "wheel", ] -build-backend = "setuptools.build_meta" +build-backend = "scikit_build_core.build" [project] name = "sgl-kernel" @@ -30,3 +29,11 @@ exclude = [ "dist*", "tests*", ] + +[tool.scikit-build] +cmake.build-type = "Release" +minimum-version = "build-system.requires" + +wheel.py-api = "cp39" +wheel.license-files = [] +wheel.packages = ["python/sgl_kernel"] diff --git a/sgl-kernel/pyproject_rocm.toml b/sgl-kernel/pyproject_rocm.toml new file mode 100644 index 000000000..850f4ca20 --- /dev/null +++ b/sgl-kernel/pyproject_rocm.toml @@ -0,0 +1,32 @@ +[build-system] +requires = [ + "setuptools>=75.0", + "scikit-build-core>=0.10", + "torch==2.5.1", + "wheel", +] +build-backend = "setuptools.build_meta" + +[project] +name = "sgl-kernel" +version = "0.0.5.post3" +description = "Kernel Library for SGLang" +readme = "README.md" +requires-python = ">=3.9" +license = { file = "LICENSE" } +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Environment :: GPU :: NVIDIA CUDA" +] +dependencies = [] + +[project.urls] +"Homepage" = "https://github.com/sgl-project/sglang/tree/main/sgl-kernel" +"Bug Tracker" = "https://github.com/sgl-project/sglang/issues" + +[tool.wheel] +exclude = [ + "dist*", + "tests*", +] diff --git a/sgl-kernel/python/sgl_kernel/sampling.py b/sgl-kernel/python/sgl_kernel/sampling.py index 7bf10bd4a..2f57f1313 100644 --- a/sgl-kernel/python/sgl_kernel/sampling.py +++ b/sgl-kernel/python/sgl_kernel/sampling.py @@ -12,7 +12,7 @@ def _top_k_renorm_probs_internal( probs = probs.float() maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None renorm_probs = torch.empty_like(probs) - torch.ops.sgl_kernel.top_k_renorm_probs_wrapper( + torch.ops.sgl_kernel.top_k_renorm_probs( probs, renorm_probs, maybe_top_k_arr, diff --git a/sgl-kernel/rename_wheels.sh b/sgl-kernel/rename_wheels.sh new file mode 100755 index 000000000..95c241893 --- /dev/null +++ b/sgl-kernel/rename_wheels.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +set -ex + +WHEEL_DIR="dist" + +wheel_files=($WHEEL_DIR/*.whl) +for wheel in "${wheel_files[@]}"; do + new_wheel="${wheel/linux/manylinux2014}" + + if [[ "$wheel" != "$new_wheel" ]]; then + echo "Renaming $wheel to $new_wheel" + mv -- "$wheel" "$new_wheel" + + fi +done + +echo "Wheel renaming completed." diff --git a/sgl-kernel/setup_rocm.py b/sgl-kernel/setup_rocm.py index b147e6b53..66a2d0a5c 100644 --- a/sgl-kernel/setup_rocm.py +++ b/sgl-kernel/setup_rocm.py @@ -21,9 +21,6 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension root = Path(__file__).parent.resolve() -if "bdist_wheel" in sys.argv and "--plat-name" not in sys.argv: - sys.argv.extend(["--plat-name", "manylinux2014_x86_64"]) - def _get_version(): with open(root / "pyproject.toml") as f: