Rename files in sgl kernel to avoid nested folder structure (#4213)

Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
Lianmin Zheng
2025-03-08 22:54:51 -08:00
committed by GitHub
parent ee132a4515
commit 8abf74e3c9
47 changed files with 184 additions and 199 deletions

View File

@@ -13,12 +13,9 @@
# limitations under the License.
# ==============================================================================
import multiprocessing
import os
import sys
from pathlib import Path
import torch
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
@@ -35,16 +32,16 @@ def _get_version():
return line.split("=")[1].strip().strip('"')
operator_namespace = "sgl_kernels"
operator_namespace = "sgl_kernel"
include_dirs = [
root / "src" / "sgl-kernel" / "include",
root / "src" / "sgl-kernel" / "csrc",
root / "include",
root / "csrc",
]
sources = [
"src/sgl-kernel/torch_extension_rocm.cc",
"src/sgl-kernel/csrc/allreduce/custom_all_reduce.hip",
"src/sgl-kernel/csrc/moe/moe_align_kernel.cu",
"csrc/allreduce/custom_all_reduce.hip",
"csrc/moe/moe_align_kernel.cu",
"csrc/torch_extension_rocm.cc",
]
cxx_flags = ["-O3"]
@@ -64,26 +61,27 @@ hipcc_flags = [
"-DENABLE_FP8",
]
ext_modules = [
CUDAExtension(
name="sgl_kernel.common_ops",
sources=sources,
include_dirs=include_dirs,
extra_compile_args={
"nvcc": hipcc_flags,
"cxx": cxx_flags,
},
libraries=libraries,
extra_link_args=extra_link_args,
py_limited_api=True,
),
]
setup(
name="sgl-kernel",
version=_get_version(),
packages=find_packages(),
package_dir={"": "src"},
ext_modules=[
CUDAExtension(
name="sgl_kernel.ops._kernels",
sources=sources,
include_dirs=include_dirs,
extra_compile_args={
"nvcc": hipcc_flags,
"cxx": cxx_flags,
},
libraries=libraries,
extra_link_args=extra_link_args,
py_limited_api=True,
),
],
package_dir={"": "python"},
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension.with_options(use_ninja=True)},
options={"bdist_wheel": {"py_limited_api": "cp39"}},
install_requires=["torch"],
)