support cmake for sgl-kernel (#4706)
Co-authored-by: hebiao064 <hebiaobuaa@gmail.com> Co-authored-by: yinfan98 <1106310035@qq.com>
This commit is contained in:
4
.github/workflows/pr-test-amd.yml
vendored
4
.github/workflows/pr-test-amd.yml
vendored
@@ -47,7 +47,7 @@ jobs:
|
||||
run: |
|
||||
docker exec ci_sglang pip install --upgrade pip
|
||||
docker exec ci_sglang pip uninstall sgl-kernel -y || true
|
||||
docker exec -w /sglang-checkout/sgl-kernel ci_sglang python3 setup_rocm.py install
|
||||
docker exec -w /sglang-checkout/sgl-kernel ci_sglang bash -c "rm -f pyproject.toml && mv pyproject_rocm.toml pyproject.toml && python3 setup_rocm.py install"
|
||||
docker exec ci_sglang pip install -e "python[dev_hip]"
|
||||
|
||||
docker exec -w / ci_sglang git clone https://github.com/merrymercy/human-eval.git
|
||||
@@ -87,7 +87,7 @@ jobs:
|
||||
run: |
|
||||
docker exec ci_sglang pip install --upgrade pip
|
||||
docker exec ci_sglang pip uninstall sgl-kernel -y || true
|
||||
docker exec -w /sglang-checkout/sgl-kernel ci_sglang python3 setup_rocm.py install
|
||||
docker exec -w /sglang-checkout/sgl-kernel ci_sglang bash -c "rm -f pyproject.toml && mv pyproject_rocm.toml pyproject.toml && python3 setup_rocm.py install"
|
||||
docker exec ci_sglang pip install -e "python[dev_hip]"
|
||||
|
||||
docker exec -w / ci_sglang git clone https://github.com/merrymercy/human-eval.git
|
||||
|
||||
2
.github/workflows/pr-test-sgl-kernel.yml
vendored
2
.github/workflows/pr-test-sgl-kernel.yml
vendored
@@ -84,6 +84,7 @@ jobs:
|
||||
pip3 uninstall sgl-kernel -y || true
|
||||
pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps
|
||||
pip3 list | grep sgl-kernel
|
||||
git clone --recursive https://github.com/deepseek-ai/DeepGEMM.git && cd DeepGEMM && python3 setup.py develop
|
||||
|
||||
- name: Run test
|
||||
timeout-minutes: 30
|
||||
@@ -115,6 +116,7 @@ jobs:
|
||||
pip3 uninstall sgl-kernel -y || true
|
||||
pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps
|
||||
pip3 list | grep sgl-kernel
|
||||
git clone --recursive https://github.com/deepseek-ai/DeepGEMM.git && cd DeepGEMM && python3 setup.py develop
|
||||
|
||||
- name: Run test
|
||||
timeout-minutes: 30
|
||||
|
||||
@@ -29,6 +29,8 @@ RUN git clone ${SGL_REPO} \
|
||||
git checkout ${SGL_BRANCH}; \
|
||||
fi \
|
||||
&& cd sgl-kernel \
|
||||
&& rm -f pyproject.toml \
|
||||
&& mv pyproject_rocm.toml pyproject.toml \
|
||||
&& python setup_rocm.py install \
|
||||
&& cd .. \
|
||||
&& if [ "$BUILD_TYPE" = "srt" ]; then \
|
||||
|
||||
2
sgl-kernel/3rdparty/flashinfer
vendored
2
sgl-kernel/3rdparty/flashinfer
vendored
Submodule sgl-kernel/3rdparty/flashinfer updated: e5a3befbe3...79fd1ae90d
166
sgl-kernel/CMakeLists.txt
Normal file
166
sgl-kernel/CMakeLists.txt
Normal file
@@ -0,0 +1,166 @@
|
||||
cmake_minimum_required(VERSION 3.26 FATAL_ERROR)
|
||||
project(sgl-kernel LANGUAGES CXX CUDA)
|
||||
|
||||
# we only want to download 3rd, but not build them.
|
||||
# FetchContent_MakeAvailable will build it.
|
||||
cmake_policy(SET CMP0169 OLD)
|
||||
|
||||
find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule REQUIRED)
|
||||
|
||||
enable_language(CUDA)
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
|
||||
message(STATUS "Detected CUDA_VERSION=${CUDA_VERSION}")
|
||||
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8")
|
||||
message("CUDA_VERSION ${CUDA_VERSION} >= 12.8")
|
||||
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4")
|
||||
message("CUDA_VERSION ${CUDA_VERSION} >= 12.4")
|
||||
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.1")
|
||||
message("CUDA_VERSION ${CUDA_VERSION} >= 12.1")
|
||||
elseif ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "11.8")
|
||||
message("CUDA_VERSION ${CUDA_VERSION} >= 11.8")
|
||||
endif()
|
||||
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
FetchContent_Declare(
|
||||
repo-cutlass
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
|
||||
GIT_TAG 62750a2b75c802660e4894434dc55e839f322277
|
||||
GIT_SHALLOW ON
|
||||
)
|
||||
FetchContent_Populate(repo-cutlass)
|
||||
FetchContent_Declare(
|
||||
repo-deepgemm
|
||||
GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM
|
||||
GIT_TAG c57699ac933a93651c34d365797c2d8b41a4765b
|
||||
GIT_SHALLOW ON
|
||||
)
|
||||
FetchContent_Populate(repo-deepgemm)
|
||||
FetchContent_Declare(
|
||||
repo-flashinfer
|
||||
GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer
|
||||
GIT_TAG 79fd1ae90d9b8098ca70dec6071da96f3f6da7b9
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-flashinfer)
|
||||
|
||||
include_directories(
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/csrc
|
||||
${repo-cutlass_SOURCE_DIR}/include
|
||||
${repo-cutlass_SOURCE_DIR}/tools/util/include
|
||||
${repo-flashinfer_SOURCE_DIR}/include
|
||||
${repo-flashinfer_SOURCE_DIR}/csrc
|
||||
)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
|
||||
|
||||
set(SGL_KERNEL_CUDA_FLAGS
|
||||
"-DNDEBUG"
|
||||
"-DOPERATOR_NAMESPACE=sgl-kernel"
|
||||
"-O3"
|
||||
"-Xcompiler"
|
||||
"-fPIC"
|
||||
"-gencode=arch=compute_75,code=sm_75"
|
||||
"-gencode=arch=compute_80,code=sm_80"
|
||||
"-gencode=arch=compute_89,code=sm_89"
|
||||
"-gencode=arch=compute_90,code=sm_90"
|
||||
"-std=c++17"
|
||||
"-DFLASHINFER_ENABLE_F16"
|
||||
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
|
||||
"-DCUTLASS_VERSIONS_GENERATED"
|
||||
"-DCUTE_USE_PACKED_TUPLE=1"
|
||||
"-DCUTLASS_TEST_LEVEL=0"
|
||||
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
|
||||
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
|
||||
"--expt-relaxed-constexpr"
|
||||
"-Xcompiler=-Wconversion"
|
||||
"-Xcompiler=-fno-strict-aliasing"
|
||||
)
|
||||
|
||||
option(SGL_KERNEL_ENABLE_SM100A "Enable SM100A" OFF)
|
||||
option(SGL_KERNEL_ENABLE_SM90A "Enable SM90A" OFF)
|
||||
option(SGL_KERNEL_ENABLE_BF16 "Enable BF16" ON)
|
||||
option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON)
|
||||
|
||||
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-gencode=arch=compute_100,code=sm_100"
|
||||
"-gencode=arch=compute_100a,code=sm_100a"
|
||||
)
|
||||
else()
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-use_fast_math"
|
||||
)
|
||||
endif()
|
||||
|
||||
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A)
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-gencode=arch=compute_90a,code=sm_90a"
|
||||
)
|
||||
endif()
|
||||
|
||||
if (SGL_KERNEL_ENABLE_BF16)
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-DFLASHINFER_ENABLE_BF16"
|
||||
)
|
||||
endif()
|
||||
|
||||
if (SGL_KERNEL_ENABLE_FP8)
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-DFLASHINFER_ENABLE_FP8"
|
||||
"-DFLASHINFER_ENABLE_FP8_E4M3"
|
||||
"-DFLASHINFER_ENABLE_FP8_E5M2"
|
||||
)
|
||||
endif()
|
||||
|
||||
string(REPLACE "-D__CUDA_NO_HALF_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
||||
string(REPLACE "-D__CUDA_NO_HALF_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
||||
string(REPLACE "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
||||
string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
|
||||
|
||||
set(SOURCES
|
||||
"csrc/allreduce/trt_reduce_internal.cu"
|
||||
"csrc/allreduce/trt_reduce_kernel.cu"
|
||||
"csrc/attention/lightning_attention_decode_kernel.cu"
|
||||
"csrc/elementwise/activation.cu"
|
||||
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
|
||||
"csrc/elementwise/rope.cu"
|
||||
"csrc/gemm/awq_kernel.cu"
|
||||
"csrc/gemm/bmm_fp8.cu"
|
||||
"csrc/gemm/cublas_grouped_gemm.cu"
|
||||
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
|
||||
"csrc/gemm/fp8_gemm_kernel.cu"
|
||||
"csrc/gemm/int8_gemm_kernel.cu"
|
||||
"csrc/gemm/nvfp4_quant_entry.cu"
|
||||
"csrc/gemm/nvfp4_quant_kernels.cu"
|
||||
"csrc/gemm/nvfp4_scaled_mm_entry.cu"
|
||||
"csrc/gemm/nvfp4_scaled_mm_kernels.cu"
|
||||
"csrc/gemm/per_tensor_quant_fp8.cu"
|
||||
"csrc/gemm/per_token_group_quant_8bit.cu"
|
||||
"csrc/gemm/per_token_quant_fp8.cu"
|
||||
"csrc/moe/moe_align_kernel.cu"
|
||||
"csrc/moe/moe_topk_softmax_kernels.cu"
|
||||
"csrc/speculative/eagle_utils.cu"
|
||||
"csrc/speculative/speculative_sampling.cu"
|
||||
"csrc/speculative/packbit.cu"
|
||||
"csrc/torch_extension.cc"
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
|
||||
)
|
||||
|
||||
# Support abi3 for build
|
||||
Python_add_library(common_ops MODULE USE_SABI 3.9 WITH_SOABI ${SOURCES})
|
||||
|
||||
target_compile_options(common_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
|
||||
|
||||
target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS})
|
||||
|
||||
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt)
|
||||
|
||||
install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel")
|
||||
@@ -19,13 +19,14 @@ submodule: ## Initialize and update git submodules
|
||||
@git submodule update --init --recursive
|
||||
|
||||
ln: submodule ## Create compilation database
|
||||
@rm -rf build && bear python3 setup.py build
|
||||
@rm -rf build && mkdir build && cd build && cmake .. -DCMAKE_EXPORT_COMPILE_COMMANDS=YES
|
||||
|
||||
|
||||
install: submodule ## Install package in development mode
|
||||
@pip install -e .
|
||||
|
||||
build: submodule ## Build and install wheel package
|
||||
@rm -rf dist/* || true && export MAX_JOBS=$(nproc) && python3 setup.py bdist_wheel && pip3 install dist/*whl --force-reinstall --no-deps
|
||||
@rm -rf dist/* || true && export MAX_JOBS=$(nproc) && uv build --wheel -Cbuild-dir=build . --verbose --color=always && pip3 install dist/*whl --force-reinstall --no-deps
|
||||
|
||||
clean: ## Remove build artifacts
|
||||
@rm -rf build dist *.egg-info
|
||||
|
||||
@@ -15,7 +15,7 @@ docker run --rm \
|
||||
pytorch/manylinux-builder:cuda${CUDA_VERSION} \
|
||||
bash -c "
|
||||
${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==2.5.1 --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \
|
||||
${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir ninja setuptools==75.0.0 wheel==0.41.0 numpy && \
|
||||
${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir ninja setuptools==75.0.0 wheel==0.41.0 numpy uv && \
|
||||
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' && \
|
||||
export CUDA_VERSION=${CUDA_VERSION} && \
|
||||
export SGL_KERNEL_ENABLE_BF16=1 && \
|
||||
@@ -25,5 +25,5 @@ docker run --rm \
|
||||
ln -s /usr/local/cuda-${CUDA_VERSION}/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/x86_64-linux-gnu/libcuda.so && \
|
||||
cd /sgl-kernel && \
|
||||
ls -la ${PYTHON_ROOT_PATH}/lib/python${PYTHON_VERSION}/site-packages/wheel/ && \
|
||||
PYTHONPATH=${PYTHON_ROOT_PATH}/lib/python${PYTHON_VERSION}/site-packages ${PYTHON_ROOT_PATH}/bin/python setup.py bdist_wheel
|
||||
PYTHONPATH=${PYTHON_ROOT_PATH}/lib/python${PYTHON_VERSION}/site-packages ${PYTHON_ROOT_PATH}/bin/python -m uv build --wheel -Cbuild-dir=build . --color=always
|
||||
"
|
||||
|
||||
@@ -18,7 +18,7 @@ limitations under the License.
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/extension.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#define THREADS_PER_BLOCK 128
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/all.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
|
||||
@@ -16,7 +16,6 @@ limitations under the License.
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <THC/THCAtomics.cuh>
|
||||
|
||||
|
||||
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <torch/all.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "sgl_kernel_ops.h"
|
||||
@@ -178,9 +178,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs);
|
||||
|
||||
m.def(
|
||||
"top_k_renorm_probs_wrapper(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int "
|
||||
"top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int "
|
||||
"cuda_stream) -> ()");
|
||||
m.impl("top_k_renorm_probs_wrapper", torch::kCUDA, &top_k_renorm_probs_wrapper);
|
||||
m.impl("top_k_renorm_probs", torch::kCUDA, &top_k_renorm_probs);
|
||||
|
||||
m.def(
|
||||
"top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int "
|
||||
|
||||
@@ -15,8 +15,11 @@ limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <Python.h>
|
||||
#include <torch/extension.h>
|
||||
#include <torch/library.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
@@ -253,23 +256,12 @@ void min_p_sampling_from_probs(
|
||||
double min_p_val,
|
||||
bool deterministic,
|
||||
int64_t cuda_stream);
|
||||
// top k renorm probs
|
||||
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
|
||||
void top_k_renorm_probs(
|
||||
at::Tensor probs,
|
||||
at::Tensor renorm_probs,
|
||||
std::optional<at::Tensor> maybe_top_k_arr,
|
||||
unsigned int top_k_val,
|
||||
int64_t cuda_stream);
|
||||
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
|
||||
inline void top_k_renorm_probs_wrapper(
|
||||
at::Tensor probs,
|
||||
at::Tensor renorm_probs,
|
||||
std::optional<at::Tensor> maybe_top_k_arr,
|
||||
int64_t top_k_val,
|
||||
int64_t cuda_stream) {
|
||||
top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast<unsigned int>(top_k_val), cuda_stream);
|
||||
}
|
||||
int64_t cuda_stream);
|
||||
void top_p_renorm_probs(
|
||||
at::Tensor probs,
|
||||
at::Tensor renorm_probs,
|
||||
|
||||
@@ -15,14 +15,190 @@ limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <cuda_runtime.h>
|
||||
#ifndef USE_ROCM
|
||||
#include <pytorch_extension_utils.h>
|
||||
#endif
|
||||
#include <torch/extension.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Adapt from FlashInfer
|
||||
#ifdef FLASHINFER_ENABLE_F16
|
||||
#define _DISPATCH_CASE_F16(c_type, ...) \
|
||||
case at::ScalarType::Half: { \
|
||||
using c_type = nv_half; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
#else
|
||||
#define _DISPATCH_CASE_F16(c_type, ...)
|
||||
#endif
|
||||
|
||||
#ifdef FLASHINFER_ENABLE_BF16
|
||||
#define _DISPATCH_CASE_BF16(c_type, ...) \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using c_type = nv_bfloat16; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
#else
|
||||
#define _DISPATCH_CASE_BF16(c_type, ...)
|
||||
#endif
|
||||
|
||||
#ifdef FLASHINFER_ENABLE_FP8_E4M3
|
||||
#define _DISPATCH_CASE_FP8_E4M3(c_type, ...) \
|
||||
case at::ScalarType::Float8_e4m3fn: { \
|
||||
using c_type = __nv_fp8_e4m3; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
#else
|
||||
#define _DISPATCH_CASE_FP8_E4M3(c_type, ...)
|
||||
#endif
|
||||
|
||||
#ifdef FLASHINFER_ENABLE_FP8_E5M2
|
||||
#define _DISPATCH_CASE_FP8_E5M2(c_type, ...) \
|
||||
case at::ScalarType::Float8_e5m2: { \
|
||||
using c_type = __nv_fp8_e5m2; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
#else
|
||||
#define _DISPATCH_CASE_FP8_E5M2(c_type, ...)
|
||||
#endif
|
||||
|
||||
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \
|
||||
[&]() -> bool { \
|
||||
switch (pytorch_dtype) { \
|
||||
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
|
||||
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
|
||||
default: \
|
||||
std::ostringstream oss; \
|
||||
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
|
||||
TORCH_CHECK(false, oss.str()); \
|
||||
return false; \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \
|
||||
[&]() -> bool { \
|
||||
switch (pytorch_dtype) { \
|
||||
_DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \
|
||||
_DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \
|
||||
default: \
|
||||
std::ostringstream oss; \
|
||||
oss << __PRETTY_FUNCTION__ << " failed to dispatch fp8 data type " << pytorch_dtype; \
|
||||
TORCH_CHECK(false, oss.str()); \
|
||||
return false; \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \
|
||||
[&]() -> bool { \
|
||||
switch (pytorch_dtype) { \
|
||||
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
|
||||
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
|
||||
_DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \
|
||||
_DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \
|
||||
default: \
|
||||
std::ostringstream oss; \
|
||||
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
|
||||
TORCH_CHECK(false, oss.str()); \
|
||||
return false; \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define _DISPATCH_SWITCH(var_name, cond, ...) \
|
||||
[&]() -> bool { \
|
||||
switch (cond) { \
|
||||
__VA_ARGS__ \
|
||||
default: \
|
||||
std::ostringstream oss; \
|
||||
oss << __PRETTY_FUNCTION__ << " failed to dispatch " var_name " " << int(cond); \
|
||||
TORCH_CHECK(false, oss.str()); \
|
||||
return false; \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define _DISPATCH_SWITCH_U16x2(var1_name, var2_name, cond1, cond2, ...) \
|
||||
[&]() -> bool { \
|
||||
switch (pack_u16(cond1, cond2)) { \
|
||||
__VA_ARGS__ \
|
||||
default: \
|
||||
std::ostringstream oss; \
|
||||
oss << __PRETTY_FUNCTION__ << " failed to dispatch (" var1_name ", " var2_name "): (" << int(cond1) << ", " \
|
||||
<< int(cond2) << ")"; \
|
||||
TORCH_CHECK(false, oss.str()); \
|
||||
return false; \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define _DISPATCH_CASE(case_expr, case_var, ...) \
|
||||
case case_expr: { \
|
||||
constexpr auto case_var = case_expr; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
|
||||
#define _DISPATCH_CASE_U16x2(case_expr1, case_expr2, case_var1, case_var2, ...) \
|
||||
case pack_u16(case_expr1, case_expr2): { \
|
||||
constexpr auto case_var1 = case_expr1; \
|
||||
constexpr auto case_var2 = case_expr2; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
|
||||
#define DISPATCH_BOOL(expr, const_expr, ...) \
|
||||
[&]() -> bool { \
|
||||
if (expr) { \
|
||||
constexpr bool const_expr = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool const_expr = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
inline void check_shape(const at::Tensor& a, const at::Tensor& b, const char* a_name, const char* b_name) {
|
||||
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ", b.dim());
|
||||
for (int i = 0; i < a.dim(); ++i) {
|
||||
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, ".size(", i, ")");
|
||||
}
|
||||
}
|
||||
|
||||
inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
|
||||
return (uint32_t(a) << 16) | uint32_t(b);
|
||||
}
|
||||
|
||||
#define CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads) \
|
||||
TORCH_CHECK( \
|
||||
num_qo_heads % num_kv_heads == 0, \
|
||||
"num_qo_heads(", \
|
||||
num_qo_heads, \
|
||||
") must be divisible by num_kv_heads(", \
|
||||
num_kv_heads, \
|
||||
")")
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension")
|
||||
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_LAST_DIM_CONTIGUOUS(x)
|
||||
|
||||
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
|
||||
|
||||
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
|
||||
|
||||
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
|
||||
#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
|
||||
inline bool is_float8_tensor(const at::Tensor& tensor) {
|
||||
return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn || tensor.scalar_type() == at::ScalarType::Float8_e5m2;
|
||||
}
|
||||
#endif
|
||||
|
||||
struct cuda_error : public std::runtime_error {
|
||||
/**
|
||||
* @brief Constructs a `cuda_error` object with the given `message`.
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
[build-system]
|
||||
requires = [
|
||||
"setuptools>=75.0",
|
||||
"scikit-build-core>=0.10",
|
||||
"torch==2.5.1",
|
||||
"wheel",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "sgl-kernel"
|
||||
@@ -30,3 +29,11 @@ exclude = [
|
||||
"dist*",
|
||||
"tests*",
|
||||
]
|
||||
|
||||
[tool.scikit-build]
|
||||
cmake.build-type = "Release"
|
||||
minimum-version = "build-system.requires"
|
||||
|
||||
wheel.py-api = "cp39"
|
||||
wheel.license-files = []
|
||||
wheel.packages = ["python/sgl_kernel"]
|
||||
|
||||
32
sgl-kernel/pyproject_rocm.toml
Normal file
32
sgl-kernel/pyproject_rocm.toml
Normal file
@@ -0,0 +1,32 @@
|
||||
[build-system]
|
||||
requires = [
|
||||
"setuptools>=75.0",
|
||||
"scikit-build-core>=0.10",
|
||||
"torch==2.5.1",
|
||||
"wheel",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "sgl-kernel"
|
||||
version = "0.0.5.post3"
|
||||
description = "Kernel Library for SGLang"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
license = { file = "LICENSE" }
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Environment :: GPU :: NVIDIA CUDA"
|
||||
]
|
||||
dependencies = []
|
||||
|
||||
[project.urls]
|
||||
"Homepage" = "https://github.com/sgl-project/sglang/tree/main/sgl-kernel"
|
||||
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
|
||||
|
||||
[tool.wheel]
|
||||
exclude = [
|
||||
"dist*",
|
||||
"tests*",
|
||||
]
|
||||
@@ -12,7 +12,7 @@ def _top_k_renorm_probs_internal(
|
||||
probs = probs.float()
|
||||
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
||||
renorm_probs = torch.empty_like(probs)
|
||||
torch.ops.sgl_kernel.top_k_renorm_probs_wrapper(
|
||||
torch.ops.sgl_kernel.top_k_renorm_probs(
|
||||
probs,
|
||||
renorm_probs,
|
||||
maybe_top_k_arr,
|
||||
|
||||
17
sgl-kernel/rename_wheels.sh
Executable file
17
sgl-kernel/rename_wheels.sh
Executable file
@@ -0,0 +1,17 @@
|
||||
#!/usr/bin/env bash
|
||||
set -ex
|
||||
|
||||
WHEEL_DIR="dist"
|
||||
|
||||
wheel_files=($WHEEL_DIR/*.whl)
|
||||
for wheel in "${wheel_files[@]}"; do
|
||||
new_wheel="${wheel/linux/manylinux2014}"
|
||||
|
||||
if [[ "$wheel" != "$new_wheel" ]]; then
|
||||
echo "Renaming $wheel to $new_wheel"
|
||||
mv -- "$wheel" "$new_wheel"
|
||||
|
||||
fi
|
||||
done
|
||||
|
||||
echo "Wheel renaming completed."
|
||||
@@ -21,9 +21,6 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
root = Path(__file__).parent.resolve()
|
||||
|
||||
if "bdist_wheel" in sys.argv and "--plat-name" not in sys.argv:
|
||||
sys.argv.extend(["--plat-name", "manylinux2014_x86_64"])
|
||||
|
||||
|
||||
def _get_version():
|
||||
with open(root / "pyproject.toml") as f:
|
||||
|
||||
Reference in New Issue
Block a user