feat: refactor sgl-kernel and use TORCH_LIBRARY instead of PYBIND11_MODULE for custom ops (#3130)

Co-authored-by: yinfan.1024 <yinfan.1024@bytedance.com>
Co-authored-by: yinfan98 <1106110035@qq.com>
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
yinfan98
2025-01-26 02:55:08 +08:00
committed by GitHub
parent 896c07441e
commit 9286740eff
7 changed files with 198 additions and 111 deletions

View File

@@ -38,6 +38,7 @@ def _get_version():
return line.split("=")[1].strip().strip('"')
operator_namespace = "sgl_kernels"
cutlass_default = root / "3rdparty" / "cutlass"
cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default))
flashinfer = root / "3rdparty" / "flashinfer"
@@ -45,15 +46,19 @@ turbomind = root / "3rdparty" / "turbomind"
include_dirs = [
cutlass.resolve() / "include",
cutlass.resolve() / "tools" / "util" / "include",
root / "src" / "sgl-kernel" / "include",
root / "src" / "sgl-kernel" / "csrc",
flashinfer.resolve() / "include",
flashinfer.resolve() / "include" / "gemm",
flashinfer.resolve() / "csrc",
"cublas",
"cublasLt",
turbomind.resolve(),
turbomind.resolve() / "src",
]
nvcc_flags = [
"-DNDEBUG",
f"-DOPERATOR_NAMESPACE={operator_namespace}",
"-O3",
"-Xcompiler",
"-fPIC",
@@ -72,13 +77,13 @@ nvcc_flags_fp8 = [
]
sources = [
"src/sgl-kernel/torch_extension.cc",
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/moe_align_kernel.cu",
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
"src/sgl-kernel/csrc/rotary_embedding.cu",
"src/sgl-kernel/csrc/fused_add_rms_norm.cu",
"3rdparty/flashinfer/csrc/activation.cu",
@@ -125,7 +130,7 @@ for flag in [
pass
cxx_flags = ["-O3"]
libraries = ["c10", "torch", "torch_python", "cuda"]
libraries = ["c10", "torch", "torch_python", "cuda", "cublas", "cublasLt"]
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
ext_modules = [
@@ -139,6 +144,7 @@ ext_modules = [
},
libraries=libraries,
extra_link_args=extra_link_args,
py_limited_api=True,
),
]
@@ -149,6 +155,7 @@ setup(
package_dir={"": "src"},
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension},
options={"bdist_wheel": {"py_limited_api": "cp39"}},
)
_update_wheel_platform_tag()