Clean up allocators (#9134)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Lianmin Zheng
2025-08-13 13:56:04 -07:00
committed by GitHub
parent 2f20f43026
commit 9e426466af
16 changed files with 288 additions and 295 deletions

View File

@@ -21,7 +21,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
/*
* From csrc/allreduce
*/
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
m.def("register_graph_buffers", &register_graph_buffers);
m.def("dispose", &dispose);
@@ -46,6 +45,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("mscclpp_allreduce(int context, Tensor inp, Tensor! out, int nthreads, int nblocks) -> ()");
m.impl("mscclpp_allreduce", torch::kCUDA, &mscclpp_allreduce);
/*
* From csrc/attention
*/
@@ -284,6 +284,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"page_size) -> ()");
m.impl("transfer_kv_direct", torch::kCUDA, &transfer_kv_direct);
/*
* From csrc/memory
*/
m.def("store_kv_cache(Tensor k_cache, Tensor v_cache, Tensor out_loc, Tensor k, Tensor v) -> ()");
m.impl("store_kv_cache", &store_kv_cache);
/*
* From csrc/moe/cutlass_moe/w4a8
*/
@@ -390,13 +396,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA, &convert_vertical_slash_indexes_mergehead);
/*
* From XGrammar
* From csrc/grammar
*/
m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()");
m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace);
/*
* From QServe
* From csrc/gemm (QServe)
*/
m.def(
"qserve_w4a8_per_chn_gemm(Tensor _in_feats, Tensor _kernel, Tensor _wscales, Tensor _ascales, Tensor _w_szs, "
@@ -413,12 +419,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
*/
m.def("create_greenctx_stream_by_value(int smA, int smB, int device) -> int[]");
m.impl("create_greenctx_stream_by_value", &create_greenctx_stream_by_value);
/*
* From csrc/memory
*/
m.def("store_kv_cache(Tensor k_cache, Tensor v_cache, Tensor out_loc, Tensor k, Tensor v) -> ()");
m.impl("store_kv_cache", &store_kv_cache);
}
REGISTER_EXTENSION(common_ops)

View File

@@ -47,7 +47,7 @@ sources = [
"csrc/moe/moe_align_kernel.cu",
"csrc/moe/moe_topk_softmax_kernels.cu",
"csrc/speculative/eagle_utils.cu",
"csrc/torch_extension_rocm.cc",
"csrc/common_extension_rocm.cc",
]
cxx_flags = ["-O3"]