sgl-kernel transfer custom allreduce from trt kernel to vllm kernel (#5079)
This commit is contained in:
@@ -26,15 +26,18 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
|
||||
m.def("register_graph_buffers", ®ister_graph_buffers);
|
||||
m.def("dispose", &dispose);
|
||||
m.def("meta_size", &meta_size);
|
||||
m.def("register_buffer", ®ister_buffer);
|
||||
|
||||
m.def(
|
||||
"init_custom_ar(int rank_id, int world_size, Tensor rank_data, int[] buffers, int[] tmp_result_buffers, int[] "
|
||||
"barrier_in, int[] barrier_out) -> int");
|
||||
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
|
||||
"int rank, bool full_nvlink) -> int");
|
||||
m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
|
||||
|
||||
m.def("all_reduce(int fa, Tensor inp, Tensor! out) -> ()");
|
||||
m.def(
|
||||
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
|
||||
"int reg_buffer_sz_bytes) -> ()");
|
||||
m.impl("all_reduce", torch::kCUDA, &all_reduce);
|
||||
|
||||
/*
|
||||
* From csrc/attention
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user