[Feature] use pytest for sgl-kernel (#4896)

This commit is contained in:
Adarsh Shirawalmath
2025-03-30 23:06:52 +05:30
committed by GitHub
parent 4ede6770cd
commit 9fccda3111
10 changed files with 263 additions and 290 deletions

View File

@@ -13,155 +13,186 @@ from torch.distributed import ProcessGroup
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes):
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
ranks = list(range(world_size))
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
dist.init_process_group(
backend="nccl",
init_method=distributed_init_method,
rank=rank,
world_size=world_size,
)
group = dist.group.WORLD
buffer_max_size = 8 * 1024 * 1024
barrier_max_size = 8 * (24 + 2) * 8
buffer_ptrs = None
tmp_result_buffer_ptrs = None
barrier_in_ptrs = None
barrier_out_ptrs = None
custom_ptr = None
try:
buffer_ptrs = TestCustomAllReduce.create_shared_buffer(
buffer_max_size, group=group
)
tmp_result_buffer_ptrs = TestCustomAllReduce.create_shared_buffer(
buffer_max_size, group=group
)
barrier_in_ptrs = TestCustomAllReduce.create_shared_buffer(
barrier_max_size, group=group
)
barrier_out_ptrs = TestCustomAllReduce.create_shared_buffer(
barrier_max_size, group=group
)
rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=device)
custom_ptr = custom_ops.init_custom_reduce(
rank,
world_size,
rank_data,
buffer_ptrs,
tmp_result_buffer_ptrs,
barrier_in_ptrs,
barrier_out_ptrs,
)
test_loop = 10
for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
for _ in range(test_loop):
inp1 = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
inp1_ref = inp1.clone()
out1 = torch.empty_like(inp1)
custom_ops.custom_reduce(custom_ptr, inp1, out1)
dist.all_reduce(inp1_ref, group=group)
torch.testing.assert_close(out1, inp1_ref)
finally:
dist.barrier(group=group)
if custom_ptr is not None:
custom_ops.custom_dispose(custom_ptr)
if buffer_ptrs:
TestCustomAllReduce.free_shared_buffer(buffer_ptrs, group)
if tmp_result_buffer_ptrs:
TestCustomAllReduce.free_shared_buffer(tmp_result_buffer_ptrs, group)
if barrier_in_ptrs:
TestCustomAllReduce.free_shared_buffer(barrier_in_ptrs, group)
if barrier_out_ptrs:
TestCustomAllReduce.free_shared_buffer(barrier_out_ptrs, group)
dist.destroy_process_group(group=group)
def get_open_port() -> int:
# try ipv4
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
except OSError:
# try ipv6
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
s.bind(("::1", 0))
return s.getsockname()[1]
def multi_process_parallel(
world_size: int,
test_target: Any,
world_size: int, test_target: Any, target_args: tuple = ()
) -> None:
mp.set_start_method("spawn", force=True)
procs = []
distributed_init_port = get_open_port()
for i in range(world_size):
proc = mp.Process(
target=test_target,
args=(world_size, i, distributed_init_port),
)
proc_args = (world_size, i, distributed_init_port) + target_args
proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
proc.start()
procs.append(proc)
for i in range(world_size):
procs[i].join()
assert procs[i].exitcode == 0
assert (
procs[i].exitcode == 0
), f"Process {i} failed with exit code {procs[i].exitcode}"
class TestCustomAllReduce(unittest.TestCase):
@classmethod
def setUpClass(cls):
random.seed(42)
cls.test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152]
cls.world_sizes = [2, 4, 8]
test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152]
world_sizes = [2, 4, 8]
@staticmethod
def create_shared_buffer(
size_in_bytes: int, group: Optional[ProcessGroup] = None
) -> List[int]:
"""
Creates a shared buffer and returns a list of pointers
representing the buffer on all processes in the group.
"""
lib = CudaRTLibrary()
pointer = lib.cudaMalloc(size_in_bytes)
handle = lib.cudaIpcGetMemHandle(pointer)
if group is None:
group = dist.group.WORLD
world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group)
handle_bytes = ctypes.string_at(ctypes.addressof(handle), ctypes.sizeof(handle))
input_tensor = torch.ByteTensor(list(handle_bytes)).to(f"cuda:{rank}")
gathered_tensors = [torch.empty_like(input_tensor) for _ in range(world_size)]
dist.all_gather(gathered_tensors, input_tensor, group=group)
handles = []
handle_type = type(handle)
for tensor in gathered_tensors:
bytes_list = tensor.cpu().tolist()
bytes_data = bytes(bytes_list)
handle_obj = handle_type()
ctypes.memmove(ctypes.addressof(handle_obj), bytes_data, len(bytes_data))
handles.append(handle_obj)
pointers: List[int] = []
for i, h in enumerate(handles):
if i == rank:
pointers.append(pointer.value) # type: ignore
pointers.append(pointer.value)
else:
pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore
try:
opened_ptr = lib.cudaIpcOpenMemHandle(h)
pointers.append(opened_ptr.value)
except Exception as e:
print(f"Rank {rank}: Failed to open IPC handle from rank {i}: {e}")
raise
dist.barrier(group=group)
return pointers
@staticmethod
def free_shared_buffer(
pointers: List[int], group: Optional[ProcessGroup] = None
) -> None:
if group is None:
group = dist.group.WORLD
rank = dist.get_rank(group=group)
lib = CudaRTLibrary()
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
if pointers and len(pointers) > rank and pointers[rank] is not None:
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
dist.barrier(group=group)
def test_correctness(self):
for world_size in self.world_sizes:
if world_size > torch.cuda.device_count():
available_gpus = torch.cuda.device_count()
if world_size > available_gpus:
print(
f"Skipping world_size={world_size}, requires {world_size} GPUs, found {available_gpus}"
)
continue
multi_process_parallel(world_size, self.correctness)
print(f"Running test for world_size={world_size}")
multi_process_parallel(
world_size, _run_correctness_worker, target_args=(self.test_sizes,)
)
print(f"custom allreduce tp = {world_size}: OK")
def init_custom_allreduce(self, rank, world_size, group):
buffer_max_size = 8 * 1024 * 1024
barrier_max_size = 8 * (24 + 2) * 8
self.buffer_ptrs = self.create_shared_buffer(buffer_max_size, group=group)
self.tmp_result_buffer_ptrs = self.create_shared_buffer(
buffer_max_size, group=group
)
self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
self.rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0")
)
self.custom_ptr = custom_ops.init_custom_reduce(
rank,
world_size,
self.rank_data,
self.buffer_ptrs,
self.tmp_result_buffer_ptrs,
self.barrier_in_ptrs,
self.barrier_out_ptrs,
)
def custom_allreduce(self, inp, out):
custom_ops.custom_reduce(self.custom_ptr, inp, out)
def free_custom_allreduce(self, group):
self.free_shared_buffer(self.buffer_ptrs, group)
self.free_shared_buffer(self.tmp_result_buffer_ptrs, group)
self.free_shared_buffer(self.barrier_in_ptrs, group)
self.free_shared_buffer(self.barrier_out_ptrs, group)
custom_ops.custom_dispose(self.custom_ptr)
@staticmethod
def init_distributed_env(world_size, rank, distributed_init_port):
device = torch.device("cuda:0")
torch.cuda.set_device(device)
ranks = [i for i in range(world_size)]
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
dist.init_process_group(
backend="nccl",
init_method=distributed_init_method,
rank=rank,
world_size=world_size,
)
group = torch.distributed.new_group(ranks, backend="gloo")
return group
# compare result with torch.distributed
def correctness(self, world_size, rank, distributed_init_port):
group = self.init_distributed_env(world_size, rank, distributed_init_port)
self.init_custom_allreduce(rank=rank, world_size=world_size, group=group)
test_loop = 10
for sz in self.test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
for _ in range(test_loop):
inp1 = torch.randint(
1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
)
out1 = torch.empty_like(inp1)
self.custom_allreduce(inp1, out1)
dist.all_reduce(inp1, group=group)
torch.testing.assert_close(out1, inp1)
self.free_custom_allreduce(group)
if __name__ == "__main__":
unittest.main()