Rename files in sgl kernel to avoid nested folder structure (#4213)
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -13,12 +13,9 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from setuptools import find_packages, setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
@@ -35,16 +32,16 @@ def _get_version():
|
||||
return line.split("=")[1].strip().strip('"')
|
||||
|
||||
|
||||
operator_namespace = "sgl_kernels"
|
||||
operator_namespace = "sgl_kernel"
|
||||
include_dirs = [
|
||||
root / "src" / "sgl-kernel" / "include",
|
||||
root / "src" / "sgl-kernel" / "csrc",
|
||||
root / "include",
|
||||
root / "csrc",
|
||||
]
|
||||
|
||||
sources = [
|
||||
"src/sgl-kernel/torch_extension_rocm.cc",
|
||||
"src/sgl-kernel/csrc/allreduce/custom_all_reduce.hip",
|
||||
"src/sgl-kernel/csrc/moe/moe_align_kernel.cu",
|
||||
"csrc/allreduce/custom_all_reduce.hip",
|
||||
"csrc/moe/moe_align_kernel.cu",
|
||||
"csrc/torch_extension_rocm.cc",
|
||||
]
|
||||
|
||||
cxx_flags = ["-O3"]
|
||||
@@ -64,26 +61,27 @@ hipcc_flags = [
|
||||
"-DENABLE_FP8",
|
||||
]
|
||||
|
||||
ext_modules = [
|
||||
CUDAExtension(
|
||||
name="sgl_kernel.common_ops",
|
||||
sources=sources,
|
||||
include_dirs=include_dirs,
|
||||
extra_compile_args={
|
||||
"nvcc": hipcc_flags,
|
||||
"cxx": cxx_flags,
|
||||
},
|
||||
libraries=libraries,
|
||||
extra_link_args=extra_link_args,
|
||||
py_limited_api=True,
|
||||
),
|
||||
]
|
||||
|
||||
setup(
|
||||
name="sgl-kernel",
|
||||
version=_get_version(),
|
||||
packages=find_packages(),
|
||||
package_dir={"": "src"},
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
name="sgl_kernel.ops._kernels",
|
||||
sources=sources,
|
||||
include_dirs=include_dirs,
|
||||
extra_compile_args={
|
||||
"nvcc": hipcc_flags,
|
||||
"cxx": cxx_flags,
|
||||
},
|
||||
libraries=libraries,
|
||||
extra_link_args=extra_link_args,
|
||||
py_limited_api=True,
|
||||
),
|
||||
],
|
||||
package_dir={"": "python"},
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={"build_ext": BuildExtension.with_options(use_ninja=True)},
|
||||
options={"bdist_wheel": {"py_limited_api": "cp39"}},
|
||||
install_requires=["torch"],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user