[CPU] remove process_group from inputs of shm_allreduce and shm_allgather (#7486)

This commit is contained in:
Chunyuan WU
2025-07-01 12:54:11 +08:00
committed by GitHub
parent ff2e9c9479
commit 6005eceee3
2 changed files with 9 additions and 57 deletions

View File

@@ -212,11 +212,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope_fused_weight(
void initialize(int64_t size, int64_t rank);
// shared mmeory all_reduce
void shm_allreduce(
at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, c10::intrusive_ptr<c10d::ReduceOp> op);
void shm_allreduce(at::Tensor& data, int64_t op);
// shared memory all_gather
at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int64_t dim);
at::Tensor shm_allgather(at::Tensor& data, int64_t dim);
// rope
std::tuple<at::Tensor, at::Tensor> rotary_embedding_cpu(
@@ -344,11 +343,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// all reduce
m.def("initialize(int size, int rank) -> ()");
m.impl("initialize", torch::kCPU, &initialize);
m.def(
"shm_allreduce(Tensor data, __torch__.torch.classes.c10d.ProcessGroup process_group, "
"__torch__.torch.classes.c10d.ReduceOp reduce_op) -> ()");
m.def("shm_allreduce(Tensor data, int reduce_op) -> ()");
m.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
m.def("shm_allgather(Tensor data, __torch__.torch.classes.c10d.ProcessGroup process_group, int dim) -> Tensor");
m.def("shm_allgather(Tensor data, int dim) -> Tensor");
m.impl("shm_allgather", torch::kCPU, &shm_allgather);
// rope