From ec3ee0289d762cc66870e8073028353f5b894ab6 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Fri, 28 Mar 2025 17:23:51 -0700 Subject: [PATCH] fix sgl-kernel cu118 build (#4872) --- sgl-kernel/build.sh | 3 ++- sgl-kernel/csrc/gemm/awq_kernel.cu | 5 +++++ sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu | 5 +++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/sgl-kernel/build.sh b/sgl-kernel/build.sh index 7d5ee6b6b..3936b4d26 100755 --- a/sgl-kernel/build.sh +++ b/sgl-kernel/build.sh @@ -25,5 +25,6 @@ docker run --rm \ ln -s /usr/local/cuda-${CUDA_VERSION}/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/x86_64-linux-gnu/libcuda.so && \ cd /sgl-kernel && \ ls -la ${PYTHON_ROOT_PATH}/lib/python${PYTHON_VERSION}/site-packages/wheel/ && \ - PYTHONPATH=${PYTHON_ROOT_PATH}/lib/python${PYTHON_VERSION}/site-packages ${PYTHON_ROOT_PATH}/bin/python -m uv build --wheel -Cbuild-dir=build . --color=always + PYTHONPATH=${PYTHON_ROOT_PATH}/lib/python${PYTHON_VERSION}/site-packages ${PYTHON_ROOT_PATH}/bin/python -m uv build --wheel -Cbuild-dir=build . --color=always && \ + ./rename_wheels.sh " diff --git a/sgl-kernel/csrc/gemm/awq_kernel.cu b/sgl-kernel/csrc/gemm/awq_kernel.cu index 0c144d40f..188f0cb3f 100644 --- a/sgl-kernel/csrc/gemm/awq_kernel.cu +++ b/sgl-kernel/csrc/gemm/awq_kernel.cu @@ -1,6 +1,7 @@ // Adapted from // https://github.com/vllm-project/vllm/blob/eb59b5a6cba6727d3727c0372258db9002f687c1/csrc/quantization/awq/gemm_kernels.cu#L350 #include +#include #include #include #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 @@ -79,6 +80,7 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) { } __device__ uint4 dequantize_s4_to_bf16x2(uint32_t const& source) { +#if CUDA_VERSION >= 12000 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 uint4 result; uint32_t* h = reinterpret_cast(&result); @@ -118,6 +120,7 @@ __device__ uint4 dequantize_s4_to_bf16x2(uint32_t const& source) { assert(false); return {}; #endif +#endif } template @@ -128,6 +131,7 @@ __global__ void __launch_bounds__(256) dequantize_weights( OutputT* __restrict__ output, int group_size, int qweight_cols) { +#if CUDA_VERSION >= 12000 int col = blockIdx.x * blockDim.x + threadIdx.x; int row = blockIdx.y * blockDim.y + threadIdx.y; @@ -174,6 +178,7 @@ __global__ void __launch_bounds__(256) dequantize_weights( static_assert(sizeof(uint4) == 8 * sizeof(OutputT), "Memory layout mismatch"); *reinterpret_cast(output_ptr) = weight_raw; } +#endif } torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch::Tensor qzeros) { diff --git a/sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu b/sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu index fa96442df..5024d20af 100644 --- a/sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu +++ b/sgl-kernel/csrc/gemm/nvfp4_quant_kernels.cu @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -56,6 +57,7 @@ constexpr int CVT_FP4_SF_VEC_SIZE = 16; // Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { // PTX instructions used here requires sm100a. +#if CUDA_VERSION >= 12080 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL) uint32_t val; asm volatile( @@ -83,11 +85,13 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { #else return 0; #endif +#endif } // Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { // PTX instructions used here requires sm100a. +#if CUDA_VERSION >= 12080 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL) uint32_t val; asm volatile( @@ -115,6 +119,7 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { #else return 0; #endif +#endif } // Fast reciprocal.