feat: refactor sgl-kernel and use TORCH_LIBRARY instead of PYBIND11_MODULE for custom ops (#3130)
Co-authored-by: yinfan.1024 <yinfan.1024@bytedance.com> Co-authored-by: yinfan98 <1106110035@qq.com> Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -26,10 +26,11 @@ Third-party libraries:
|
|||||||
Steps to add a new kernel:
|
Steps to add a new kernel:
|
||||||
|
|
||||||
1. Implement in [src/sgl-kernel/csrc/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/src/sgl-kernel/csrc)
|
1. Implement in [src/sgl-kernel/csrc/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/src/sgl-kernel/csrc)
|
||||||
2. Expose interface in [csrc/sgl_kernel_ops.cu](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu) with pybind11
|
2. Expose interface in [src/sgl-kernel/include/sgl_kernel_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/include/sgl_kernel_ops.h)
|
||||||
3. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py)
|
3. Create torch extension in [src/sgl-kernel/torch_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/torch_extension.cc)
|
||||||
4. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py)
|
4. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py)
|
||||||
5. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source
|
5. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py)
|
||||||
|
6. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source
|
||||||
|
|
||||||
### Build & Install
|
### Build & Install
|
||||||
|
|
||||||
@@ -37,8 +38,6 @@ Development build:
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
make build
|
make build
|
||||||
pip3 install dist/*whl --force-reinstall --no-deps
|
|
||||||
# Or use: make install (runs pip install -e .)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Testing & Benchmarking
|
### Testing & Benchmarking
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ def _get_version():
|
|||||||
return line.split("=")[1].strip().strip('"')
|
return line.split("=")[1].strip().strip('"')
|
||||||
|
|
||||||
|
|
||||||
|
operator_namespace = "sgl_kernels"
|
||||||
cutlass_default = root / "3rdparty" / "cutlass"
|
cutlass_default = root / "3rdparty" / "cutlass"
|
||||||
cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default))
|
cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default))
|
||||||
flashinfer = root / "3rdparty" / "flashinfer"
|
flashinfer = root / "3rdparty" / "flashinfer"
|
||||||
@@ -45,15 +46,19 @@ turbomind = root / "3rdparty" / "turbomind"
|
|||||||
include_dirs = [
|
include_dirs = [
|
||||||
cutlass.resolve() / "include",
|
cutlass.resolve() / "include",
|
||||||
cutlass.resolve() / "tools" / "util" / "include",
|
cutlass.resolve() / "tools" / "util" / "include",
|
||||||
|
root / "src" / "sgl-kernel" / "include",
|
||||||
root / "src" / "sgl-kernel" / "csrc",
|
root / "src" / "sgl-kernel" / "csrc",
|
||||||
flashinfer.resolve() / "include",
|
flashinfer.resolve() / "include",
|
||||||
flashinfer.resolve() / "include" / "gemm",
|
flashinfer.resolve() / "include" / "gemm",
|
||||||
flashinfer.resolve() / "csrc",
|
flashinfer.resolve() / "csrc",
|
||||||
|
"cublas",
|
||||||
|
"cublasLt",
|
||||||
turbomind.resolve(),
|
turbomind.resolve(),
|
||||||
turbomind.resolve() / "src",
|
turbomind.resolve() / "src",
|
||||||
]
|
]
|
||||||
nvcc_flags = [
|
nvcc_flags = [
|
||||||
"-DNDEBUG",
|
"-DNDEBUG",
|
||||||
|
f"-DOPERATOR_NAMESPACE={operator_namespace}",
|
||||||
"-O3",
|
"-O3",
|
||||||
"-Xcompiler",
|
"-Xcompiler",
|
||||||
"-fPIC",
|
"-fPIC",
|
||||||
@@ -72,13 +77,13 @@ nvcc_flags_fp8 = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
sources = [
|
sources = [
|
||||||
|
"src/sgl-kernel/torch_extension.cc",
|
||||||
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
|
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
|
||||||
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
|
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
|
||||||
"src/sgl-kernel/csrc/moe_align_kernel.cu",
|
"src/sgl-kernel/csrc/moe_align_kernel.cu",
|
||||||
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
|
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
|
||||||
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
|
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
|
||||||
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
|
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
|
||||||
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
|
|
||||||
"src/sgl-kernel/csrc/rotary_embedding.cu",
|
"src/sgl-kernel/csrc/rotary_embedding.cu",
|
||||||
"src/sgl-kernel/csrc/fused_add_rms_norm.cu",
|
"src/sgl-kernel/csrc/fused_add_rms_norm.cu",
|
||||||
"3rdparty/flashinfer/csrc/activation.cu",
|
"3rdparty/flashinfer/csrc/activation.cu",
|
||||||
@@ -125,7 +130,7 @@ for flag in [
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
cxx_flags = ["-O3"]
|
cxx_flags = ["-O3"]
|
||||||
libraries = ["c10", "torch", "torch_python", "cuda"]
|
libraries = ["c10", "torch", "torch_python", "cuda", "cublas", "cublasLt"]
|
||||||
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
|
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
|
||||||
|
|
||||||
ext_modules = [
|
ext_modules = [
|
||||||
@@ -139,6 +144,7 @@ ext_modules = [
|
|||||||
},
|
},
|
||||||
libraries=libraries,
|
libraries=libraries,
|
||||||
extra_link_args=extra_link_args,
|
extra_link_args=extra_link_args,
|
||||||
|
py_limited_api=True,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -149,6 +155,7 @@ setup(
|
|||||||
package_dir={"": "src"},
|
package_dir={"": "src"},
|
||||||
ext_modules=ext_modules,
|
ext_modules=ext_modules,
|
||||||
cmdclass={"build_ext": BuildExtension},
|
cmdclass={"build_ext": BuildExtension},
|
||||||
|
options={"bdist_wheel": {"py_limited_api": "cp39"}},
|
||||||
)
|
)
|
||||||
|
|
||||||
_update_wheel_platform_tag()
|
_update_wheel_platform_tag()
|
||||||
|
|||||||
@@ -1,7 +1,25 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <Python.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
|
|
||||||
|
#define _CONCAT(A, B) A##B
|
||||||
|
#define CONCAT(A, B) _CONCAT(A, B)
|
||||||
|
|
||||||
|
#define _STRINGIFY(A) #A
|
||||||
|
#define STRINGIFY(A) _STRINGIFY(A)
|
||||||
|
|
||||||
|
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
||||||
|
|
||||||
|
#define REGISTER_EXTENSION(NAME) \
|
||||||
|
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
||||||
|
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
||||||
|
return PyModule_Create(&module); \
|
||||||
|
}
|
||||||
|
|
||||||
// trt_reduce
|
// trt_reduce
|
||||||
using fptr_t = int64_t;
|
using fptr_t = int64_t;
|
||||||
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers,
|
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers,
|
||||||
@@ -67,9 +85,18 @@ void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at:
|
|||||||
int64_t cuda_stream);
|
int64_t cuda_stream);
|
||||||
|
|
||||||
// top k renorm probs
|
// top k renorm probs
|
||||||
|
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
|
||||||
void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_k_arr,
|
void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_k_arr,
|
||||||
unsigned int top_k_val, int64_t cuda_stream);
|
unsigned int top_k_val, int64_t cuda_stream);
|
||||||
|
|
||||||
|
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
|
||||||
|
// wrapper for binding
|
||||||
|
inline void top_k_renorm_probs_wrapper(at::Tensor probs, at::Tensor renorm_probs,
|
||||||
|
std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val,
|
||||||
|
int64_t cuda_stream) {
|
||||||
|
top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast<unsigned int>(top_k_val), cuda_stream);
|
||||||
|
}
|
||||||
|
|
||||||
// top p renorm probs
|
// top p renorm probs
|
||||||
void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_p_arr,
|
void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_p_arr,
|
||||||
double top_p_val, int64_t cuda_stream);
|
double top_p_val, int64_t cuda_stream);
|
||||||
@@ -84,48 +111,3 @@ void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_sample
|
|||||||
void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success,
|
void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success,
|
||||||
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic,
|
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic,
|
||||||
int64_t cuda_stream);
|
int64_t cuda_stream);
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
// trt_reduce
|
|
||||||
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
|
|
||||||
m.def("dispose", &dispose, "dispose custom allreduce meta");
|
|
||||||
m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)");
|
|
||||||
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "custom all reduce get graph ipc meta");
|
|
||||||
m.def("register_graph_buffers", ®ister_graph_buffers, "custom all reduce register graph buffers");
|
|
||||||
// moe_align_block_size
|
|
||||||
m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)");
|
|
||||||
// sampling_scaling_penalties
|
|
||||||
m.def("sampling_scaling_penalties", &sampling_scaling_penalties, "Sampling scaling penalties (CUDA)");
|
|
||||||
// int8_scaled_mm
|
|
||||||
m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)");
|
|
||||||
// lightning_attention_decode
|
|
||||||
m.def("lightning_attention_decode", &lightning_attention_decode, "Lightning Attention Ddecode (CUDA)");
|
|
||||||
// rotary embedding
|
|
||||||
m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)");
|
|
||||||
// rms norm
|
|
||||||
m.def("rmsnorm", &rmsnorm, "RMSNorm (CUDA)");
|
|
||||||
// fused rms norm
|
|
||||||
m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused Add RMSNorm (CUDA)");
|
|
||||||
// gemma rms norm
|
|
||||||
m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma RMSNorm (CUDA)");
|
|
||||||
// fused gemma rms norm
|
|
||||||
m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, "Gemma Fused Add RMSNorm (CUDA)");
|
|
||||||
// silu and mul
|
|
||||||
m.def("silu_and_mul", &silu_and_mul, "Silu and Mul (CUDA)");
|
|
||||||
// gelu tanh and mul
|
|
||||||
m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Gelu Tanh and Mul (CUDA)");
|
|
||||||
// gelu and mul
|
|
||||||
m.def("gelu_and_mul", &gelu_and_mul, "Gelu and Mul (CUDA)");
|
|
||||||
// bmm fp8
|
|
||||||
m.def("bmm_fp8", &bmm_fp8, "BMM FP8 (CUDA)");
|
|
||||||
// min p sampling from probs
|
|
||||||
m.def("min_p_sampling_from_probs", &min_p_sampling_from_probs, "Min P Sampling From Probs (CUDA)");
|
|
||||||
// top k renorm probs
|
|
||||||
m.def("top_k_renorm_probs", &top_k_renorm_probs, "Top K Renorm Probs (CUDA)");
|
|
||||||
// top p renorm probs
|
|
||||||
m.def("top_p_renorm_probs", &top_p_renorm_probs, "Top P Renorm Probs (CUDA)");
|
|
||||||
// top k top p sampling from probs
|
|
||||||
m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs, "Top K Top P Sampling From Probs (CUDA)");
|
|
||||||
// top p sampling from probs
|
|
||||||
m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs, "Top P Sampling From Probs (CUDA)");
|
|
||||||
}
|
|
||||||
@@ -1,9 +1,12 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
#include <cuda_runtime.h>
|
||||||
#include <pytorch_extension_utils.h>
|
#include <pytorch_extension_utils.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "sgl_kernels_ops.h"
|
||||||
|
|
||||||
struct cuda_error : public std::runtime_error {
|
struct cuda_error : public std::runtime_error {
|
||||||
/**
|
/**
|
||||||
* @brief Constructs a `cuda_error` object with the given `message`.
|
* @brief Constructs a `cuda_error` object with the given `message`.
|
||||||
@@ -1,41 +1,8 @@
|
|||||||
|
import os
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import sgl_kernel.ops._kernels
|
||||||
import torch
|
import torch
|
||||||
from sgl_kernel.ops._kernels import all_reduce as _all_reduce
|
|
||||||
from sgl_kernel.ops._kernels import bmm_fp8 as _bmm_fp8
|
|
||||||
from sgl_kernel.ops._kernels import dispose as _dispose
|
|
||||||
from sgl_kernel.ops._kernels import fused_add_rmsnorm as _fused_add_rmsnorm
|
|
||||||
from sgl_kernel.ops._kernels import gelu_and_mul as _gelu_and_mul
|
|
||||||
from sgl_kernel.ops._kernels import gelu_tanh_and_mul as _gelu_tanh_and_mul
|
|
||||||
from sgl_kernel.ops._kernels import gemma_fused_add_rmsnorm as _gemma_fused_add_rmsnorm
|
|
||||||
from sgl_kernel.ops._kernels import gemma_rmsnorm as _gemma_rmsnorm
|
|
||||||
from sgl_kernel.ops._kernels import (
|
|
||||||
get_graph_buffer_ipc_meta as _get_graph_buffer_ipc_meta,
|
|
||||||
)
|
|
||||||
from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar
|
|
||||||
from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm
|
|
||||||
from sgl_kernel.ops._kernels import (
|
|
||||||
lightning_attention_decode as _lightning_attention_decode,
|
|
||||||
)
|
|
||||||
from sgl_kernel.ops._kernels import (
|
|
||||||
min_p_sampling_from_probs as _min_p_sampling_from_probs,
|
|
||||||
)
|
|
||||||
from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size
|
|
||||||
from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers
|
|
||||||
from sgl_kernel.ops._kernels import rmsnorm as _rmsnorm
|
|
||||||
from sgl_kernel.ops._kernels import rotary_embedding as _rotary_embedding
|
|
||||||
from sgl_kernel.ops._kernels import (
|
|
||||||
sampling_scaling_penalties as _sampling_scaling_penalties,
|
|
||||||
)
|
|
||||||
from sgl_kernel.ops._kernels import silu_and_mul as _silu_and_mul
|
|
||||||
from sgl_kernel.ops._kernels import top_k_renorm_probs as _top_k_renorm_probs
|
|
||||||
from sgl_kernel.ops._kernels import (
|
|
||||||
top_k_top_p_sampling_from_probs as _top_k_top_p_sampling_from_probs,
|
|
||||||
)
|
|
||||||
from sgl_kernel.ops._kernels import top_p_renorm_probs as _top_p_renorm_probs
|
|
||||||
from sgl_kernel.ops._kernels import (
|
|
||||||
top_p_sampling_from_probs as _top_p_sampling_from_probs,
|
|
||||||
)
|
|
||||||
from sgl_kernel.ops.utils import (
|
from sgl_kernel.ops.utils import (
|
||||||
_get_cache_buf,
|
_get_cache_buf,
|
||||||
_get_cuda_stream,
|
_get_cuda_stream,
|
||||||
@@ -46,25 +13,25 @@ from sgl_kernel.ops.utils import (
|
|||||||
def init_custom_reduce(
|
def init_custom_reduce(
|
||||||
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
|
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
|
||||||
):
|
):
|
||||||
return _init_custom_ar(
|
return torch.ops.sgl_kernels.init_custom_ar(
|
||||||
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
|
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def custom_dispose(fa):
|
def custom_dispose(fa):
|
||||||
_dispose(fa)
|
torch.ops.sgl_kernels.dispose(fa)
|
||||||
|
|
||||||
|
|
||||||
def custom_reduce(fa, inp, out):
|
def custom_reduce(fa, inp, out):
|
||||||
_all_reduce(fa, inp, out)
|
torch.ops.sgl_kernels.all_reduce(fa, inp, out)
|
||||||
|
|
||||||
|
|
||||||
def get_graph_buffer_ipc_meta(fa):
|
def get_graph_buffer_ipc_meta(fa):
|
||||||
return _get_graph_buffer_ipc_meta(fa)
|
return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa)
|
||||||
|
|
||||||
|
|
||||||
def register_graph_buffers(fa, handles, offsets):
|
def register_graph_buffers(fa, handles, offsets):
|
||||||
_register_graph_buffers(fa, handles, offsets)
|
torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets)
|
||||||
|
|
||||||
|
|
||||||
def moe_align_block_size(
|
def moe_align_block_size(
|
||||||
@@ -77,7 +44,7 @@ def moe_align_block_size(
|
|||||||
token_cnts_buffer,
|
token_cnts_buffer,
|
||||||
cumsum_buffer,
|
cumsum_buffer,
|
||||||
):
|
):
|
||||||
_moe_align_block_size(
|
torch.ops.sgl_kernels.moe_align_block_size(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
num_experts,
|
num_experts,
|
||||||
block_size,
|
block_size,
|
||||||
@@ -90,11 +57,11 @@ def moe_align_block_size(
|
|||||||
|
|
||||||
|
|
||||||
def sampling_scaling_penalties(logits, scaling_penalties):
|
def sampling_scaling_penalties(logits, scaling_penalties):
|
||||||
return _sampling_scaling_penalties(logits, scaling_penalties)
|
return torch.ops.sgl_kernels.sampling_scaling_penalties(logits, scaling_penalties)
|
||||||
|
|
||||||
|
|
||||||
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||||
return _int8_scaled_mm(
|
return torch.ops.sgl_kernels.int8_scaled_mm(
|
||||||
mat_a,
|
mat_a,
|
||||||
mat_b,
|
mat_b,
|
||||||
scales_a,
|
scales_a,
|
||||||
@@ -105,11 +72,15 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
|||||||
|
|
||||||
|
|
||||||
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
|
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
|
||||||
_lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv)
|
torch.ops.sgl_kernels.lightning_attention_decode(
|
||||||
|
q, k, v, past_kv, slope, output, new_kv
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox):
|
def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox):
|
||||||
return _rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox)
|
return torch.ops.sgl_kernels.rotary_embedding(
|
||||||
|
positions, query, key, head_size, cos_sin_cache, is_neox
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
|
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
|
||||||
@@ -123,7 +94,7 @@ def rmsnorm(
|
|||||||
with input.device as device:
|
with input.device as device:
|
||||||
if out is None:
|
if out is None:
|
||||||
out = torch.empty_like(input)
|
out = torch.empty_like(input)
|
||||||
_rmsnorm(out, input, weight, eps, _get_cuda_stream(device))
|
torch.ops.sgl_kernels.rmsnorm(out, input, weight, eps, _get_cuda_stream(device))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@@ -131,7 +102,9 @@ def fused_add_rmsnorm(
|
|||||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||||
) -> None:
|
) -> None:
|
||||||
with input.device as device:
|
with input.device as device:
|
||||||
_fused_add_rmsnorm(input, residual, weight, eps, _get_cuda_stream(device))
|
torch.ops.sgl_kernels.fused_add_rmsnorm(
|
||||||
|
input, residual, weight, eps, _get_cuda_stream(device)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def gemma_rmsnorm(
|
def gemma_rmsnorm(
|
||||||
@@ -143,7 +116,9 @@ def gemma_rmsnorm(
|
|||||||
with input.device as device:
|
with input.device as device:
|
||||||
if out is None:
|
if out is None:
|
||||||
out = torch.empty_like(input)
|
out = torch.empty_like(input)
|
||||||
_gemma_rmsnorm(out, input, weight, eps, _get_cuda_stream(device))
|
torch.ops.sgl_kernels.gemma_rmsnorm(
|
||||||
|
out, input, weight, eps, _get_cuda_stream(device)
|
||||||
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@@ -151,7 +126,9 @@ def gemma_fused_add_rmsnorm(
|
|||||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||||
) -> None:
|
) -> None:
|
||||||
with input.device as device:
|
with input.device as device:
|
||||||
_gemma_fused_add_rmsnorm(input, residual, weight, eps, _get_cuda_stream(device))
|
torch.ops.sgl_kernels.gemma_fused_add_rmsnorm(
|
||||||
|
input, residual, weight, eps, _get_cuda_stream(device)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
|
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
|
||||||
@@ -176,7 +153,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
|||||||
dtype=input.dtype,
|
dtype=input.dtype,
|
||||||
)
|
)
|
||||||
with input.device as device:
|
with input.device as device:
|
||||||
_silu_and_mul(out, input, _get_cuda_stream(device))
|
torch.ops.sgl_kernels.silu_and_mul(out, input, _get_cuda_stream(device))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@@ -192,7 +169,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
|
|||||||
dtype=input.dtype,
|
dtype=input.dtype,
|
||||||
)
|
)
|
||||||
with input.device as device:
|
with input.device as device:
|
||||||
_gelu_tanh_and_mul(out, input, _get_cuda_stream(device))
|
torch.ops.sgl_kernels.gelu_tanh_and_mul(out, input, _get_cuda_stream(device))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@@ -208,7 +185,7 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
|||||||
dtype=input.dtype,
|
dtype=input.dtype,
|
||||||
)
|
)
|
||||||
with input.device as device:
|
with input.device as device:
|
||||||
_gelu_and_mul(out, input, _get_cuda_stream(device))
|
torch.ops.sgl_kernels.gelu_and_mul(out, input, _get_cuda_stream(device))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@@ -222,7 +199,7 @@ def _bmm_fp8_internal(
|
|||||||
) -> None:
|
) -> None:
|
||||||
with A.device as device:
|
with A.device as device:
|
||||||
cublas_handle = torch.cuda.current_blas_handle()
|
cublas_handle = torch.cuda.current_blas_handle()
|
||||||
_bmm_fp8(
|
torch.ops.sgl_kernels.bmm_fp8(
|
||||||
A,
|
A,
|
||||||
B,
|
B,
|
||||||
D,
|
D,
|
||||||
@@ -262,7 +239,7 @@ def _top_k_renorm_probs_internal(
|
|||||||
probs = probs.float()
|
probs = probs.float()
|
||||||
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
||||||
renorm_probs = torch.empty_like(probs)
|
renorm_probs = torch.empty_like(probs)
|
||||||
_top_k_renorm_probs(
|
torch.ops.sgl_kernels.top_k_renorm_probs_wrapper(
|
||||||
probs,
|
probs,
|
||||||
renorm_probs,
|
renorm_probs,
|
||||||
maybe_top_k_arr,
|
maybe_top_k_arr,
|
||||||
@@ -293,7 +270,7 @@ def _top_p_renorm_probs_internal(
|
|||||||
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||||
)
|
)
|
||||||
renorm_probs = torch.empty_like(probs)
|
renorm_probs = torch.empty_like(probs)
|
||||||
_top_p_renorm_probs(
|
torch.ops.sgl_kernels.top_p_renorm_probs(
|
||||||
probs,
|
probs,
|
||||||
renorm_probs,
|
renorm_probs,
|
||||||
maybe_top_p_arr,
|
maybe_top_p_arr,
|
||||||
@@ -328,7 +305,7 @@ def _top_p_sampling_from_probs_internal(
|
|||||||
)
|
)
|
||||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||||
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
||||||
_top_p_sampling_from_probs(
|
torch.ops.sgl_kernels.top_p_sampling_from_probs(
|
||||||
probs,
|
probs,
|
||||||
uniform_samples,
|
uniform_samples,
|
||||||
samples,
|
samples,
|
||||||
@@ -374,7 +351,7 @@ def _top_k_top_p_sampling_from_probs_internal(
|
|||||||
)
|
)
|
||||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||||
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
||||||
_top_k_top_p_sampling_from_probs(
|
torch.ops.sgl_kernels.top_k_top_p_sampling_from_probs(
|
||||||
probs,
|
probs,
|
||||||
uniform_samples,
|
uniform_samples,
|
||||||
samples,
|
samples,
|
||||||
@@ -432,7 +409,7 @@ def _min_p_sampling_from_probs_internal(
|
|||||||
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
|
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
|
||||||
)
|
)
|
||||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||||
_min_p_sampling_from_probs(
|
torch.ops.sgl_kernels.min_p_sampling_from_probs(
|
||||||
probs,
|
probs,
|
||||||
uniform_samples,
|
uniform_samples,
|
||||||
samples,
|
samples,
|
||||||
|
|||||||
119
sgl-kernel/src/sgl-kernel/torch_extension.cc
Normal file
119
sgl-kernel/src/sgl-kernel/torch_extension.cc
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
|
||||||
|
#include <ATen/core/dispatch/Dispatcher.h>
|
||||||
|
#include <torch/library.h>
|
||||||
|
|
||||||
|
#include "sgl_kernels_ops.h"
|
||||||
|
|
||||||
|
TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
|
||||||
|
// trt_reduce
|
||||||
|
m.def(
|
||||||
|
"init_custom_ar(int rank_id, int world_size, Tensor rank_data, int[] buffers, int[] tmp_result_buffers, int[] "
|
||||||
|
"barrier_in, int[] barrier_out) -> int");
|
||||||
|
m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
|
||||||
|
|
||||||
|
m.def("dispose", &dispose);
|
||||||
|
|
||||||
|
m.def("all_reduce(int fa, Tensor inp, Tensor! out) -> ()");
|
||||||
|
m.impl("all_reduce", torch::kCUDA, &all_reduce);
|
||||||
|
|
||||||
|
m.def("get_graph_buffer_ipc_meta(int fa) -> (int[], int[])");
|
||||||
|
m.impl("get_graph_buffer_ipc_meta", torch::kCUDA, &get_graph_buffer_ipc_meta);
|
||||||
|
|
||||||
|
m.def("register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()");
|
||||||
|
m.impl("register_graph_buffers", torch::kCUDA, ®ister_graph_buffers);
|
||||||
|
|
||||||
|
// moe_align_block_size
|
||||||
|
m.def(
|
||||||
|
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
|
||||||
|
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
|
||||||
|
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||||
|
|
||||||
|
// sampling_scaling_penalties
|
||||||
|
m.def("sampling_scaling_penalties(Tensor logits, Tensor scaling_penalties) -> Tensor");
|
||||||
|
m.impl("sampling_scaling_penalties", torch::kCUDA, &sampling_scaling_penalties);
|
||||||
|
|
||||||
|
// int8_scaled_mm
|
||||||
|
m.def(
|
||||||
|
"int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
|
||||||
|
"bias) -> Tensor");
|
||||||
|
m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm);
|
||||||
|
|
||||||
|
// lightning_attention_decode
|
||||||
|
m.def(
|
||||||
|
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
|
||||||
|
"new_kv) -> ()");
|
||||||
|
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
|
||||||
|
|
||||||
|
// rotary embedding
|
||||||
|
m.def(
|
||||||
|
"rotary_embedding(Tensor positions, Tensor! query, Tensor! key, int head_size, Tensor cos_sin_cache, bool "
|
||||||
|
"is_neox) -> ()");
|
||||||
|
m.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
|
||||||
|
|
||||||
|
// rms norm
|
||||||
|
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||||
|
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);
|
||||||
|
|
||||||
|
// fused rms norm
|
||||||
|
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||||
|
m.impl("fused_add_rmsnorm", torch::kCUDA, &fused_add_rmsnorm);
|
||||||
|
|
||||||
|
// gemma rms norm
|
||||||
|
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||||
|
m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm);
|
||||||
|
|
||||||
|
// fused gemma rms norm
|
||||||
|
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||||
|
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);
|
||||||
|
|
||||||
|
// silu and mul
|
||||||
|
m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
||||||
|
m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
|
||||||
|
|
||||||
|
// gelu tanh and mul
|
||||||
|
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
||||||
|
m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
|
||||||
|
|
||||||
|
// gelu and mul
|
||||||
|
m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
||||||
|
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
|
||||||
|
|
||||||
|
// bmm fp8
|
||||||
|
m.def(
|
||||||
|
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int "
|
||||||
|
"cublas_handle, int cuda_stream) -> ()");
|
||||||
|
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
|
||||||
|
|
||||||
|
// min p sampling from probs
|
||||||
|
m.def(
|
||||||
|
"min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float "
|
||||||
|
"min_p_val, bool deterministic, int cuda_stream) -> ()");
|
||||||
|
m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs);
|
||||||
|
|
||||||
|
// top k renorm probs
|
||||||
|
m.def(
|
||||||
|
"top_k_renorm_probs_wrapper(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int "
|
||||||
|
"cuda_stream) -> ()");
|
||||||
|
m.impl("top_k_renorm_probs_wrapper", torch::kCUDA, &top_k_renorm_probs_wrapper);
|
||||||
|
|
||||||
|
// top p renorm probs
|
||||||
|
m.def(
|
||||||
|
"top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int "
|
||||||
|
"cuda_stream) -> ()");
|
||||||
|
m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs);
|
||||||
|
|
||||||
|
// top k top p sampling from probs
|
||||||
|
m.def(
|
||||||
|
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
|
||||||
|
"maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int "
|
||||||
|
"cuda_stream) -> ()");
|
||||||
|
m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs);
|
||||||
|
|
||||||
|
// top p sampling from probs
|
||||||
|
m.def(
|
||||||
|
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
|
||||||
|
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
|
||||||
|
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_EXTENSION(_kernels)
|
||||||
Reference in New Issue
Block a user