support 1 shot allreduce in 1-node and 2-node using mscclpp (#6277)
This commit is contained in:
@@ -38,6 +38,15 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
|
||||
"int reg_buffer_sz_bytes) -> ()");
|
||||
m.impl("all_reduce", torch::kCUDA, &all_reduce);
|
||||
|
||||
m.def("mscclpp_generate_unique_id", &mscclpp_generate_unique_id);
|
||||
m.def(
|
||||
"mscclpp_init_context(Tensor unique_id, int rank, int world_size, Tensor scratch, Tensor put_buffer, "
|
||||
"int nranks_per_node, int[] rank_to_node, int[] rank_to_ib, int context_selection) -> int");
|
||||
m.impl("mscclpp_init_context", torch::kCUDA, &mscclpp_init_context);
|
||||
|
||||
m.def("mscclpp_allreduce(int context, Tensor inp, Tensor! out, int nthreads, int nblocks) -> ()");
|
||||
m.impl("mscclpp_allreduce", torch::kCUDA, &mscclpp_allreduce);
|
||||
/*
|
||||
* From csrc/attention
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user