diff --git a/sgl-kernel/README.md b/sgl-kernel/README.md index aafc69a17..fc7a46417 100644 --- a/sgl-kernel/README.md +++ b/sgl-kernel/README.md @@ -48,7 +48,7 @@ Steps to add a new kernel: 1. Implement the kernel in [csrc](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc) 2. Expose the interface in [include/sgl_kernel_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/include/sgl_kernel_ops.h) 3. Create torch extension in [csrc/torch_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/csrc/torch_extension.cc) -4. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source +4. Update [CMakeLists.txt](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/CMakeLists.txt) to include new CUDA source 5. Expose Python interface in [python](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel) ### Build & Install diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py deleted file mode 100644 index f516758c1..000000000 --- a/sgl-kernel/setup.py +++ /dev/null @@ -1,250 +0,0 @@ -# Copyright 2025 SGLang Team. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import os -import shutil -import sys -from pathlib import Path - -import torch -from setuptools import find_packages, setup -from setuptools.command.build_py import build_py -from torch.utils.cpp_extension import BuildExtension, CUDAExtension - -root = Path(__file__).parent.resolve() - - -if "bdist_wheel" in sys.argv and "--plat-name" not in sys.argv: - sys.argv.extend(["--plat-name", "manylinux2014_x86_64"]) - - -def _get_cuda_version(): - if torch.version.cuda: - return tuple(map(int, torch.version.cuda.split("."))) - return (0, 0) - - -def _get_device_sm(): - if torch.cuda.is_available(): - major, minor = torch.cuda.get_device_capability() - return major * 10 + minor - return 0 - - -def _get_version(): - with open(root / "pyproject.toml") as f: - for line in f: - if line.startswith("version"): - return line.split("=")[1].strip().strip('"') - - -operator_namespace = "sgl_kernel" -cutlass_default = root / "3rdparty" / "cutlass" -cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default)) -flashinfer = root / "3rdparty" / "flashinfer" -deepgemm = root / "3rdparty" / "deepgemm" -include_dirs = [ - root / "include", - root / "csrc", - cutlass.resolve() / "include", - cutlass.resolve() / "tools" / "util" / "include", - flashinfer.resolve() / "include", - flashinfer.resolve() / "include" / "gemm", - flashinfer.resolve() / "csrc", - "cublas", -] - - -class CustomBuildPy(build_py): - def run(self): - self.copy_deepgemm_to_build_lib() - self.make_jit_include_symlinks() - build_py.run(self) - - def make_jit_include_symlinks(self): - # Make symbolic links of third-party include directories - build_include_dir = os.path.join(self.build_lib, "deep_gemm/include") - os.makedirs(build_include_dir, exist_ok=True) - - third_party_include_dirs = [ - cutlass.resolve() / "include" / "cute", - cutlass.resolve() / "include" / "cutlass", - ] - - for d in third_party_include_dirs: - dirname = str(d).split("/")[-1] - src_dir = d - dst_dir = f"{build_include_dir}/{dirname}" - assert os.path.exists(src_dir) - if os.path.exists(dst_dir): - assert os.path.islink(dst_dir) - os.unlink(dst_dir) - os.symlink(src_dir, dst_dir, target_is_directory=True) - - def copy_deepgemm_to_build_lib(self): - """ - This function copies DeepGemm to python's site-packages - """ - dst_dir = os.path.join(self.build_lib, "deep_gemm") - os.makedirs(dst_dir, exist_ok=True) - - # Copy deepgemm/deep_gemm to the build directory - src_dir = os.path.join(str(deepgemm.resolve()), "deep_gemm") - - # Remove existing directory if it exists - if os.path.exists(dst_dir): - shutil.rmtree(dst_dir) - - # Copy the directory - shutil.copytree(src_dir, dst_dir) - - -nvcc_flags = [ - "-DNDEBUG", - f"-DOPERATOR_NAMESPACE={operator_namespace}", - "-O3", - "-Xcompiler", - "-fPIC", - "-gencode=arch=compute_75,code=sm_75", - "-gencode=arch=compute_80,code=sm_80", - "-gencode=arch=compute_89,code=sm_89", - "-gencode=arch=compute_90,code=sm_90", - "-std=c++17", - "-DFLASHINFER_ENABLE_F16", - "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", - "-DCUTLASS_VERSIONS_GENERATED", - "-DCUTE_USE_PACKED_TUPLE=1", - "-DCUTLASS_TEST_LEVEL=0", - "-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1", - "-DCUTLASS_DEBUG_TRACE_LEVEL=0", - "--ptxas-options=-v", - "--expt-relaxed-constexpr", - "-Xcompiler=-Wconversion", - "-Xcompiler=-fno-strict-aliasing", -] -nvcc_flags_fp8 = [ - "-DFLASHINFER_ENABLE_FP8", - "-DFLASHINFER_ENABLE_FP8_E4M3", - "-DFLASHINFER_ENABLE_FP8_E5M2", -] - -sources = [ - "csrc/allreduce/trt_reduce_internal.cu", - "csrc/allreduce/trt_reduce_kernel.cu", - "csrc/attention/lightning_attention_decode_kernel.cu", - "csrc/elementwise/activation.cu", - "csrc/elementwise/fused_add_rms_norm_kernel.cu", - "csrc/elementwise/rope.cu", - "csrc/gemm/bmm_fp8.cu", - "csrc/gemm/cublas_grouped_gemm.cu", - "csrc/gemm/awq_kernel.cu", - "csrc/gemm/fp8_gemm_kernel.cu", - "csrc/gemm/fp8_blockwise_gemm_kernel.cu", - "csrc/gemm/int8_gemm_kernel.cu", - "csrc/gemm/nvfp4_quant_entry.cu", - "csrc/gemm/nvfp4_quant_kernels.cu", - "csrc/gemm/nvfp4_scaled_mm_entry.cu", - "csrc/gemm/nvfp4_scaled_mm_kernels.cu", - "csrc/gemm/per_token_group_quant_8bit.cu", - "csrc/gemm/per_token_quant_fp8.cu", - "csrc/gemm/per_tensor_quant_fp8.cu", - "csrc/moe/moe_align_kernel.cu", - "csrc/moe/moe_fused_gate.cu", - "csrc/moe/moe_topk_softmax_kernels.cu", - "csrc/speculative/eagle_utils.cu", - "csrc/speculative/speculative_sampling.cu", - "csrc/speculative/packbit.cu", - "csrc/torch_extension.cc", - "3rdparty/flashinfer/csrc/norm.cu", - "3rdparty/flashinfer/csrc/renorm.cu", - "3rdparty/flashinfer/csrc/sampling.cu", -] - -enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1" -enable_fp8 = os.getenv("SGL_KERNEL_ENABLE_FP8", "0") == "1" -enable_fp4 = os.getenv("SGL_KERNEL_ENABLE_FP4", "0") == "1" -enable_sm90a = os.getenv("SGL_KERNEL_ENABLE_SM90A", "0") == "1" -enable_sm100a = os.getenv("SGL_KERNEL_ENABLE_SM100A", "0") == "1" -cuda_version = _get_cuda_version() -sm_version = _get_device_sm() - -if torch.cuda.is_available(): - if cuda_version >= (12, 0) and sm_version >= 90: - nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") - if cuda_version >= (12, 8) and sm_version >= 100: - nvcc_flags.append("-gencode=arch=compute_100,code=sm_100") - nvcc_flags.append("-gencode=arch=compute_100a,code=sm_100a") - nvcc_flags.append("-DENABLE_NVFP4=1") - else: - nvcc_flags.append("-use_fast_math") - if sm_version >= 90: - nvcc_flags.extend(nvcc_flags_fp8) - if sm_version >= 80: - nvcc_flags.append("-DFLASHINFER_ENABLE_BF16") -else: - # compilation environment without GPU - if enable_sm100a: - nvcc_flags.append("-gencode=arch=compute_100a,code=sm_100a") - if enable_sm90a: - nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") - if enable_fp4: - nvcc_flags.append("-DENABLE_NVFP4=1") - if enable_fp8: - nvcc_flags.extend(nvcc_flags_fp8) - if enable_bf16: - nvcc_flags.append("-DFLASHINFER_ENABLE_BF16") - -for flag in [ - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", -]: - try: - torch.utils.cpp_extension.COMMON_NVCC_FLAGS.remove(flag) - except ValueError: - pass - -cxx_flags = ["-O3"] -libraries = ["c10", "torch", "torch_python", "cuda", "cublas"] -extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"] - -ext_modules = [ - CUDAExtension( - name="sgl_kernel.common_ops", - sources=sources, - include_dirs=include_dirs, - extra_compile_args={ - "nvcc": nvcc_flags, - "cxx": cxx_flags, - }, - libraries=libraries, - extra_link_args=extra_link_args, - py_limited_api=True, - ), -] - -setup( - name="sgl-kernel", - version=_get_version(), - packages=find_packages(where="python"), - package_dir={"": "python"}, - ext_modules=ext_modules, - cmdclass={ - "build_ext": BuildExtension.with_options(use_ninja=True), - "build_py": CustomBuildPy, - }, - options={"bdist_wheel": {"py_limited_api": "cp39"}}, -)