[CPU] remove process_group from inputs of shm_allreduce and shm_allgather (#7486)
This commit is contained in:
@@ -212,11 +212,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope_fused_weight(
|
||||
void initialize(int64_t size, int64_t rank);
|
||||
|
||||
// shared mmeory all_reduce
|
||||
void shm_allreduce(
|
||||
at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, c10::intrusive_ptr<c10d::ReduceOp> op);
|
||||
void shm_allreduce(at::Tensor& data, int64_t op);
|
||||
|
||||
// shared memory all_gather
|
||||
at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int64_t dim);
|
||||
at::Tensor shm_allgather(at::Tensor& data, int64_t dim);
|
||||
|
||||
// rope
|
||||
std::tuple<at::Tensor, at::Tensor> rotary_embedding_cpu(
|
||||
@@ -344,11 +343,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
// all reduce
|
||||
m.def("initialize(int size, int rank) -> ()");
|
||||
m.impl("initialize", torch::kCPU, &initialize);
|
||||
m.def(
|
||||
"shm_allreduce(Tensor data, __torch__.torch.classes.c10d.ProcessGroup process_group, "
|
||||
"__torch__.torch.classes.c10d.ReduceOp reduce_op) -> ()");
|
||||
m.def("shm_allreduce(Tensor data, int reduce_op) -> ()");
|
||||
m.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
|
||||
m.def("shm_allgather(Tensor data, __torch__.torch.classes.c10d.ProcessGroup process_group, int dim) -> Tensor");
|
||||
m.def("shm_allgather(Tensor data, int dim) -> Tensor");
|
||||
m.impl("shm_allgather", torch::kCPU, &shm_allgather);
|
||||
|
||||
// rope
|
||||
|
||||
Reference in New Issue
Block a user