From 5ec5eaf76067ae804b1e2aa74a4df132837fd0bf Mon Sep 17 00:00:00 2001 From: Yi Zhang <1109276519@qq.com> Date: Sun, 30 Mar 2025 14:16:53 +0800 Subject: [PATCH] fix allreduce test (#4909) --- sgl-kernel/tests/test_trt_allreduce.py | 105 ++++--------------------- 1 file changed, 14 insertions(+), 91 deletions(-) diff --git a/sgl-kernel/tests/test_trt_allreduce.py b/sgl-kernel/tests/test_trt_allreduce.py index 9bbc4e76f..910bcb253 100644 --- a/sgl-kernel/tests/test_trt_allreduce.py +++ b/sgl-kernel/tests/test_trt_allreduce.py @@ -1,22 +1,17 @@ import ctypes -import logging +import multiprocessing as mp import random import socket -import time import unittest from typing import Any, List, Optional -import ray import sgl_kernel.allreduce as custom_ops import torch import torch.distributed as dist from torch.distributed import ProcessGroup -from vllm import _custom_ops as vllm_ops from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary -logger = logging.getLogger(__name__) - def get_open_port() -> int: # try ipv4 @@ -33,22 +28,21 @@ def get_open_port() -> int: def multi_process_parallel( world_size: int, - cls: Any, test_target: Any, ) -> None: - # Using ray helps debugging the error when it failed - # as compared to multiprocessing. - # NOTE: We need to set working_dir for distributed tests, - # otherwise we may get import errors on ray workers - ray.init(log_to_driver=True) - + procs = [] distributed_init_port = get_open_port() - refs = [] - for rank in range(world_size): - refs.append(test_target.remote(cls, world_size, rank, distributed_init_port)) - ray.get(refs) + for i in range(world_size): + proc = mp.Process( + target=test_target, + args=(world_size, i, distributed_init_port), + ) + proc.start() + procs.append(proc) - ray.shutdown() + for i in range(world_size): + procs[i].join() + assert procs[i].exitcode == 0 class TestCustomAllReduce(unittest.TestCase): @@ -95,13 +89,8 @@ class TestCustomAllReduce(unittest.TestCase): for world_size in self.world_sizes: if world_size > torch.cuda.device_count(): continue - multi_process_parallel(world_size, self, self.correctness) - - def test_performance(self): - for world_size in self.world_sizes: - if world_size > torch.cuda.device_count(): - continue - multi_process_parallel(world_size, self, self.performance) + multi_process_parallel(world_size, self.correctness) + print(f"custom allreduce tp = {world_size}: OK") def init_custom_allreduce(self, rank, world_size, group): buffer_max_size = 8 * 1024 * 1024 @@ -137,37 +126,6 @@ class TestCustomAllReduce(unittest.TestCase): self.free_shared_buffer(self.barrier_out_ptrs, group) custom_ops.custom_dispose(self.custom_ptr) - def init_vllm_allreduce(self, rank, group): - self.vllm_rank = rank - self.vllm_max_size = 8 * 1024 * 1024 - self.vllm_meta_ptrs = self.create_shared_buffer( - vllm_ops.meta_size() + self.vllm_max_size, group=group - ) - self.vllm_buffer_ptrs = self.create_shared_buffer( - self.vllm_max_size, group=group - ) - self.vllm_rank_data = torch.empty( - 8 * 1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0") - ) - self.vllm_ptr = vllm_ops.init_custom_ar( - self.vllm_meta_ptrs, self.vllm_rank_data, rank, True - ) - vllm_ops.register_buffer(self.vllm_ptr, self.vllm_buffer_ptrs) - - def vllm_allreduce(self, inp, out): - vllm_ops.all_reduce( - self.vllm_ptr, - inp, - out, - self.vllm_buffer_ptrs[self.vllm_rank], - self.vllm_max_size, - ) - - def free_vllm_allreduce(self, group): - vllm_ops.dispose(self.vllm_ptr) - self.free_shared_buffer(self.vllm_meta_ptrs, group) - self.free_shared_buffer(self.vllm_buffer_ptrs, group) - @staticmethod def init_distributed_env(world_size, rank, distributed_init_port): device = torch.device("cuda:0") @@ -184,7 +142,6 @@ class TestCustomAllReduce(unittest.TestCase): return group # compare result with torch.distributed - @ray.remote(num_gpus=1, max_calls=1) def correctness(self, world_size, rank, distributed_init_port): group = self.init_distributed_env(world_size, rank, distributed_init_port) @@ -205,40 +162,6 @@ class TestCustomAllReduce(unittest.TestCase): self.free_custom_allreduce(group) - # compare performance with vllm - @ray.remote(num_gpus=1, max_calls=1) - def performance(self, world_size, rank, distributed_init_port): - group = self.init_distributed_env(world_size, rank, distributed_init_port) - - self.init_vllm_allreduce(rank, group) - self.init_custom_allreduce(rank=rank, world_size=world_size, group=group) - - for sz in self.test_sizes: - inp1 = torch.randint( - 1, 16, (sz,), dtype=torch.float32, device=torch.cuda.current_device() - ) - out1 = torch.empty_like(inp1) - test_loop = 5000 - start = time.time() - for _ in range(test_loop): - self.custom_allreduce(inp1, out1) - elapse_custom = time.time() - start - - start = time.time() - for _ in range(test_loop): - self.vllm_allreduce(inp1, out1) - elapse_vllm = time.time() - start - - if rank == 0: - logger.warning( - f"test_size = {sz}, world_size = {world_size}, " - f"vllm time = {elapse_vllm * 1000 / test_loop:.4f}ms, " - f"custom time = {elapse_custom * 1000 / test_loop:.4f}ms " - ) - - self.free_custom_allreduce(group) - self.free_vllm_allreduce(group) - if __name__ == "__main__": unittest.main()