From 9286740eff9b735a005e14cf5dfae986c75e3533 Mon Sep 17 00:00:00 2001 From: yinfan98 <1106310035@qq.com> Date: Sun, 26 Jan 2025 02:55:08 +0800 Subject: [PATCH] feat: refactor sgl-kernel and use TORCH_LIBRARY instead of PYBIND11_MODULE for custom ops (#3130) Co-authored-by: yinfan.1024 Co-authored-by: yinfan98 <1106110035@qq.com> Co-authored-by: Yineng Zhang --- sgl-kernel/developer_guide.md | 11 +- sgl-kernel/setup.py | 11 +- .../sgl_kernels_ops.h} | 72 ++++------- .../{csrc => include}/trt_reduce_internal.cuh | 0 .../src/sgl-kernel/{csrc => include}/utils.h | 3 + sgl-kernel/src/sgl-kernel/ops/__init__.py | 93 ++++++-------- sgl-kernel/src/sgl-kernel/torch_extension.cc | 119 ++++++++++++++++++ 7 files changed, 198 insertions(+), 111 deletions(-) rename sgl-kernel/src/sgl-kernel/{csrc/sgl_kernel_ops.cu => include/sgl_kernels_ops.h} (65%) rename sgl-kernel/src/sgl-kernel/{csrc => include}/trt_reduce_internal.cuh (100%) rename sgl-kernel/src/sgl-kernel/{csrc => include}/utils.h (98%) create mode 100644 sgl-kernel/src/sgl-kernel/torch_extension.cc diff --git a/sgl-kernel/developer_guide.md b/sgl-kernel/developer_guide.md index 91e93ff75..26b68535c 100644 --- a/sgl-kernel/developer_guide.md +++ b/sgl-kernel/developer_guide.md @@ -26,10 +26,11 @@ Third-party libraries: Steps to add a new kernel: 1. Implement in [src/sgl-kernel/csrc/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/src/sgl-kernel/csrc) -2. Expose interface in [csrc/sgl_kernel_ops.cu](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu) with pybind11 -3. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py) -4. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py) -5. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source +2. Expose interface in [src/sgl-kernel/include/sgl_kernel_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/include/sgl_kernel_ops.h) +3. Create torch extension in [src/sgl-kernel/torch_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/torch_extension.cc) +4. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py) +5. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py) +6. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source ### Build & Install @@ -37,8 +38,6 @@ Development build: ```bash make build -pip3 install dist/*whl --force-reinstall --no-deps -# Or use: make install (runs pip install -e .) ``` ### Testing & Benchmarking diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 56c5b1bb5..95b040fe1 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -38,6 +38,7 @@ def _get_version(): return line.split("=")[1].strip().strip('"') +operator_namespace = "sgl_kernels" cutlass_default = root / "3rdparty" / "cutlass" cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default)) flashinfer = root / "3rdparty" / "flashinfer" @@ -45,15 +46,19 @@ turbomind = root / "3rdparty" / "turbomind" include_dirs = [ cutlass.resolve() / "include", cutlass.resolve() / "tools" / "util" / "include", + root / "src" / "sgl-kernel" / "include", root / "src" / "sgl-kernel" / "csrc", flashinfer.resolve() / "include", flashinfer.resolve() / "include" / "gemm", flashinfer.resolve() / "csrc", + "cublas", + "cublasLt", turbomind.resolve(), turbomind.resolve() / "src", ] nvcc_flags = [ "-DNDEBUG", + f"-DOPERATOR_NAMESPACE={operator_namespace}", "-O3", "-Xcompiler", "-fPIC", @@ -72,13 +77,13 @@ nvcc_flags_fp8 = [ ] sources = [ + "src/sgl-kernel/torch_extension.cc", "src/sgl-kernel/csrc/trt_reduce_internal.cu", "src/sgl-kernel/csrc/trt_reduce_kernel.cu", "src/sgl-kernel/csrc/moe_align_kernel.cu", "src/sgl-kernel/csrc/int8_gemm_kernel.cu", "src/sgl-kernel/csrc/sampling_scaling_penalties.cu", "src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu", - "src/sgl-kernel/csrc/sgl_kernel_ops.cu", "src/sgl-kernel/csrc/rotary_embedding.cu", "src/sgl-kernel/csrc/fused_add_rms_norm.cu", "3rdparty/flashinfer/csrc/activation.cu", @@ -125,7 +130,7 @@ for flag in [ pass cxx_flags = ["-O3"] -libraries = ["c10", "torch", "torch_python", "cuda"] +libraries = ["c10", "torch", "torch_python", "cuda", "cublas", "cublasLt"] extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"] ext_modules = [ @@ -139,6 +144,7 @@ ext_modules = [ }, libraries=libraries, extra_link_args=extra_link_args, + py_limited_api=True, ), ] @@ -149,6 +155,7 @@ setup( package_dir={"": "src"}, ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension}, + options={"bdist_wheel": {"py_limited_api": "cp39"}}, ) _update_wheel_platform_tag() diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h similarity index 65% rename from sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu rename to sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h index 876d62b7e..91e350895 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu +++ b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h @@ -1,7 +1,25 @@ +#pragma once +#include +#include + #include #include "utils.h" +#define _CONCAT(A, B) A##B +#define CONCAT(A, B) _CONCAT(A, B) + +#define _STRINGIFY(A) #A +#define STRINGIFY(A) _STRINGIFY(A) + +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) + +#define REGISTER_EXTENSION(NAME) \ + PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ + static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \ + return PyModule_Create(&module); \ + } + // trt_reduce using fptr_t = int64_t; fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector& buffers, @@ -67,9 +85,18 @@ void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at: int64_t cuda_stream); // top k renorm probs +// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension. void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_k_arr, unsigned int top_k_val, int64_t cuda_stream); +// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension. +// wrapper for binding +inline void top_k_renorm_probs_wrapper(at::Tensor probs, at::Tensor renorm_probs, + std::optional maybe_top_k_arr, int64_t top_k_val, + int64_t cuda_stream) { + top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast(top_k_val), cuda_stream); +} + // top p renorm probs void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_p_arr, double top_p_val, int64_t cuda_stream); @@ -84,48 +111,3 @@ void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_sample void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success, std::optional maybe_top_p_arr, double top_p_val, bool deterministic, int64_t cuda_stream); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - // trt_reduce - m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)"); - m.def("dispose", &dispose, "dispose custom allreduce meta"); - m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)"); - m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "custom all reduce get graph ipc meta"); - m.def("register_graph_buffers", ®ister_graph_buffers, "custom all reduce register graph buffers"); - // moe_align_block_size - m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)"); - // sampling_scaling_penalties - m.def("sampling_scaling_penalties", &sampling_scaling_penalties, "Sampling scaling penalties (CUDA)"); - // int8_scaled_mm - m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)"); - // lightning_attention_decode - m.def("lightning_attention_decode", &lightning_attention_decode, "Lightning Attention Ddecode (CUDA)"); - // rotary embedding - m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)"); - // rms norm - m.def("rmsnorm", &rmsnorm, "RMSNorm (CUDA)"); - // fused rms norm - m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused Add RMSNorm (CUDA)"); - // gemma rms norm - m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma RMSNorm (CUDA)"); - // fused gemma rms norm - m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, "Gemma Fused Add RMSNorm (CUDA)"); - // silu and mul - m.def("silu_and_mul", &silu_and_mul, "Silu and Mul (CUDA)"); - // gelu tanh and mul - m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Gelu Tanh and Mul (CUDA)"); - // gelu and mul - m.def("gelu_and_mul", &gelu_and_mul, "Gelu and Mul (CUDA)"); - // bmm fp8 - m.def("bmm_fp8", &bmm_fp8, "BMM FP8 (CUDA)"); - // min p sampling from probs - m.def("min_p_sampling_from_probs", &min_p_sampling_from_probs, "Min P Sampling From Probs (CUDA)"); - // top k renorm probs - m.def("top_k_renorm_probs", &top_k_renorm_probs, "Top K Renorm Probs (CUDA)"); - // top p renorm probs - m.def("top_p_renorm_probs", &top_p_renorm_probs, "Top P Renorm Probs (CUDA)"); - // top k top p sampling from probs - m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs, "Top K Top P Sampling From Probs (CUDA)"); - // top p sampling from probs - m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs, "Top P Sampling From Probs (CUDA)"); -} diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh b/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh rename to sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh diff --git a/sgl-kernel/src/sgl-kernel/csrc/utils.h b/sgl-kernel/src/sgl-kernel/include/utils.h similarity index 98% rename from sgl-kernel/src/sgl-kernel/csrc/utils.h rename to sgl-kernel/src/sgl-kernel/include/utils.h index ed802d4fd..1cca35d5c 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/utils.h +++ b/sgl-kernel/src/sgl-kernel/include/utils.h @@ -1,9 +1,12 @@ #pragma once +#include #include #include #include +#include "sgl_kernels_ops.h" + struct cuda_error : public std::runtime_error { /** * @brief Constructs a `cuda_error` object with the given `message`. diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index cd69eb3c2..3a21ced87 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -1,41 +1,8 @@ +import os from typing import Optional, Tuple, Union +import sgl_kernel.ops._kernels import torch -from sgl_kernel.ops._kernels import all_reduce as _all_reduce -from sgl_kernel.ops._kernels import bmm_fp8 as _bmm_fp8 -from sgl_kernel.ops._kernels import dispose as _dispose -from sgl_kernel.ops._kernels import fused_add_rmsnorm as _fused_add_rmsnorm -from sgl_kernel.ops._kernels import gelu_and_mul as _gelu_and_mul -from sgl_kernel.ops._kernels import gelu_tanh_and_mul as _gelu_tanh_and_mul -from sgl_kernel.ops._kernels import gemma_fused_add_rmsnorm as _gemma_fused_add_rmsnorm -from sgl_kernel.ops._kernels import gemma_rmsnorm as _gemma_rmsnorm -from sgl_kernel.ops._kernels import ( - get_graph_buffer_ipc_meta as _get_graph_buffer_ipc_meta, -) -from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar -from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm -from sgl_kernel.ops._kernels import ( - lightning_attention_decode as _lightning_attention_decode, -) -from sgl_kernel.ops._kernels import ( - min_p_sampling_from_probs as _min_p_sampling_from_probs, -) -from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size -from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers -from sgl_kernel.ops._kernels import rmsnorm as _rmsnorm -from sgl_kernel.ops._kernels import rotary_embedding as _rotary_embedding -from sgl_kernel.ops._kernels import ( - sampling_scaling_penalties as _sampling_scaling_penalties, -) -from sgl_kernel.ops._kernels import silu_and_mul as _silu_and_mul -from sgl_kernel.ops._kernels import top_k_renorm_probs as _top_k_renorm_probs -from sgl_kernel.ops._kernels import ( - top_k_top_p_sampling_from_probs as _top_k_top_p_sampling_from_probs, -) -from sgl_kernel.ops._kernels import top_p_renorm_probs as _top_p_renorm_probs -from sgl_kernel.ops._kernels import ( - top_p_sampling_from_probs as _top_p_sampling_from_probs, -) from sgl_kernel.ops.utils import ( _get_cache_buf, _get_cuda_stream, @@ -46,25 +13,25 @@ from sgl_kernel.ops.utils import ( def init_custom_reduce( rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out ): - return _init_custom_ar( + return torch.ops.sgl_kernels.init_custom_ar( rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out ) def custom_dispose(fa): - _dispose(fa) + torch.ops.sgl_kernels.dispose(fa) def custom_reduce(fa, inp, out): - _all_reduce(fa, inp, out) + torch.ops.sgl_kernels.all_reduce(fa, inp, out) def get_graph_buffer_ipc_meta(fa): - return _get_graph_buffer_ipc_meta(fa) + return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa) def register_graph_buffers(fa, handles, offsets): - _register_graph_buffers(fa, handles, offsets) + torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets) def moe_align_block_size( @@ -77,7 +44,7 @@ def moe_align_block_size( token_cnts_buffer, cumsum_buffer, ): - _moe_align_block_size( + torch.ops.sgl_kernels.moe_align_block_size( topk_ids, num_experts, block_size, @@ -90,11 +57,11 @@ def moe_align_block_size( def sampling_scaling_penalties(logits, scaling_penalties): - return _sampling_scaling_penalties(logits, scaling_penalties) + return torch.ops.sgl_kernels.sampling_scaling_penalties(logits, scaling_penalties) def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): - return _int8_scaled_mm( + return torch.ops.sgl_kernels.int8_scaled_mm( mat_a, mat_b, scales_a, @@ -105,11 +72,15 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv): - _lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv) + torch.ops.sgl_kernels.lightning_attention_decode( + q, k, v, past_kv, slope, output, new_kv + ) def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox): - return _rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox) + return torch.ops.sgl_kernels.rotary_embedding( + positions, query, key, head_size, cos_sin_cache, is_neox + ) # These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer @@ -123,7 +94,7 @@ def rmsnorm( with input.device as device: if out is None: out = torch.empty_like(input) - _rmsnorm(out, input, weight, eps, _get_cuda_stream(device)) + torch.ops.sgl_kernels.rmsnorm(out, input, weight, eps, _get_cuda_stream(device)) return out @@ -131,7 +102,9 @@ def fused_add_rmsnorm( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 ) -> None: with input.device as device: - _fused_add_rmsnorm(input, residual, weight, eps, _get_cuda_stream(device)) + torch.ops.sgl_kernels.fused_add_rmsnorm( + input, residual, weight, eps, _get_cuda_stream(device) + ) def gemma_rmsnorm( @@ -143,7 +116,9 @@ def gemma_rmsnorm( with input.device as device: if out is None: out = torch.empty_like(input) - _gemma_rmsnorm(out, input, weight, eps, _get_cuda_stream(device)) + torch.ops.sgl_kernels.gemma_rmsnorm( + out, input, weight, eps, _get_cuda_stream(device) + ) return out @@ -151,7 +126,9 @@ def gemma_fused_add_rmsnorm( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 ) -> None: with input.device as device: - _gemma_fused_add_rmsnorm(input, residual, weight, eps, _get_cuda_stream(device)) + torch.ops.sgl_kernels.gemma_fused_add_rmsnorm( + input, residual, weight, eps, _get_cuda_stream(device) + ) def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None: @@ -176,7 +153,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: dtype=input.dtype, ) with input.device as device: - _silu_and_mul(out, input, _get_cuda_stream(device)) + torch.ops.sgl_kernels.silu_and_mul(out, input, _get_cuda_stream(device)) return out @@ -192,7 +169,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te dtype=input.dtype, ) with input.device as device: - _gelu_tanh_and_mul(out, input, _get_cuda_stream(device)) + torch.ops.sgl_kernels.gelu_tanh_and_mul(out, input, _get_cuda_stream(device)) return out @@ -208,7 +185,7 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: dtype=input.dtype, ) with input.device as device: - _gelu_and_mul(out, input, _get_cuda_stream(device)) + torch.ops.sgl_kernels.gelu_and_mul(out, input, _get_cuda_stream(device)) return out @@ -222,7 +199,7 @@ def _bmm_fp8_internal( ) -> None: with A.device as device: cublas_handle = torch.cuda.current_blas_handle() - _bmm_fp8( + torch.ops.sgl_kernels.bmm_fp8( A, B, D, @@ -262,7 +239,7 @@ def _top_k_renorm_probs_internal( probs = probs.float() maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None renorm_probs = torch.empty_like(probs) - _top_k_renorm_probs( + torch.ops.sgl_kernels.top_k_renorm_probs_wrapper( probs, renorm_probs, maybe_top_k_arr, @@ -293,7 +270,7 @@ def _top_p_renorm_probs_internal( maybe_top_p_arr.float() if maybe_top_p_arr is not None else None ) renorm_probs = torch.empty_like(probs) - _top_p_renorm_probs( + torch.ops.sgl_kernels.top_p_renorm_probs( probs, renorm_probs, maybe_top_p_arr, @@ -328,7 +305,7 @@ def _top_p_sampling_from_probs_internal( ) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) success = torch.empty(probs.size(0), dtype=torch.bool, device=device) - _top_p_sampling_from_probs( + torch.ops.sgl_kernels.top_p_sampling_from_probs( probs, uniform_samples, samples, @@ -374,7 +351,7 @@ def _top_k_top_p_sampling_from_probs_internal( ) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) success = torch.empty(probs.size(0), dtype=torch.bool, device=device) - _top_k_top_p_sampling_from_probs( + torch.ops.sgl_kernels.top_k_top_p_sampling_from_probs( probs, uniform_samples, samples, @@ -432,7 +409,7 @@ def _min_p_sampling_from_probs_internal( maybe_min_p_arr.float() if maybe_min_p_arr is not None else None ) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) - _min_p_sampling_from_probs( + torch.ops.sgl_kernels.min_p_sampling_from_probs( probs, uniform_samples, samples, diff --git a/sgl-kernel/src/sgl-kernel/torch_extension.cc b/sgl-kernel/src/sgl-kernel/torch_extension.cc new file mode 100644 index 000000000..f8a061c15 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/torch_extension.cc @@ -0,0 +1,119 @@ + +#include +#include + +#include "sgl_kernels_ops.h" + +TORCH_LIBRARY_EXPAND(sgl_kernels, m) { + // trt_reduce + m.def( + "init_custom_ar(int rank_id, int world_size, Tensor rank_data, int[] buffers, int[] tmp_result_buffers, int[] " + "barrier_in, int[] barrier_out) -> int"); + m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); + + m.def("dispose", &dispose); + + m.def("all_reduce(int fa, Tensor inp, Tensor! out) -> ()"); + m.impl("all_reduce", torch::kCUDA, &all_reduce); + + m.def("get_graph_buffer_ipc_meta(int fa) -> (int[], int[])"); + m.impl("get_graph_buffer_ipc_meta", torch::kCUDA, &get_graph_buffer_ipc_meta); + + m.def("register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()"); + m.impl("register_graph_buffers", torch::kCUDA, ®ister_graph_buffers); + + // moe_align_block_size + m.def( + "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " + "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"); + m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); + + // sampling_scaling_penalties + m.def("sampling_scaling_penalties(Tensor logits, Tensor scaling_penalties) -> Tensor"); + m.impl("sampling_scaling_penalties", torch::kCUDA, &sampling_scaling_penalties); + + // int8_scaled_mm + m.def( + "int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? " + "bias) -> Tensor"); + m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm); + + // lightning_attention_decode + m.def( + "lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! " + "new_kv) -> ()"); + m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode); + + // rotary embedding + m.def( + "rotary_embedding(Tensor positions, Tensor! query, Tensor! key, int head_size, Tensor cos_sin_cache, bool " + "is_neox) -> ()"); + m.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); + + // rms norm + m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("rmsnorm", torch::kCUDA, &rmsnorm); + + // fused rms norm + m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("fused_add_rmsnorm", torch::kCUDA, &fused_add_rmsnorm); + + // gemma rms norm + m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm); + + // fused gemma rms norm + m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm); + + // silu and mul + m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); + + // gelu tanh and mul + m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); + + // gelu and mul + m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); + + // bmm fp8 + m.def( + "bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int " + "cublas_handle, int cuda_stream) -> ()"); + m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8); + + // min p sampling from probs + m.def( + "min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float " + "min_p_val, bool deterministic, int cuda_stream) -> ()"); + m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs); + + // top k renorm probs + m.def( + "top_k_renorm_probs_wrapper(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int " + "cuda_stream) -> ()"); + m.impl("top_k_renorm_probs_wrapper", torch::kCUDA, &top_k_renorm_probs_wrapper); + + // top p renorm probs + m.def( + "top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int " + "cuda_stream) -> ()"); + m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs); + + // top k top p sampling from probs + m.def( + "top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " + "maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int " + "cuda_stream) -> ()"); + m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs); + + // top p sampling from probs + m.def( + "top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " + "maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()"); + m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs); +} + +REGISTER_EXTENSION(_kernels)