Files
sglang/sgl-kernel/tests/test_trt_reduce.py
2024-12-15 13:44:55 +08:00

249 lines
8.3 KiB
Python

import ctypes
import logging
import os
import random
import socket
import time
import unittest
from typing import Any, List, Optional, Union
import ray
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm import _custom_ops as vllm_ops
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
logger = logging.getLogger(__name__)
def get_open_port() -> int:
# try ipv4
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
except OSError:
# try ipv6
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def multi_process_parallel(
world_size: int,
cls: Any,
test_target: Any,
) -> None:
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
# NOTE: We need to set working_dir for distributed tests,
# otherwise we may get import errors on ray workers
ray.init(log_to_driver=True)
distributed_init_port = get_open_port()
refs = []
for rank in range(world_size):
refs.append(test_target.remote(cls, world_size, rank, distributed_init_port))
ray.get(refs)
ray.shutdown()
class TestCustomAllReduce(unittest.TestCase):
@classmethod
def setUpClass(cls):
random.seed(42)
cls.test_sizes = {
2: [512, 4096, 32768, 262144, 2097152],
4: [512, 4096, 32768, 131072],
6: [512, 4096, 32768, 65536],
8: [512, 4096, 32768, 65536],
}
cls.world_sizes = [2, 4, 6, 8]
@staticmethod
def create_shared_buffer(
size_in_bytes: int, group: Optional[ProcessGroup] = None
) -> List[int]:
"""
Creates a shared buffer and returns a list of pointers
representing the buffer on all processes in the group.
"""
lib = CudaRTLibrary()
pointer = lib.cudaMalloc(size_in_bytes)
handle = lib.cudaIpcGetMemHandle(pointer)
world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group)
pointers: List[int] = []
for i, h in enumerate(handles):
if i == rank:
pointers.append(pointer.value) # type: ignore
else:
pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore
return pointers
@staticmethod
def free_shared_buffer(
pointers: List[int], group: Optional[ProcessGroup] = None
) -> None:
rank = dist.get_rank(group=group)
lib = CudaRTLibrary()
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
def test_correctness(self):
for world_size in self.world_sizes:
if world_size > torch.cuda.device_count():
continue
multi_process_parallel(world_size, self, self.correctness)
def test_performance(self):
for world_size in self.world_sizes:
if world_size > torch.cuda.device_count():
continue
multi_process_parallel(world_size, self, self.performance)
def init_custom_allreduce(self, rank, world_size, group):
import sgl_kernel
buffer_max_size = 8 * 1024 * 1024
barrier_max_size = 8 * (24 + 2) * 8
self.buffer_ptrs = self.create_shared_buffer(buffer_max_size, group=group)
self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
self.custom_ptr = sgl_kernel.ops.init_custom_reduce(
rank,
world_size,
self.buffer_ptrs,
self.barrier_in_ptrs,
self.barrier_out_ptrs,
)
def custom_allreduce(self, inp, out):
import sgl_kernel
sgl_kernel.ops.custom_reduce(self.custom_ptr, inp, out)
def free_custom_allreduce(self, group):
import sgl_kernel
self.free_shared_buffer(self.buffer_ptrs, group)
self.free_shared_buffer(self.barrier_in_ptrs, group)
self.free_shared_buffer(self.barrier_out_ptrs, group)
sgl_kernel.ops.custom_dispose(self.custom_ptr)
def init_vllm_allreduce(self, rank, group):
self.vllm_rank = rank
self.vllm_max_size = 8 * 1024 * 1024
self.vllm_meta_ptrs = self.create_shared_buffer(
vllm_ops.meta_size() + self.vllm_max_size, group=group
)
self.vllm_buffer_ptrs = self.create_shared_buffer(
self.vllm_max_size, group=group
)
self.vllm_rank_data = torch.empty(
8 * 1024 * 1024, dtype=torch.uint8, device=torch.device(f"cuda:{rank}")
)
self.vllm_ptr = vllm_ops.init_custom_ar(
self.vllm_meta_ptrs, self.vllm_rank_data, rank, True
)
vllm_ops.register_buffer(self.vllm_ptr, self.vllm_buffer_ptrs)
def vllm_allreduce(self, inp, out):
vllm_ops.all_reduce(
self.vllm_ptr,
inp,
out,
self.vllm_buffer_ptrs[self.vllm_rank],
self.vllm_max_size,
)
def free_vllm_allreduce(self, group):
vllm_ops.dispose(self.vllm_ptr)
self.free_shared_buffer(self.vllm_meta_ptrs, group)
self.free_shared_buffer(self.vllm_buffer_ptrs, group)
@staticmethod
def init_distributed_env(world_size, rank, distributed_init_port):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
ranks = [i for i in range(world_size)]
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
dist.init_process_group(
backend="nccl",
init_method=distributed_init_method,
rank=rank,
world_size=world_size,
)
group = torch.distributed.new_group(ranks, backend="gloo")
return group
# compare result with torch.distributed
@ray.remote(num_gpus=1, max_calls=1)
def correctness(self, world_size, rank, distributed_init_port):
group = self.init_distributed_env(world_size, rank, distributed_init_port)
self.init_custom_allreduce(rank=rank, world_size=world_size, group=group)
test_loop = 10
for sz in self.test_sizes[world_size]:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
for _ in range(test_loop):
inp1 = torch.randint(
1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device()
)
out1 = torch.empty_like(inp1)
self.custom_allreduce(inp1, out1)
dist.all_reduce(inp1, group=group)
torch.testing.assert_close(out1, inp1)
self.free_custom_allreduce(group)
# compare performance with vllm
@ray.remote(num_gpus=1, max_calls=1)
def performance(self, world_size, rank, distributed_init_port):
group = self.init_distributed_env(world_size, rank, distributed_init_port)
self.init_vllm_allreduce(rank, group)
self.init_custom_allreduce(rank=rank, world_size=world_size, group=group)
for sz in self.test_sizes[world_size]:
inp1 = torch.randint(
1, 16, (sz,), dtype=torch.float32, device=torch.cuda.current_device()
)
out1 = torch.empty_like(inp1)
test_loop = 5000
start = time.time()
for _ in range(test_loop):
self.custom_allreduce(inp1, out1)
elapse_custom = time.time() - start
start = time.time()
for _ in range(test_loop):
self.vllm_allreduce(inp1, out1)
elapse_vllm = time.time() - start
if rank == 0:
logger.warning(
f"test_size = {sz}, world_size = {world_size}, "
f"vllm time = {elapse_vllm * 1000 / test_loop:.4f}ms,"
f"custom time = {elapse_custom * 1000 / test_loop:.4f}ms"
)
self.free_custom_allreduce(group)
self.free_vllm_allreduce(group)
if __name__ == "__main__":
unittest.main()