diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 72d188e71..d60167435 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -62,6 +62,23 @@ nvcc_flags = [ "-DFLASHINFER_ENABLE_F16", ] +sources = [ + "src/sgl-kernel/csrc/trt_reduce_internal.cu", + "src/sgl-kernel/csrc/trt_reduce_kernel.cu", + "src/sgl-kernel/csrc/moe_align_kernel.cu", + "src/sgl-kernel/csrc/int8_gemm_kernel.cu", + "src/sgl-kernel/csrc/sampling_scaling_penalties.cu", + "src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu", + "src/sgl-kernel/csrc/sgl_kernel_ops.cu", + "src/sgl-kernel/csrc/rotary_embedding.cu", + "3rdparty/flashinfer/csrc/activation.cu", + "3rdparty/flashinfer/csrc/bmm_fp8.cu", + "3rdparty/flashinfer/csrc/group_gemm.cu", + "3rdparty/flashinfer/csrc/norm.cu", + "3rdparty/flashinfer/csrc/sampling.cu", + "3rdparty/flashinfer/csrc/renorm.cu", +] + enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1" enable_fp8 = os.getenv("SGL_KERNEL_ENABLE_FP8", "0") == "1" enable_sm90a = os.getenv("SGL_KERNEL_ENABLE_SM90A", "0") == "1" @@ -71,6 +88,7 @@ sm_version = _get_device_sm() if torch.cuda.is_available(): if cuda_version >= (12, 0) and sm_version >= 90: nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") + sources.append("3rdparty/flashinfer/csrc/group_gemm_sm90.cu") if sm_version >= 90: nvcc_flags.extend( [ @@ -85,6 +103,7 @@ else: # compilation environment without GPU if enable_sm90a: nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") + sources.append("3rdparty/flashinfer/csrc/group_gemm_sm90.cu") if enable_fp8: nvcc_flags.extend( [ @@ -110,26 +129,11 @@ for flag in [ cxx_flags = ["-O3"] libraries = ["c10", "torch", "torch_python", "cuda"] extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"] + ext_modules = [ CUDAExtension( name="sgl_kernel.ops._kernels", - sources=[ - "src/sgl-kernel/csrc/trt_reduce_internal.cu", - "src/sgl-kernel/csrc/trt_reduce_kernel.cu", - "src/sgl-kernel/csrc/moe_align_kernel.cu", - "src/sgl-kernel/csrc/int8_gemm_kernel.cu", - "src/sgl-kernel/csrc/sampling_scaling_penalties.cu", - "src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu", - "src/sgl-kernel/csrc/sgl_kernel_ops.cu", - "src/sgl-kernel/csrc/rotary_embedding.cu", - "3rdparty/flashinfer/csrc/activation.cu", - "3rdparty/flashinfer/csrc/bmm_fp8.cu", - "3rdparty/flashinfer/csrc/group_gemm.cu", - "3rdparty/flashinfer/csrc/group_gemm_sm90.cu", - "3rdparty/flashinfer/csrc/norm.cu", - "3rdparty/flashinfer/csrc/sampling.cu", - "3rdparty/flashinfer/csrc/renorm.cu", - ], + sources=sources, include_dirs=include_dirs, extra_compile_args={ "nvcc": nvcc_flags,