[CPU] [sgl-kernel] set dispatch key of initialize to CatchAll (#7734)
This commit is contained in:
@@ -342,7 +342,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
|
|
||||||
// all reduce
|
// all reduce
|
||||||
m.def("initialize(int size, int rank) -> ()");
|
m.def("initialize(int size, int rank) -> ()");
|
||||||
m.impl("initialize", torch::kCPU, &initialize);
|
|
||||||
m.def("shm_allreduce(Tensor data, int reduce_op) -> ()");
|
m.def("shm_allreduce(Tensor data, int reduce_op) -> ()");
|
||||||
m.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
|
m.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
|
||||||
m.def("shm_allgather(Tensor data, int dim) -> Tensor");
|
m.def("shm_allgather(Tensor data, int dim) -> Tensor");
|
||||||
@@ -360,6 +359,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
|
|
||||||
TORCH_LIBRARY_IMPL(sgl_kernel, CatchAll, m) {
|
TORCH_LIBRARY_IMPL(sgl_kernel, CatchAll, m) {
|
||||||
m.impl("init_cpu_threads_env", init_cpu_threads_env);
|
m.impl("init_cpu_threads_env", init_cpu_threads_env);
|
||||||
|
m.impl("initialize", &initialize);
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_EXTENSION(common_ops)
|
REGISTER_EXTENSION(common_ops)
|
||||||
|
|||||||
Reference in New Issue
Block a user