feat: refactor sgl-kernel and use TORCH_LIBRARY instead of PYBIND11_MODULE for custom ops (#3130)
Co-authored-by: yinfan.1024 <yinfan.1024@bytedance.com> Co-authored-by: yinfan98 <1106110035@qq.com> Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -26,10 +26,11 @@ Third-party libraries:
|
||||
Steps to add a new kernel:
|
||||
|
||||
1. Implement in [src/sgl-kernel/csrc/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/src/sgl-kernel/csrc)
|
||||
2. Expose interface in [csrc/sgl_kernel_ops.cu](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu) with pybind11
|
||||
3. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py)
|
||||
4. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py)
|
||||
5. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source
|
||||
2. Expose interface in [src/sgl-kernel/include/sgl_kernel_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/include/sgl_kernel_ops.h)
|
||||
3. Create torch extension in [src/sgl-kernel/torch_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/torch_extension.cc)
|
||||
4. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py)
|
||||
5. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py)
|
||||
6. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source
|
||||
|
||||
### Build & Install
|
||||
|
||||
@@ -37,8 +38,6 @@ Development build:
|
||||
|
||||
```bash
|
||||
make build
|
||||
pip3 install dist/*whl --force-reinstall --no-deps
|
||||
# Or use: make install (runs pip install -e .)
|
||||
```
|
||||
|
||||
### Testing & Benchmarking
|
||||
|
||||
@@ -38,6 +38,7 @@ def _get_version():
|
||||
return line.split("=")[1].strip().strip('"')
|
||||
|
||||
|
||||
operator_namespace = "sgl_kernels"
|
||||
cutlass_default = root / "3rdparty" / "cutlass"
|
||||
cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default))
|
||||
flashinfer = root / "3rdparty" / "flashinfer"
|
||||
@@ -45,15 +46,19 @@ turbomind = root / "3rdparty" / "turbomind"
|
||||
include_dirs = [
|
||||
cutlass.resolve() / "include",
|
||||
cutlass.resolve() / "tools" / "util" / "include",
|
||||
root / "src" / "sgl-kernel" / "include",
|
||||
root / "src" / "sgl-kernel" / "csrc",
|
||||
flashinfer.resolve() / "include",
|
||||
flashinfer.resolve() / "include" / "gemm",
|
||||
flashinfer.resolve() / "csrc",
|
||||
"cublas",
|
||||
"cublasLt",
|
||||
turbomind.resolve(),
|
||||
turbomind.resolve() / "src",
|
||||
]
|
||||
nvcc_flags = [
|
||||
"-DNDEBUG",
|
||||
f"-DOPERATOR_NAMESPACE={operator_namespace}",
|
||||
"-O3",
|
||||
"-Xcompiler",
|
||||
"-fPIC",
|
||||
@@ -72,13 +77,13 @@ nvcc_flags_fp8 = [
|
||||
]
|
||||
|
||||
sources = [
|
||||
"src/sgl-kernel/torch_extension.cc",
|
||||
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
|
||||
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
|
||||
"src/sgl-kernel/csrc/moe_align_kernel.cu",
|
||||
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
|
||||
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
|
||||
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
|
||||
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
|
||||
"src/sgl-kernel/csrc/rotary_embedding.cu",
|
||||
"src/sgl-kernel/csrc/fused_add_rms_norm.cu",
|
||||
"3rdparty/flashinfer/csrc/activation.cu",
|
||||
@@ -125,7 +130,7 @@ for flag in [
|
||||
pass
|
||||
|
||||
cxx_flags = ["-O3"]
|
||||
libraries = ["c10", "torch", "torch_python", "cuda"]
|
||||
libraries = ["c10", "torch", "torch_python", "cuda", "cublas", "cublasLt"]
|
||||
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
|
||||
|
||||
ext_modules = [
|
||||
@@ -139,6 +144,7 @@ ext_modules = [
|
||||
},
|
||||
libraries=libraries,
|
||||
extra_link_args=extra_link_args,
|
||||
py_limited_api=True,
|
||||
),
|
||||
]
|
||||
|
||||
@@ -149,6 +155,7 @@ setup(
|
||||
package_dir={"": "src"},
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
options={"bdist_wheel": {"py_limited_api": "cp39"}},
|
||||
)
|
||||
|
||||
_update_wheel_platform_tag()
|
||||
|
||||
@@ -1,7 +1,25 @@
|
||||
#pragma once
|
||||
#include <Python.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#define _CONCAT(A, B) A##B
|
||||
#define CONCAT(A, B) _CONCAT(A, B)
|
||||
|
||||
#define _STRINGIFY(A) #A
|
||||
#define STRINGIFY(A) _STRINGIFY(A)
|
||||
|
||||
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
||||
|
||||
#define REGISTER_EXTENSION(NAME) \
|
||||
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
||||
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
||||
return PyModule_Create(&module); \
|
||||
}
|
||||
|
||||
// trt_reduce
|
||||
using fptr_t = int64_t;
|
||||
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers,
|
||||
@@ -67,9 +85,18 @@ void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at:
|
||||
int64_t cuda_stream);
|
||||
|
||||
// top k renorm probs
|
||||
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
|
||||
void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_k_arr,
|
||||
unsigned int top_k_val, int64_t cuda_stream);
|
||||
|
||||
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
|
||||
// wrapper for binding
|
||||
inline void top_k_renorm_probs_wrapper(at::Tensor probs, at::Tensor renorm_probs,
|
||||
std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val,
|
||||
int64_t cuda_stream) {
|
||||
top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast<unsigned int>(top_k_val), cuda_stream);
|
||||
}
|
||||
|
||||
// top p renorm probs
|
||||
void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_p_arr,
|
||||
double top_p_val, int64_t cuda_stream);
|
||||
@@ -84,48 +111,3 @@ void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_sample
|
||||
void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success,
|
||||
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic,
|
||||
int64_t cuda_stream);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
// trt_reduce
|
||||
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
|
||||
m.def("dispose", &dispose, "dispose custom allreduce meta");
|
||||
m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)");
|
||||
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "custom all reduce get graph ipc meta");
|
||||
m.def("register_graph_buffers", ®ister_graph_buffers, "custom all reduce register graph buffers");
|
||||
// moe_align_block_size
|
||||
m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)");
|
||||
// sampling_scaling_penalties
|
||||
m.def("sampling_scaling_penalties", &sampling_scaling_penalties, "Sampling scaling penalties (CUDA)");
|
||||
// int8_scaled_mm
|
||||
m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)");
|
||||
// lightning_attention_decode
|
||||
m.def("lightning_attention_decode", &lightning_attention_decode, "Lightning Attention Ddecode (CUDA)");
|
||||
// rotary embedding
|
||||
m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)");
|
||||
// rms norm
|
||||
m.def("rmsnorm", &rmsnorm, "RMSNorm (CUDA)");
|
||||
// fused rms norm
|
||||
m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused Add RMSNorm (CUDA)");
|
||||
// gemma rms norm
|
||||
m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma RMSNorm (CUDA)");
|
||||
// fused gemma rms norm
|
||||
m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, "Gemma Fused Add RMSNorm (CUDA)");
|
||||
// silu and mul
|
||||
m.def("silu_and_mul", &silu_and_mul, "Silu and Mul (CUDA)");
|
||||
// gelu tanh and mul
|
||||
m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Gelu Tanh and Mul (CUDA)");
|
||||
// gelu and mul
|
||||
m.def("gelu_and_mul", &gelu_and_mul, "Gelu and Mul (CUDA)");
|
||||
// bmm fp8
|
||||
m.def("bmm_fp8", &bmm_fp8, "BMM FP8 (CUDA)");
|
||||
// min p sampling from probs
|
||||
m.def("min_p_sampling_from_probs", &min_p_sampling_from_probs, "Min P Sampling From Probs (CUDA)");
|
||||
// top k renorm probs
|
||||
m.def("top_k_renorm_probs", &top_k_renorm_probs, "Top K Renorm Probs (CUDA)");
|
||||
// top p renorm probs
|
||||
m.def("top_p_renorm_probs", &top_p_renorm_probs, "Top P Renorm Probs (CUDA)");
|
||||
// top k top p sampling from probs
|
||||
m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs, "Top K Top P Sampling From Probs (CUDA)");
|
||||
// top p sampling from probs
|
||||
m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs, "Top P Sampling From Probs (CUDA)");
|
||||
}
|
||||
@@ -1,9 +1,12 @@
|
||||
#pragma once
|
||||
#include <cuda_runtime.h>
|
||||
#include <pytorch_extension_utils.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "sgl_kernels_ops.h"
|
||||
|
||||
struct cuda_error : public std::runtime_error {
|
||||
/**
|
||||
* @brief Constructs a `cuda_error` object with the given `message`.
|
||||
@@ -1,41 +1,8 @@
|
||||
import os
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import sgl_kernel.ops._kernels
|
||||
import torch
|
||||
from sgl_kernel.ops._kernels import all_reduce as _all_reduce
|
||||
from sgl_kernel.ops._kernels import bmm_fp8 as _bmm_fp8
|
||||
from sgl_kernel.ops._kernels import dispose as _dispose
|
||||
from sgl_kernel.ops._kernels import fused_add_rmsnorm as _fused_add_rmsnorm
|
||||
from sgl_kernel.ops._kernels import gelu_and_mul as _gelu_and_mul
|
||||
from sgl_kernel.ops._kernels import gelu_tanh_and_mul as _gelu_tanh_and_mul
|
||||
from sgl_kernel.ops._kernels import gemma_fused_add_rmsnorm as _gemma_fused_add_rmsnorm
|
||||
from sgl_kernel.ops._kernels import gemma_rmsnorm as _gemma_rmsnorm
|
||||
from sgl_kernel.ops._kernels import (
|
||||
get_graph_buffer_ipc_meta as _get_graph_buffer_ipc_meta,
|
||||
)
|
||||
from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar
|
||||
from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm
|
||||
from sgl_kernel.ops._kernels import (
|
||||
lightning_attention_decode as _lightning_attention_decode,
|
||||
)
|
||||
from sgl_kernel.ops._kernels import (
|
||||
min_p_sampling_from_probs as _min_p_sampling_from_probs,
|
||||
)
|
||||
from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size
|
||||
from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers
|
||||
from sgl_kernel.ops._kernels import rmsnorm as _rmsnorm
|
||||
from sgl_kernel.ops._kernels import rotary_embedding as _rotary_embedding
|
||||
from sgl_kernel.ops._kernels import (
|
||||
sampling_scaling_penalties as _sampling_scaling_penalties,
|
||||
)
|
||||
from sgl_kernel.ops._kernels import silu_and_mul as _silu_and_mul
|
||||
from sgl_kernel.ops._kernels import top_k_renorm_probs as _top_k_renorm_probs
|
||||
from sgl_kernel.ops._kernels import (
|
||||
top_k_top_p_sampling_from_probs as _top_k_top_p_sampling_from_probs,
|
||||
)
|
||||
from sgl_kernel.ops._kernels import top_p_renorm_probs as _top_p_renorm_probs
|
||||
from sgl_kernel.ops._kernels import (
|
||||
top_p_sampling_from_probs as _top_p_sampling_from_probs,
|
||||
)
|
||||
from sgl_kernel.ops.utils import (
|
||||
_get_cache_buf,
|
||||
_get_cuda_stream,
|
||||
@@ -46,25 +13,25 @@ from sgl_kernel.ops.utils import (
|
||||
def init_custom_reduce(
|
||||
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
|
||||
):
|
||||
return _init_custom_ar(
|
||||
return torch.ops.sgl_kernels.init_custom_ar(
|
||||
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
|
||||
)
|
||||
|
||||
|
||||
def custom_dispose(fa):
|
||||
_dispose(fa)
|
||||
torch.ops.sgl_kernels.dispose(fa)
|
||||
|
||||
|
||||
def custom_reduce(fa, inp, out):
|
||||
_all_reduce(fa, inp, out)
|
||||
torch.ops.sgl_kernels.all_reduce(fa, inp, out)
|
||||
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa):
|
||||
return _get_graph_buffer_ipc_meta(fa)
|
||||
return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
|
||||
def register_graph_buffers(fa, handles, offsets):
|
||||
_register_graph_buffers(fa, handles, offsets)
|
||||
torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
|
||||
def moe_align_block_size(
|
||||
@@ -77,7 +44,7 @@ def moe_align_block_size(
|
||||
token_cnts_buffer,
|
||||
cumsum_buffer,
|
||||
):
|
||||
_moe_align_block_size(
|
||||
torch.ops.sgl_kernels.moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
@@ -90,11 +57,11 @@ def moe_align_block_size(
|
||||
|
||||
|
||||
def sampling_scaling_penalties(logits, scaling_penalties):
|
||||
return _sampling_scaling_penalties(logits, scaling_penalties)
|
||||
return torch.ops.sgl_kernels.sampling_scaling_penalties(logits, scaling_penalties)
|
||||
|
||||
|
||||
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||
return _int8_scaled_mm(
|
||||
return torch.ops.sgl_kernels.int8_scaled_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scales_a,
|
||||
@@ -105,11 +72,15 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||
|
||||
|
||||
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
|
||||
_lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv)
|
||||
torch.ops.sgl_kernels.lightning_attention_decode(
|
||||
q, k, v, past_kv, slope, output, new_kv
|
||||
)
|
||||
|
||||
|
||||
def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox):
|
||||
return _rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox)
|
||||
return torch.ops.sgl_kernels.rotary_embedding(
|
||||
positions, query, key, head_size, cos_sin_cache, is_neox
|
||||
)
|
||||
|
||||
|
||||
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
|
||||
@@ -123,7 +94,7 @@ def rmsnorm(
|
||||
with input.device as device:
|
||||
if out is None:
|
||||
out = torch.empty_like(input)
|
||||
_rmsnorm(out, input, weight, eps, _get_cuda_stream(device))
|
||||
torch.ops.sgl_kernels.rmsnorm(out, input, weight, eps, _get_cuda_stream(device))
|
||||
return out
|
||||
|
||||
|
||||
@@ -131,7 +102,9 @@ def fused_add_rmsnorm(
|
||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
) -> None:
|
||||
with input.device as device:
|
||||
_fused_add_rmsnorm(input, residual, weight, eps, _get_cuda_stream(device))
|
||||
torch.ops.sgl_kernels.fused_add_rmsnorm(
|
||||
input, residual, weight, eps, _get_cuda_stream(device)
|
||||
)
|
||||
|
||||
|
||||
def gemma_rmsnorm(
|
||||
@@ -143,7 +116,9 @@ def gemma_rmsnorm(
|
||||
with input.device as device:
|
||||
if out is None:
|
||||
out = torch.empty_like(input)
|
||||
_gemma_rmsnorm(out, input, weight, eps, _get_cuda_stream(device))
|
||||
torch.ops.sgl_kernels.gemma_rmsnorm(
|
||||
out, input, weight, eps, _get_cuda_stream(device)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@@ -151,7 +126,9 @@ def gemma_fused_add_rmsnorm(
|
||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
) -> None:
|
||||
with input.device as device:
|
||||
_gemma_fused_add_rmsnorm(input, residual, weight, eps, _get_cuda_stream(device))
|
||||
torch.ops.sgl_kernels.gemma_fused_add_rmsnorm(
|
||||
input, residual, weight, eps, _get_cuda_stream(device)
|
||||
)
|
||||
|
||||
|
||||
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
|
||||
@@ -176,7 +153,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
dtype=input.dtype,
|
||||
)
|
||||
with input.device as device:
|
||||
_silu_and_mul(out, input, _get_cuda_stream(device))
|
||||
torch.ops.sgl_kernels.silu_and_mul(out, input, _get_cuda_stream(device))
|
||||
return out
|
||||
|
||||
|
||||
@@ -192,7 +169,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
|
||||
dtype=input.dtype,
|
||||
)
|
||||
with input.device as device:
|
||||
_gelu_tanh_and_mul(out, input, _get_cuda_stream(device))
|
||||
torch.ops.sgl_kernels.gelu_tanh_and_mul(out, input, _get_cuda_stream(device))
|
||||
return out
|
||||
|
||||
|
||||
@@ -208,7 +185,7 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
dtype=input.dtype,
|
||||
)
|
||||
with input.device as device:
|
||||
_gelu_and_mul(out, input, _get_cuda_stream(device))
|
||||
torch.ops.sgl_kernels.gelu_and_mul(out, input, _get_cuda_stream(device))
|
||||
return out
|
||||
|
||||
|
||||
@@ -222,7 +199,7 @@ def _bmm_fp8_internal(
|
||||
) -> None:
|
||||
with A.device as device:
|
||||
cublas_handle = torch.cuda.current_blas_handle()
|
||||
_bmm_fp8(
|
||||
torch.ops.sgl_kernels.bmm_fp8(
|
||||
A,
|
||||
B,
|
||||
D,
|
||||
@@ -262,7 +239,7 @@ def _top_k_renorm_probs_internal(
|
||||
probs = probs.float()
|
||||
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
||||
renorm_probs = torch.empty_like(probs)
|
||||
_top_k_renorm_probs(
|
||||
torch.ops.sgl_kernels.top_k_renorm_probs_wrapper(
|
||||
probs,
|
||||
renorm_probs,
|
||||
maybe_top_k_arr,
|
||||
@@ -293,7 +270,7 @@ def _top_p_renorm_probs_internal(
|
||||
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||
)
|
||||
renorm_probs = torch.empty_like(probs)
|
||||
_top_p_renorm_probs(
|
||||
torch.ops.sgl_kernels.top_p_renorm_probs(
|
||||
probs,
|
||||
renorm_probs,
|
||||
maybe_top_p_arr,
|
||||
@@ -328,7 +305,7 @@ def _top_p_sampling_from_probs_internal(
|
||||
)
|
||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
||||
_top_p_sampling_from_probs(
|
||||
torch.ops.sgl_kernels.top_p_sampling_from_probs(
|
||||
probs,
|
||||
uniform_samples,
|
||||
samples,
|
||||
@@ -374,7 +351,7 @@ def _top_k_top_p_sampling_from_probs_internal(
|
||||
)
|
||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
||||
_top_k_top_p_sampling_from_probs(
|
||||
torch.ops.sgl_kernels.top_k_top_p_sampling_from_probs(
|
||||
probs,
|
||||
uniform_samples,
|
||||
samples,
|
||||
@@ -432,7 +409,7 @@ def _min_p_sampling_from_probs_internal(
|
||||
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
|
||||
)
|
||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||
_min_p_sampling_from_probs(
|
||||
torch.ops.sgl_kernels.min_p_sampling_from_probs(
|
||||
probs,
|
||||
uniform_samples,
|
||||
samples,
|
||||
|
||||
119
sgl-kernel/src/sgl-kernel/torch_extension.cc
Normal file
119
sgl-kernel/src/sgl-kernel/torch_extension.cc
Normal file
@@ -0,0 +1,119 @@
|
||||
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "sgl_kernels_ops.h"
|
||||
|
||||
TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
|
||||
// trt_reduce
|
||||
m.def(
|
||||
"init_custom_ar(int rank_id, int world_size, Tensor rank_data, int[] buffers, int[] tmp_result_buffers, int[] "
|
||||
"barrier_in, int[] barrier_out) -> int");
|
||||
m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
|
||||
|
||||
m.def("dispose", &dispose);
|
||||
|
||||
m.def("all_reduce(int fa, Tensor inp, Tensor! out) -> ()");
|
||||
m.impl("all_reduce", torch::kCUDA, &all_reduce);
|
||||
|
||||
m.def("get_graph_buffer_ipc_meta(int fa) -> (int[], int[])");
|
||||
m.impl("get_graph_buffer_ipc_meta", torch::kCUDA, &get_graph_buffer_ipc_meta);
|
||||
|
||||
m.def("register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()");
|
||||
m.impl("register_graph_buffers", torch::kCUDA, ®ister_graph_buffers);
|
||||
|
||||
// moe_align_block_size
|
||||
m.def(
|
||||
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
|
||||
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
|
||||
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||
|
||||
// sampling_scaling_penalties
|
||||
m.def("sampling_scaling_penalties(Tensor logits, Tensor scaling_penalties) -> Tensor");
|
||||
m.impl("sampling_scaling_penalties", torch::kCUDA, &sampling_scaling_penalties);
|
||||
|
||||
// int8_scaled_mm
|
||||
m.def(
|
||||
"int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
|
||||
"bias) -> Tensor");
|
||||
m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm);
|
||||
|
||||
// lightning_attention_decode
|
||||
m.def(
|
||||
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
|
||||
"new_kv) -> ()");
|
||||
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
|
||||
|
||||
// rotary embedding
|
||||
m.def(
|
||||
"rotary_embedding(Tensor positions, Tensor! query, Tensor! key, int head_size, Tensor cos_sin_cache, bool "
|
||||
"is_neox) -> ()");
|
||||
m.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
|
||||
|
||||
// rms norm
|
||||
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);
|
||||
|
||||
// fused rms norm
|
||||
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||
m.impl("fused_add_rmsnorm", torch::kCUDA, &fused_add_rmsnorm);
|
||||
|
||||
// gemma rms norm
|
||||
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||
m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm);
|
||||
|
||||
// fused gemma rms norm
|
||||
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);
|
||||
|
||||
// silu and mul
|
||||
m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
||||
m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
|
||||
|
||||
// gelu tanh and mul
|
||||
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
||||
m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
|
||||
|
||||
// gelu and mul
|
||||
m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
||||
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
|
||||
|
||||
// bmm fp8
|
||||
m.def(
|
||||
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int "
|
||||
"cublas_handle, int cuda_stream) -> ()");
|
||||
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
|
||||
|
||||
// min p sampling from probs
|
||||
m.def(
|
||||
"min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float "
|
||||
"min_p_val, bool deterministic, int cuda_stream) -> ()");
|
||||
m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs);
|
||||
|
||||
// top k renorm probs
|
||||
m.def(
|
||||
"top_k_renorm_probs_wrapper(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int "
|
||||
"cuda_stream) -> ()");
|
||||
m.impl("top_k_renorm_probs_wrapper", torch::kCUDA, &top_k_renorm_probs_wrapper);
|
||||
|
||||
// top p renorm probs
|
||||
m.def(
|
||||
"top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int "
|
||||
"cuda_stream) -> ()");
|
||||
m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs);
|
||||
|
||||
// top k top p sampling from probs
|
||||
m.def(
|
||||
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
|
||||
"maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int "
|
||||
"cuda_stream) -> ()");
|
||||
m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs);
|
||||
|
||||
// top p sampling from probs
|
||||
m.def(
|
||||
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
|
||||
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
|
||||
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(_kernels)
|
||||
Reference in New Issue
Block a user