support 1 shot allreduce in 1-node and 2-node using mscclpp (#6277)
This commit is contained in:
146
sgl-kernel/tests/test_mscclpp.py
Normal file
146
sgl-kernel/tests/test_mscclpp.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import socket
|
||||
import unittest
|
||||
from enum import IntEnum
|
||||
from typing import Any
|
||||
|
||||
import sgl_kernel.allreduce as custom_ops
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
class MscclContextSelection(IntEnum):
|
||||
MSCCL1SHOT1NODELL = 1
|
||||
MSCCL1SHOT2NODELL = 2
|
||||
|
||||
|
||||
def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes):
|
||||
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
|
||||
torch.cuda.set_device(device)
|
||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
init_method=distributed_init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
group = dist.group.WORLD
|
||||
cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo")
|
||||
if rank == 0:
|
||||
unique_id = [custom_ops.mscclpp_generate_unique_id()]
|
||||
else:
|
||||
unique_id = [None]
|
||||
dist.broadcast_object_list(
|
||||
unique_id, src=0, device=torch.device("cpu"), group=cpu_group
|
||||
)
|
||||
unique_id = unique_id[0]
|
||||
rank_to_node, rank_to_ib = list(range(world_size)), list(range(world_size))
|
||||
for r in range(world_size):
|
||||
rank_to_node[r] = r // 8
|
||||
rank_to_ib[r] = rank % 8
|
||||
MAX_BYTES = 2**20
|
||||
scratch = torch.empty(
|
||||
MAX_BYTES * 8, dtype=torch.bfloat16, device=torch.cuda.current_device()
|
||||
)
|
||||
put_buffer = torch.empty(
|
||||
MAX_BYTES, dtype=torch.bfloat16, device=torch.cuda.current_device()
|
||||
)
|
||||
print(f"[{rank}] start mscclpp_context init")
|
||||
nranks_per_node = torch.cuda.device_count()
|
||||
selection = int(MscclContextSelection.MSCCL1SHOT1NODELL)
|
||||
mscclpp_context = custom_ops.mscclpp_init_context(
|
||||
unique_id,
|
||||
rank,
|
||||
world_size,
|
||||
scratch,
|
||||
put_buffer,
|
||||
nranks_per_node,
|
||||
rank_to_node,
|
||||
rank_to_ib,
|
||||
selection,
|
||||
)
|
||||
try:
|
||||
test_loop = 10
|
||||
for sz in test_sizes:
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
if sz * dtype.itemsize > MAX_BYTES:
|
||||
continue
|
||||
if rank == 0:
|
||||
print(f"mscclpp allreduce test sz {sz}, dtype {dtype}")
|
||||
for _ in range(test_loop):
|
||||
inp1 = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
|
||||
inp1_ref = inp1.clone()
|
||||
out1 = torch.empty_like(inp1)
|
||||
custom_ops.mscclpp_allreduce(
|
||||
mscclpp_context, inp1, out1, nthreads=512, nblocks=21
|
||||
)
|
||||
dist.all_reduce(inp1_ref, group=group)
|
||||
torch.testing.assert_close(out1, inp1_ref)
|
||||
finally:
|
||||
dist.barrier(group=group)
|
||||
dist.destroy_process_group(group=group)
|
||||
|
||||
|
||||
def get_open_port() -> int:
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
except OSError:
|
||||
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
||||
s.bind(("::1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def multi_process_parallel(
|
||||
world_size: int, test_target: Any, target_args: tuple = ()
|
||||
) -> None:
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
procs = []
|
||||
distributed_init_port = get_open_port()
|
||||
for i in range(world_size):
|
||||
proc_args = (world_size, i, distributed_init_port) + target_args
|
||||
proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
|
||||
for i in range(world_size):
|
||||
procs[i].join()
|
||||
assert (
|
||||
procs[i].exitcode == 0
|
||||
), f"Process {i} failed with exit code {procs[i].exitcode}"
|
||||
|
||||
|
||||
class TestMSCCLAllReduce(unittest.TestCase):
|
||||
test_sizes = [
|
||||
512,
|
||||
2560,
|
||||
4096,
|
||||
5120,
|
||||
7680,
|
||||
32768,
|
||||
262144,
|
||||
524288,
|
||||
]
|
||||
world_sizes = [8]
|
||||
|
||||
def test_correctness(self):
|
||||
for world_size in self.world_sizes:
|
||||
available_gpus = torch.cuda.device_count()
|
||||
if world_size > available_gpus:
|
||||
print(
|
||||
f"Skipping world_size={world_size}, found {available_gpus} and now ray is not supported here"
|
||||
)
|
||||
continue
|
||||
|
||||
print(f"Running test for world_size={world_size}")
|
||||
multi_process_parallel(
|
||||
world_size, _run_correctness_worker, target_args=(self.test_sizes,)
|
||||
)
|
||||
print(f"custom allreduce tp = {world_size}: OK")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user