support 1 shot allreduce in 1-node and 2-node using mscclpp (#6277)
This commit is contained in:
@@ -49,6 +49,27 @@ if torch.version.hip is not None:
|
||||
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle.default(inp)
|
||||
|
||||
def mscclpp_generate_unique_id() -> bytes:
|
||||
raise NotImplementedError()
|
||||
|
||||
def mscclpp_init_context(
|
||||
unique_id: bytes,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
scratch: torch.Tensor,
|
||||
put_buffer: torch.Tensor,
|
||||
nranks_per_node: int,
|
||||
rank_to_node: List[int],
|
||||
rank_to_ib: List[int],
|
||||
context_selection: int,
|
||||
) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
def mscclpp_allreduce(
|
||||
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
else:
|
||||
|
||||
def init_custom_ar(
|
||||
@@ -85,3 +106,36 @@ else:
|
||||
|
||||
def meta_size() -> int:
|
||||
return torch.ops.sgl_kernel.meta_size.default()
|
||||
|
||||
def mscclpp_generate_unique_id() -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.mscclpp_generate_unique_id.default()
|
||||
|
||||
def mscclpp_init_context(
|
||||
unique_id: torch.Tensor,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
scratch: torch.Tensor,
|
||||
put_buffer: torch.Tensor,
|
||||
nranks_per_node: int,
|
||||
rank_to_node: List[int],
|
||||
rank_to_ib: List[int],
|
||||
context_selection: int,
|
||||
) -> int:
|
||||
return torch.ops.sgl_kernel.mscclpp_init_context.default(
|
||||
unique_id,
|
||||
rank,
|
||||
world_size,
|
||||
scratch,
|
||||
put_buffer,
|
||||
nranks_per_node,
|
||||
rank_to_node,
|
||||
rank_to_ib,
|
||||
context_selection,
|
||||
)
|
||||
|
||||
def mscclpp_allreduce(
|
||||
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.mscclpp_allreduce.default(
|
||||
context, inp, out, nthreads, nblocks
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user