fix allreduce test (#4909)
This commit is contained in:
@@ -1,22 +1,17 @@
|
|||||||
import ctypes
|
import ctypes
|
||||||
import logging
|
import multiprocessing as mp
|
||||||
import random
|
import random
|
||||||
import socket
|
import socket
|
||||||
import time
|
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import ray
|
|
||||||
import sgl_kernel.allreduce as custom_ops
|
import sgl_kernel.allreduce as custom_ops
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from vllm import _custom_ops as vllm_ops
|
|
||||||
|
|
||||||
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def get_open_port() -> int:
|
def get_open_port() -> int:
|
||||||
# try ipv4
|
# try ipv4
|
||||||
@@ -33,22 +28,21 @@ def get_open_port() -> int:
|
|||||||
|
|
||||||
def multi_process_parallel(
|
def multi_process_parallel(
|
||||||
world_size: int,
|
world_size: int,
|
||||||
cls: Any,
|
|
||||||
test_target: Any,
|
test_target: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Using ray helps debugging the error when it failed
|
procs = []
|
||||||
# as compared to multiprocessing.
|
|
||||||
# NOTE: We need to set working_dir for distributed tests,
|
|
||||||
# otherwise we may get import errors on ray workers
|
|
||||||
ray.init(log_to_driver=True)
|
|
||||||
|
|
||||||
distributed_init_port = get_open_port()
|
distributed_init_port = get_open_port()
|
||||||
refs = []
|
for i in range(world_size):
|
||||||
for rank in range(world_size):
|
proc = mp.Process(
|
||||||
refs.append(test_target.remote(cls, world_size, rank, distributed_init_port))
|
target=test_target,
|
||||||
ray.get(refs)
|
args=(world_size, i, distributed_init_port),
|
||||||
|
)
|
||||||
|
proc.start()
|
||||||
|
procs.append(proc)
|
||||||
|
|
||||||
ray.shutdown()
|
for i in range(world_size):
|
||||||
|
procs[i].join()
|
||||||
|
assert procs[i].exitcode == 0
|
||||||
|
|
||||||
|
|
||||||
class TestCustomAllReduce(unittest.TestCase):
|
class TestCustomAllReduce(unittest.TestCase):
|
||||||
@@ -95,13 +89,8 @@ class TestCustomAllReduce(unittest.TestCase):
|
|||||||
for world_size in self.world_sizes:
|
for world_size in self.world_sizes:
|
||||||
if world_size > torch.cuda.device_count():
|
if world_size > torch.cuda.device_count():
|
||||||
continue
|
continue
|
||||||
multi_process_parallel(world_size, self, self.correctness)
|
multi_process_parallel(world_size, self.correctness)
|
||||||
|
print(f"custom allreduce tp = {world_size}: OK")
|
||||||
def test_performance(self):
|
|
||||||
for world_size in self.world_sizes:
|
|
||||||
if world_size > torch.cuda.device_count():
|
|
||||||
continue
|
|
||||||
multi_process_parallel(world_size, self, self.performance)
|
|
||||||
|
|
||||||
def init_custom_allreduce(self, rank, world_size, group):
|
def init_custom_allreduce(self, rank, world_size, group):
|
||||||
buffer_max_size = 8 * 1024 * 1024
|
buffer_max_size = 8 * 1024 * 1024
|
||||||
@@ -137,37 +126,6 @@ class TestCustomAllReduce(unittest.TestCase):
|
|||||||
self.free_shared_buffer(self.barrier_out_ptrs, group)
|
self.free_shared_buffer(self.barrier_out_ptrs, group)
|
||||||
custom_ops.custom_dispose(self.custom_ptr)
|
custom_ops.custom_dispose(self.custom_ptr)
|
||||||
|
|
||||||
def init_vllm_allreduce(self, rank, group):
|
|
||||||
self.vllm_rank = rank
|
|
||||||
self.vllm_max_size = 8 * 1024 * 1024
|
|
||||||
self.vllm_meta_ptrs = self.create_shared_buffer(
|
|
||||||
vllm_ops.meta_size() + self.vllm_max_size, group=group
|
|
||||||
)
|
|
||||||
self.vllm_buffer_ptrs = self.create_shared_buffer(
|
|
||||||
self.vllm_max_size, group=group
|
|
||||||
)
|
|
||||||
self.vllm_rank_data = torch.empty(
|
|
||||||
8 * 1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0")
|
|
||||||
)
|
|
||||||
self.vllm_ptr = vllm_ops.init_custom_ar(
|
|
||||||
self.vllm_meta_ptrs, self.vllm_rank_data, rank, True
|
|
||||||
)
|
|
||||||
vllm_ops.register_buffer(self.vllm_ptr, self.vllm_buffer_ptrs)
|
|
||||||
|
|
||||||
def vllm_allreduce(self, inp, out):
|
|
||||||
vllm_ops.all_reduce(
|
|
||||||
self.vllm_ptr,
|
|
||||||
inp,
|
|
||||||
out,
|
|
||||||
self.vllm_buffer_ptrs[self.vllm_rank],
|
|
||||||
self.vllm_max_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def free_vllm_allreduce(self, group):
|
|
||||||
vllm_ops.dispose(self.vllm_ptr)
|
|
||||||
self.free_shared_buffer(self.vllm_meta_ptrs, group)
|
|
||||||
self.free_shared_buffer(self.vllm_buffer_ptrs, group)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_distributed_env(world_size, rank, distributed_init_port):
|
def init_distributed_env(world_size, rank, distributed_init_port):
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
@@ -184,7 +142,6 @@ class TestCustomAllReduce(unittest.TestCase):
|
|||||||
return group
|
return group
|
||||||
|
|
||||||
# compare result with torch.distributed
|
# compare result with torch.distributed
|
||||||
@ray.remote(num_gpus=1, max_calls=1)
|
|
||||||
def correctness(self, world_size, rank, distributed_init_port):
|
def correctness(self, world_size, rank, distributed_init_port):
|
||||||
group = self.init_distributed_env(world_size, rank, distributed_init_port)
|
group = self.init_distributed_env(world_size, rank, distributed_init_port)
|
||||||
|
|
||||||
@@ -205,40 +162,6 @@ class TestCustomAllReduce(unittest.TestCase):
|
|||||||
|
|
||||||
self.free_custom_allreduce(group)
|
self.free_custom_allreduce(group)
|
||||||
|
|
||||||
# compare performance with vllm
|
|
||||||
@ray.remote(num_gpus=1, max_calls=1)
|
|
||||||
def performance(self, world_size, rank, distributed_init_port):
|
|
||||||
group = self.init_distributed_env(world_size, rank, distributed_init_port)
|
|
||||||
|
|
||||||
self.init_vllm_allreduce(rank, group)
|
|
||||||
self.init_custom_allreduce(rank=rank, world_size=world_size, group=group)
|
|
||||||
|
|
||||||
for sz in self.test_sizes:
|
|
||||||
inp1 = torch.randint(
|
|
||||||
1, 16, (sz,), dtype=torch.float32, device=torch.cuda.current_device()
|
|
||||||
)
|
|
||||||
out1 = torch.empty_like(inp1)
|
|
||||||
test_loop = 5000
|
|
||||||
start = time.time()
|
|
||||||
for _ in range(test_loop):
|
|
||||||
self.custom_allreduce(inp1, out1)
|
|
||||||
elapse_custom = time.time() - start
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
for _ in range(test_loop):
|
|
||||||
self.vllm_allreduce(inp1, out1)
|
|
||||||
elapse_vllm = time.time() - start
|
|
||||||
|
|
||||||
if rank == 0:
|
|
||||||
logger.warning(
|
|
||||||
f"test_size = {sz}, world_size = {world_size}, "
|
|
||||||
f"vllm time = {elapse_vllm * 1000 / test_loop:.4f}ms, "
|
|
||||||
f"custom time = {elapse_custom * 1000 / test_loop:.4f}ms "
|
|
||||||
)
|
|
||||||
|
|
||||||
self.free_custom_allreduce(group)
|
|
||||||
self.free_vllm_allreduce(group)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user