feat: refactor sgl-kernel and use TORCH_LIBRARY instead of PYBIND11_MODULE for custom ops (#3130)
Co-authored-by: yinfan.1024 <yinfan.1024@bytedance.com> Co-authored-by: yinfan98 <1106110035@qq.com> Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -38,6 +38,7 @@ def _get_version():
|
||||
return line.split("=")[1].strip().strip('"')
|
||||
|
||||
|
||||
operator_namespace = "sgl_kernels"
|
||||
cutlass_default = root / "3rdparty" / "cutlass"
|
||||
cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default))
|
||||
flashinfer = root / "3rdparty" / "flashinfer"
|
||||
@@ -45,15 +46,19 @@ turbomind = root / "3rdparty" / "turbomind"
|
||||
include_dirs = [
|
||||
cutlass.resolve() / "include",
|
||||
cutlass.resolve() / "tools" / "util" / "include",
|
||||
root / "src" / "sgl-kernel" / "include",
|
||||
root / "src" / "sgl-kernel" / "csrc",
|
||||
flashinfer.resolve() / "include",
|
||||
flashinfer.resolve() / "include" / "gemm",
|
||||
flashinfer.resolve() / "csrc",
|
||||
"cublas",
|
||||
"cublasLt",
|
||||
turbomind.resolve(),
|
||||
turbomind.resolve() / "src",
|
||||
]
|
||||
nvcc_flags = [
|
||||
"-DNDEBUG",
|
||||
f"-DOPERATOR_NAMESPACE={operator_namespace}",
|
||||
"-O3",
|
||||
"-Xcompiler",
|
||||
"-fPIC",
|
||||
@@ -72,13 +77,13 @@ nvcc_flags_fp8 = [
|
||||
]
|
||||
|
||||
sources = [
|
||||
"src/sgl-kernel/torch_extension.cc",
|
||||
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
|
||||
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
|
||||
"src/sgl-kernel/csrc/moe_align_kernel.cu",
|
||||
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
|
||||
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
|
||||
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
|
||||
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
|
||||
"src/sgl-kernel/csrc/rotary_embedding.cu",
|
||||
"src/sgl-kernel/csrc/fused_add_rms_norm.cu",
|
||||
"3rdparty/flashinfer/csrc/activation.cu",
|
||||
@@ -125,7 +130,7 @@ for flag in [
|
||||
pass
|
||||
|
||||
cxx_flags = ["-O3"]
|
||||
libraries = ["c10", "torch", "torch_python", "cuda"]
|
||||
libraries = ["c10", "torch", "torch_python", "cuda", "cublas", "cublasLt"]
|
||||
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
|
||||
|
||||
ext_modules = [
|
||||
@@ -139,6 +144,7 @@ ext_modules = [
|
||||
},
|
||||
libraries=libraries,
|
||||
extra_link_args=extra_link_args,
|
||||
py_limited_api=True,
|
||||
),
|
||||
]
|
||||
|
||||
@@ -149,6 +155,7 @@ setup(
|
||||
package_dir={"": "src"},
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
options={"bdist_wheel": {"py_limited_api": "cp39"}},
|
||||
)
|
||||
|
||||
_update_wheel_platform_tag()
|
||||
|
||||
Reference in New Issue
Block a user