sgl-kernel transfer custom allreduce from trt kernel to vllm kernel (#5079)

This commit is contained in:
Yi Zhang
2025-04-06 05:23:20 +08:00
committed by GitHub
parent 0d99adb715
commit bcbbf519f9
10 changed files with 692 additions and 937 deletions

View File

@@ -26,15 +26,18 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
m.def("register_graph_buffers", &register_graph_buffers);
m.def("dispose", &dispose);
m.def("meta_size", &meta_size);
m.def("register_buffer", &register_buffer);
m.def(
"init_custom_ar(int rank_id, int world_size, Tensor rank_data, int[] buffers, int[] tmp_result_buffers, int[] "
"barrier_in, int[] barrier_out) -> int");
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
"int rank, bool full_nvlink) -> int");
m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
m.def("all_reduce(int fa, Tensor inp, Tensor! out) -> ()");
m.def(
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()");
m.impl("all_reduce", torch::kCUDA, &all_reduce);
/*
* From csrc/attention
*/