[Feature] use pytest for sgl-kernel (#4896)
This commit is contained in:
committed by
GitHub
parent
4ede6770cd
commit
9fccda3111
@@ -13,155 +13,186 @@ from torch.distributed import ProcessGroup
|
||||
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
|
||||
|
||||
def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes):
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
ranks = list(range(world_size))
|
||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
init_method=distributed_init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
group = dist.group.WORLD
|
||||
|
||||
buffer_max_size = 8 * 1024 * 1024
|
||||
barrier_max_size = 8 * (24 + 2) * 8
|
||||
buffer_ptrs = None
|
||||
tmp_result_buffer_ptrs = None
|
||||
barrier_in_ptrs = None
|
||||
barrier_out_ptrs = None
|
||||
custom_ptr = None
|
||||
|
||||
try:
|
||||
buffer_ptrs = TestCustomAllReduce.create_shared_buffer(
|
||||
buffer_max_size, group=group
|
||||
)
|
||||
tmp_result_buffer_ptrs = TestCustomAllReduce.create_shared_buffer(
|
||||
buffer_max_size, group=group
|
||||
)
|
||||
barrier_in_ptrs = TestCustomAllReduce.create_shared_buffer(
|
||||
barrier_max_size, group=group
|
||||
)
|
||||
barrier_out_ptrs = TestCustomAllReduce.create_shared_buffer(
|
||||
barrier_max_size, group=group
|
||||
)
|
||||
|
||||
rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=device)
|
||||
|
||||
custom_ptr = custom_ops.init_custom_reduce(
|
||||
rank,
|
||||
world_size,
|
||||
rank_data,
|
||||
buffer_ptrs,
|
||||
tmp_result_buffer_ptrs,
|
||||
barrier_in_ptrs,
|
||||
barrier_out_ptrs,
|
||||
)
|
||||
|
||||
test_loop = 10
|
||||
for sz in test_sizes:
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
for _ in range(test_loop):
|
||||
inp1 = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
|
||||
inp1_ref = inp1.clone()
|
||||
out1 = torch.empty_like(inp1)
|
||||
|
||||
custom_ops.custom_reduce(custom_ptr, inp1, out1)
|
||||
|
||||
dist.all_reduce(inp1_ref, group=group)
|
||||
|
||||
torch.testing.assert_close(out1, inp1_ref)
|
||||
|
||||
finally:
|
||||
dist.barrier(group=group)
|
||||
if custom_ptr is not None:
|
||||
custom_ops.custom_dispose(custom_ptr)
|
||||
if buffer_ptrs:
|
||||
TestCustomAllReduce.free_shared_buffer(buffer_ptrs, group)
|
||||
if tmp_result_buffer_ptrs:
|
||||
TestCustomAllReduce.free_shared_buffer(tmp_result_buffer_ptrs, group)
|
||||
if barrier_in_ptrs:
|
||||
TestCustomAllReduce.free_shared_buffer(barrier_in_ptrs, group)
|
||||
if barrier_out_ptrs:
|
||||
TestCustomAllReduce.free_shared_buffer(barrier_out_ptrs, group)
|
||||
|
||||
dist.destroy_process_group(group=group)
|
||||
|
||||
|
||||
def get_open_port() -> int:
|
||||
# try ipv4
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
except OSError:
|
||||
# try ipv6
|
||||
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
s.bind(("::1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def multi_process_parallel(
|
||||
world_size: int,
|
||||
test_target: Any,
|
||||
world_size: int, test_target: Any, target_args: tuple = ()
|
||||
) -> None:
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
procs = []
|
||||
distributed_init_port = get_open_port()
|
||||
for i in range(world_size):
|
||||
proc = mp.Process(
|
||||
target=test_target,
|
||||
args=(world_size, i, distributed_init_port),
|
||||
)
|
||||
proc_args = (world_size, i, distributed_init_port) + target_args
|
||||
proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
|
||||
for i in range(world_size):
|
||||
procs[i].join()
|
||||
assert procs[i].exitcode == 0
|
||||
assert (
|
||||
procs[i].exitcode == 0
|
||||
), f"Process {i} failed with exit code {procs[i].exitcode}"
|
||||
|
||||
|
||||
class TestCustomAllReduce(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
random.seed(42)
|
||||
cls.test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152]
|
||||
cls.world_sizes = [2, 4, 8]
|
||||
test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152]
|
||||
world_sizes = [2, 4, 8]
|
||||
|
||||
@staticmethod
|
||||
def create_shared_buffer(
|
||||
size_in_bytes: int, group: Optional[ProcessGroup] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Creates a shared buffer and returns a list of pointers
|
||||
representing the buffer on all processes in the group.
|
||||
"""
|
||||
lib = CudaRTLibrary()
|
||||
pointer = lib.cudaMalloc(size_in_bytes)
|
||||
handle = lib.cudaIpcGetMemHandle(pointer)
|
||||
if group is None:
|
||||
group = dist.group.WORLD
|
||||
world_size = dist.get_world_size(group=group)
|
||||
rank = dist.get_rank(group=group)
|
||||
handles = [None] * world_size
|
||||
dist.all_gather_object(handles, handle, group=group)
|
||||
|
||||
handle_bytes = ctypes.string_at(ctypes.addressof(handle), ctypes.sizeof(handle))
|
||||
input_tensor = torch.ByteTensor(list(handle_bytes)).to(f"cuda:{rank}")
|
||||
gathered_tensors = [torch.empty_like(input_tensor) for _ in range(world_size)]
|
||||
dist.all_gather(gathered_tensors, input_tensor, group=group)
|
||||
|
||||
handles = []
|
||||
handle_type = type(handle)
|
||||
for tensor in gathered_tensors:
|
||||
bytes_list = tensor.cpu().tolist()
|
||||
bytes_data = bytes(bytes_list)
|
||||
handle_obj = handle_type()
|
||||
ctypes.memmove(ctypes.addressof(handle_obj), bytes_data, len(bytes_data))
|
||||
handles.append(handle_obj)
|
||||
|
||||
pointers: List[int] = []
|
||||
for i, h in enumerate(handles):
|
||||
if i == rank:
|
||||
pointers.append(pointer.value) # type: ignore
|
||||
pointers.append(pointer.value)
|
||||
else:
|
||||
pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore
|
||||
try:
|
||||
opened_ptr = lib.cudaIpcOpenMemHandle(h)
|
||||
pointers.append(opened_ptr.value)
|
||||
except Exception as e:
|
||||
print(f"Rank {rank}: Failed to open IPC handle from rank {i}: {e}")
|
||||
raise
|
||||
|
||||
dist.barrier(group=group)
|
||||
return pointers
|
||||
|
||||
@staticmethod
|
||||
def free_shared_buffer(
|
||||
pointers: List[int], group: Optional[ProcessGroup] = None
|
||||
) -> None:
|
||||
if group is None:
|
||||
group = dist.group.WORLD
|
||||
rank = dist.get_rank(group=group)
|
||||
lib = CudaRTLibrary()
|
||||
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
|
||||
if pointers and len(pointers) > rank and pointers[rank] is not None:
|
||||
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
|
||||
dist.barrier(group=group)
|
||||
|
||||
def test_correctness(self):
|
||||
for world_size in self.world_sizes:
|
||||
if world_size > torch.cuda.device_count():
|
||||
available_gpus = torch.cuda.device_count()
|
||||
if world_size > available_gpus:
|
||||
print(
|
||||
f"Skipping world_size={world_size}, requires {world_size} GPUs, found {available_gpus}"
|
||||
)
|
||||
continue
|
||||
multi_process_parallel(world_size, self.correctness)
|
||||
|
||||
print(f"Running test for world_size={world_size}")
|
||||
multi_process_parallel(
|
||||
world_size, _run_correctness_worker, target_args=(self.test_sizes,)
|
||||
)
|
||||
print(f"custom allreduce tp = {world_size}: OK")
|
||||
|
||||
def init_custom_allreduce(self, rank, world_size, group):
|
||||
buffer_max_size = 8 * 1024 * 1024
|
||||
barrier_max_size = 8 * (24 + 2) * 8
|
||||
|
||||
self.buffer_ptrs = self.create_shared_buffer(buffer_max_size, group=group)
|
||||
self.tmp_result_buffer_ptrs = self.create_shared_buffer(
|
||||
buffer_max_size, group=group
|
||||
)
|
||||
self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
|
||||
self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
|
||||
self.rank_data = torch.empty(
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0")
|
||||
)
|
||||
|
||||
self.custom_ptr = custom_ops.init_custom_reduce(
|
||||
rank,
|
||||
world_size,
|
||||
self.rank_data,
|
||||
self.buffer_ptrs,
|
||||
self.tmp_result_buffer_ptrs,
|
||||
self.barrier_in_ptrs,
|
||||
self.barrier_out_ptrs,
|
||||
)
|
||||
|
||||
def custom_allreduce(self, inp, out):
|
||||
custom_ops.custom_reduce(self.custom_ptr, inp, out)
|
||||
|
||||
def free_custom_allreduce(self, group):
|
||||
self.free_shared_buffer(self.buffer_ptrs, group)
|
||||
self.free_shared_buffer(self.tmp_result_buffer_ptrs, group)
|
||||
self.free_shared_buffer(self.barrier_in_ptrs, group)
|
||||
self.free_shared_buffer(self.barrier_out_ptrs, group)
|
||||
custom_ops.custom_dispose(self.custom_ptr)
|
||||
|
||||
@staticmethod
|
||||
def init_distributed_env(world_size, rank, distributed_init_port):
|
||||
device = torch.device("cuda:0")
|
||||
torch.cuda.set_device(device)
|
||||
ranks = [i for i in range(world_size)]
|
||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
init_method=distributed_init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
group = torch.distributed.new_group(ranks, backend="gloo")
|
||||
return group
|
||||
|
||||
# compare result with torch.distributed
|
||||
def correctness(self, world_size, rank, distributed_init_port):
|
||||
group = self.init_distributed_env(world_size, rank, distributed_init_port)
|
||||
|
||||
self.init_custom_allreduce(rank=rank, world_size=world_size, group=group)
|
||||
|
||||
test_loop = 10
|
||||
for sz in self.test_sizes:
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
for _ in range(test_loop):
|
||||
inp1 = torch.randint(
|
||||
1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
|
||||
)
|
||||
out1 = torch.empty_like(inp1)
|
||||
self.custom_allreduce(inp1, out1)
|
||||
|
||||
dist.all_reduce(inp1, group=group)
|
||||
torch.testing.assert_close(out1, inp1)
|
||||
|
||||
self.free_custom_allreduce(group)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user