From 767c9dec03e306400d162785ba77615ab40ae6c8 Mon Sep 17 00:00:00 2001 From: yizhang2077 <1109276519@qq.com> Date: Thu, 16 Jan 2025 04:57:35 +0800 Subject: [PATCH] adapt custom allreduce for tensorrt llm (#2511) --- python/pyproject.toml | 2 +- python/sglang/srt/_custom_ops.py | 49 +++--- .../device_communicators/custom_all_reduce.py | 94 +++++----- test/srt/run_suite.py | 1 + test/srt/test_custom_allreduce.py | 164 ++++++++++++++++++ 5 files changed, 242 insertions(+), 68 deletions(-) create mode 100644 test/srt/test_custom_allreduce.py diff --git a/python/pyproject.toml b/python/pyproject.toml index fe68e59ea..ea7c2482a 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -27,7 +27,7 @@ runtime_common = [ ] srt = [ "sglang[runtime_common]", "cuda-python", - "sgl-kernel>=0.0.2.post12", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", + "sgl-kernel>=0.0.2.post14", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "flashinfer==0.1.6" ] diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index 9eb7caa1b..f59f67605 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -1,4 +1,4 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/_custom_ops.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py import contextlib import functools import importlib @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) if not is_hpu(): try: - import custom_ar + import sgl_kernel except ImportError as e: logger.warning("Failed to import from custom_ar with %r", e) @@ -50,46 +50,41 @@ def hint_on_error(fn): # custom ar def init_custom_ar( - ipc_tensors: List[torch.Tensor], - rank_data: torch.Tensor, - rank: int, - full_nvlink: bool, + rank_id: int, + world_size: int, + rank_data_base: torch.Tensor, + buffers: List[int], + tmp_result_buffers: List[int], + barrier_in: List[int], + barrier_out: List[int], ) -> int: - return torch.ops._C_vllm_ar.init_custom_ar( - ipc_tensors, rank_data, rank, full_nvlink + return sgl_kernel.ops.init_custom_reduce( + rank_id, + world_size, + rank_data_base, + buffers, + tmp_result_buffers, + barrier_in, + barrier_out, ) -def all_reduce( - fa: int, - inp: torch.Tensor, - out: torch.Tensor, - reg_buffer: int, - reg_buffer_sz_bytes: int, -) -> None: - torch.ops._C_vllm_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes) +def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: + sgl_kernel.ops.custom_reduce(fa, inp, out) def dispose(fa: int) -> None: - torch.ops._C_vllm_ar.dispose(fa) - - -def meta_size() -> int: - return torch.ops._C_vllm_ar.meta_size() - - -def register_buffer(fa: int, ipc_tensors: List[int]) -> None: - return torch.ops._C_vllm_ar.register_buffer(fa, ipc_tensors) + sgl_kernel.ops.custom_dispose(fa) def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: - return torch.ops._C_vllm_ar.get_graph_buffer_ipc_meta(fa) + return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa) def register_graph_buffers( fa: int, handles: List[List[int]], offsets: List[List[int]] ) -> None: - torch.ops._C_vllm_ar.register_graph_buffers(fa, handles, offsets) + sgl_kernel.ops.register_graph_buffers(fa, handles, offsets) # temporary fix for https://github.com/vllm-project/vllm/issues/5456 diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index b6df23440..ba9feb59d 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -21,7 +21,8 @@ from sglang.srt.distributed.parallel_state import in_the_same_node_as from sglang.srt.utils import cuda_device_count_stateless, is_cuda try: - ops.meta_size() + import sgl_kernel + custom_ar = True except Exception: # For AMD GPUs and CPUs @@ -29,7 +30,6 @@ except Exception: logger = logging.getLogger(__name__) - _P = ParamSpec("_P") _R = TypeVar("_R") @@ -47,7 +47,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: @with_nvml_context -def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool: +def is_full_nvlink(physical_device_ids: List[int]) -> bool: """ query if the set of gpus are fully connected by nvlink (1 hop) """ @@ -196,32 +196,39 @@ class CustomAllreduce: ) return - self.disabled = False - # Buffers memory are owned by this Python class and passed to C++. - # Meta data composes of two parts: meta data for synchronization and a - # temporary buffer for storing intermediate allreduce results. - self.meta_ptrs = self.create_shared_buffer( - ops.meta_size() + max_size, group=group - ) - # This is a pre-registered IPC buffer. In eager mode, input tensors - # are first copied into this buffer before allreduce is performed - self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) - # This is a buffer for storing the tuples of pointers pointing to - # IPC buffers from all ranks. Each registered tuple has size of - # 8*world_size bytes where world_size is at most 8. Allocating 8MB - # is enough for 131072 such tuples. The largest model I've seen only - # needs less than 10000 of registered tuples. - self.rank_data = torch.empty( - 8 * 1024 * 1024, dtype=torch.uint8, device=self.device - ) self.max_size = max_size self.rank = rank self.world_size = world_size self.full_nvlink = full_nvlink - self._ptr = ops.init_custom_ar( - self.meta_ptrs, self.rank_data, rank, self.full_nvlink + + # From TensorRT-LLM getMaxRequiredWorkspaceSize + self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024] + + # sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE; + self.barrier_max_size = 8 * (36 + 2) * 8 + + self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) + self.tmp_result_buffer_ptrs = self.create_shared_buffer(max_size, group=group) + self.rank_data_base = torch.empty( + 8 * 1024 * 1024, dtype=torch.uint8, device=self.device ) - ops.register_buffer(self._ptr, self.buffer_ptrs) + self.barrier_in_ptrs = self.create_shared_buffer( + self.barrier_max_size, group=group + ) + self.barrier_out_ptrs = self.create_shared_buffer( + self.barrier_max_size, group=group + ) + + self._ptr = ops.init_custom_ar( + rank, + world_size, + self.rank_data_base, + self.buffer_ptrs, + self.tmp_result_buffer_ptrs, + self.barrier_in_ptrs, + self.barrier_out_ptrs, + ) + self.disabled = False @staticmethod def create_shared_buffer( @@ -300,12 +307,25 @@ class CustomAllreduce: return False # for 4 or more non NVLink-capable GPUs, custom allreduce provides # little performance improvement over NCCL. - if self.world_size == 2 or self.full_nvlink: - return inp_size < self.max_size + if self.world_size == 2: + return ( + inp_size < self.max_size + and inp_size < self.max_required_workspace_size[0] + ) + + if self.full_nvlink: + return ( + inp_size < self.max_size + and inp_size < self.max_required_workspace_size[1] + ) + return False def all_reduce( - self, inp: torch.Tensor, *, out: torch.Tensor = None, registered: bool = False + self, + inp: torch.Tensor, + *, + out: torch.Tensor = None, ): """Performs an out-of-place all reduce. @@ -315,12 +335,7 @@ class CustomAllreduce: """ if out is None: out = torch.empty_like(inp) - if registered: - ops.all_reduce(self._ptr, inp, out, 0, 0) - else: - ops.all_reduce( - self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size - ) + ops.all_reduce(self._ptr, inp, out) return out def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: @@ -330,23 +345,22 @@ class CustomAllreduce: return None if self._IS_CAPTURING: if torch.cuda.is_current_stream_capturing(): - return self.all_reduce(input, registered=True) + return self.all_reduce(input) else: # If warm up, mimic the allocation pattern since custom # allreduce is out-of-place. return torch.empty_like(input) else: - # Note: outside of cuda graph context, custom allreduce incurs a - # cost of cudaMemcpy, which should be small (<=1% of overall - # latency) compared to the performance gain of using custom kernels - return self.all_reduce(input, registered=False) + return self.all_reduce(input) def close(self): if not self.disabled and self._ptr: ops.dispose(self._ptr) - self._ptr = 0 - self.free_shared_buffer(self.meta_ptrs) self.free_shared_buffer(self.buffer_ptrs) + self.free_shared_buffer(self.tmp_result_buffer_ptrs) + self.free_shared_buffer(self.barrier_in_ptrs) + self.free_shared_buffer(self.barrier_out_ptrs) + self._ptr = 0 def __del__(self): self.close() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index ad5aa6aa5..b00c866a9 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -12,6 +12,7 @@ suites = { "sampling/penaltylib", "test_abort.py", "test_chunked_prefill.py", + "test_custom_allreduce.py", "test_double_sparsity.py", "test_eagle_infer.py", "test_embedding_openai_server.py", diff --git a/test/srt/test_custom_allreduce.py b/test/srt/test_custom_allreduce.py new file mode 100644 index 000000000..5f6f5d9b4 --- /dev/null +++ b/test/srt/test_custom_allreduce.py @@ -0,0 +1,164 @@ +import os +import random +import socket +import unittest +from typing import Any + +import ray +import torch +import torch.distributed as dist + +from sglang.srt.distributed import init_distributed_environment +from sglang.srt.distributed.communication_op import ( # noqa + tensor_model_parallel_all_reduce, +) +from sglang.srt.distributed.parallel_state import ( + get_tensor_model_parallel_group, + graph_capture, + initialize_model_parallel, +) + + +def get_open_port() -> int: + # try ipv4 + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + except OSError: + # try ipv6 + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def multi_process_parallel( + world_size: int, + cls: Any, + test_target: Any, +) -> None: + + # Using ray helps debugging the error when it failed + # as compared to multiprocessing. + # NOTE: We need to set working_dir for distributed tests, + # otherwise we may get import errors on ray workers + ray.init(log_to_driver=False) + + distributed_init_port = get_open_port() + refs = [] + for rank in range(world_size): + refs.append(test_target.remote(cls, world_size, rank, distributed_init_port)) + ray.get(refs) + + ray.shutdown() + + +class TestCustomAllReduce(unittest.TestCase): + @classmethod + def setUpClass(cls): + random.seed(42) + # 512B to 32MB + cls.test_sizes = [512, 4096, 32768, 262144, 2097152, 16777216, 33554432] + cls.world_sizes = [2, 4, 6, 8] + cls.test_loop = 10 + + def test_graph_allreduce(self): + for world_size in self.world_sizes: + if world_size > torch.cuda.device_count(): + continue + multi_process_parallel(world_size, self, self.graph_allreduce) + + def test_eager_allreduce(self): + for world_size in self.world_sizes: + if world_size > torch.cuda.device_count(): + continue + multi_process_parallel(world_size, self, self.eager_allreduce) + + @ray.remote(num_gpus=1, max_calls=1) + def graph_allreduce(self, world_size, rank, distributed_init_port): + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + init_distributed_environment( + world_size=world_size, + rank=rank, + distributed_init_method=distributed_init_method, + local_rank=rank, + ) + initialize_model_parallel(tensor_model_parallel_size=world_size) + group = get_tensor_model_parallel_group().device_group + + # A small all_reduce for warmup. + # this is needed because device communicators might be created lazily + # (e.g. NCCL). This will ensure that the communicator is initialized + # before any communication happens, so that this group can be used for + # graph capture immediately. + data = torch.zeros(1) + data = data.to(device=device) + torch.distributed.all_reduce(data, group=group) + torch.cuda.synchronize() + del data + + for sz in self.test_sizes: + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + for _ in range(self.test_loop): + with graph_capture() as graph_capture_context: + # use integers so result matches NCCL exactly + inp1 = torch.randint( + 1, + 16, + (sz,), + dtype=dtype, + device=torch.cuda.current_device(), + ) + inp2 = torch.randint( + 1, + 16, + (sz,), + dtype=dtype, + device=torch.cuda.current_device(), + ) + torch.cuda.synchronize() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph( + graph, stream=graph_capture_context.stream + ): + out1 = tensor_model_parallel_all_reduce(inp1) + # the input buffer is immediately modified to test + # synchronization + dist.all_reduce(inp1, group=group) + out2 = tensor_model_parallel_all_reduce(inp2) + dist.all_reduce(inp2, group=group) + graph.replay() + torch.testing.assert_close(out1, inp1) + torch.testing.assert_close(out2, inp2) + + @ray.remote(num_gpus=1, max_calls=1) + def eager_allreduce(self, world_size, rank, distributed_init_port): + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + init_distributed_environment( + world_size=world_size, + rank=rank, + distributed_init_method=distributed_init_method, + local_rank=rank, + ) + initialize_model_parallel(tensor_model_parallel_size=world_size) + group = get_tensor_model_parallel_group().device_group + + for sz in self.test_sizes: + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + for _ in range(self.test_loop): + inp1 = torch.randint( + 1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device() + ) + out1 = tensor_model_parallel_all_reduce(inp1) + dist.all_reduce(inp1, group=group) + torch.testing.assert_close(out1, inp1) + + +if __name__ == "__main__": + unittest.main()