support 1 shot allreduce in 1-node and 2-node using mscclpp (#6277)

This commit is contained in:
zyksir
2025-06-05 13:11:24 +08:00
committed by GitHub
parent 4474eaf552
commit 8e3797be1c
20 changed files with 2177 additions and 12 deletions

View File

@@ -49,6 +49,27 @@ if torch.version.hip is not None:
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle.default(inp)
def mscclpp_generate_unique_id() -> bytes:
raise NotImplementedError()
def mscclpp_init_context(
unique_id: bytes,
rank: int,
world_size: int,
scratch: torch.Tensor,
put_buffer: torch.Tensor,
nranks_per_node: int,
rank_to_node: List[int],
rank_to_ib: List[int],
context_selection: int,
) -> int:
raise NotImplementedError()
def mscclpp_allreduce(
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
) -> None:
raise NotImplementedError()
else:
def init_custom_ar(
@@ -85,3 +106,36 @@ else:
def meta_size() -> int:
return torch.ops.sgl_kernel.meta_size.default()
def mscclpp_generate_unique_id() -> torch.Tensor:
return torch.ops.sgl_kernel.mscclpp_generate_unique_id.default()
def mscclpp_init_context(
unique_id: torch.Tensor,
rank: int,
world_size: int,
scratch: torch.Tensor,
put_buffer: torch.Tensor,
nranks_per_node: int,
rank_to_node: List[int],
rank_to_ib: List[int],
context_selection: int,
) -> int:
return torch.ops.sgl_kernel.mscclpp_init_context.default(
unique_id,
rank,
world_size,
scratch,
put_buffer,
nranks_per_node,
rank_to_node,
rank_to_ib,
context_selection,
)
def mscclpp_allreduce(
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
) -> None:
torch.ops.sgl_kernel.mscclpp_allreduce.default(
context, inp, out, nthreads, nblocks
)