From 89cd923581fec16d70ed536eceac7212dc6e0898 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 20 Jan 2025 04:03:15 -0800 Subject: [PATCH] Roll back to use vllm custom allreduce (#3006) --- python/sglang/srt/_custom_ops.py | 2 +- python/sglang/srt/distributed/__init__.py | 6 +- .../srt/distributed/communication_op.py | 2 +- .../custom_all_reduce_utils.py | 1 - .../device_communicators/pynccl_wrapper.py | 2 +- .../device_communicators/shm_broadcast.py | 2 +- python/sglang/srt/layers/attention/vision.py | 4 +- .../srt/model_executor/cuda_graph_runner.py | 3 - .../sglang/srt/model_executor/model_runner.py | 5 +- python/sglang/srt/utils.py | 56 ++----------------- 10 files changed, 18 insertions(+), 65 deletions(-) diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index 3c00a8552..3cb313b91 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -12,7 +12,7 @@ import torch.library from sglang.srt.utils import is_hpu logger = logging.getLogger(__name__) -use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=False) +use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=True) if not is_hpu(): if use_vllm_custom_allreduce: diff --git a/python/sglang/srt/distributed/__init__.py b/python/sglang/srt/distributed/__init__.py index db325cfab..12f802055 100644 --- a/python/sglang/srt/distributed/__init__.py +++ b/python/sglang/srt/distributed/__init__.py @@ -1,3 +1,3 @@ -from .communication_op import * -from .parallel_state import * -from .utils import * +from sglang.srt.distributed.communication_op import * +from sglang.srt.distributed.parallel_state import * +from sglang.srt.distributed.utils import * diff --git a/python/sglang/srt/distributed/communication_op.py b/python/sglang/srt/distributed/communication_op.py index ddf3b8ef5..7895508cd 100644 --- a/python/sglang/srt/distributed/communication_op.py +++ b/python/sglang/srt/distributed/communication_op.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Optional, Union import torch import torch.distributed -from .parallel_state import get_tp_group +from sglang.srt.distributed.parallel_state import get_tp_group def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py index d807dfd5c..64cf9a78d 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py @@ -7,7 +7,6 @@ import pickle import subprocess import sys import tempfile -from functools import lru_cache from itertools import product from typing import Dict, List, Optional, Sequence diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py index e72284f51..a2eacd741 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py @@ -57,7 +57,7 @@ def find_nccl_library() -> str: so_file = "librccl.so.1" else: raise ValueError("NCCL only supports CUDA and ROCm backends.") - logger.info("Found nccl from library %s", so_file) + logger.debug("Found nccl from library %s", so_file) return so_file diff --git a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py index 1afe6fca5..c9f329fb2 100644 --- a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py +++ b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py @@ -313,7 +313,7 @@ class MessageQueue: remote_subscribe_port=remote_subscribe_port, ) - logger.info("vLLM message queue communication handle: %s", self.handle) + logger.debug("Message queue communication handle: %s", self.handle) def export_handle(self) -> Handle: return self.handle diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index f66456b04..4fcfaad56 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -5,9 +5,9 @@ from typing import Optional import torch import torch.nn as nn from einops import rearrange, repeat -from vllm.distributed import parallel_state -from vllm.distributed import utils as dist_utils +from sglang.srt.distributed import parallel_state +from sglang.srt.distributed import utils as dist_utils from sglang.srt.layers.attention.triton_ops.prefill_attention import ( context_attention_fwd, ) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 9fdf7a8ac..762dac140 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -33,7 +33,6 @@ from sglang.srt.model_executor.forward_batch_info import ( ForwardBatch, ForwardMode, ) -from sglang.srt.utils import monkey_patch_vllm_all_gather if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner @@ -72,7 +71,6 @@ def patch_model( try: if enable_compile: _to_torch(model, reverse=False, batch_size=batch_size) - monkey_patch_vllm_all_gather() backup_ca_comm = tp_group.ca_comm # Use custom-allreduce here. # We found the custom allreduce is much faster than the built-in allreduce in torch, @@ -88,7 +86,6 @@ def patch_model( finally: if enable_compile: _to_torch(model, reverse=True, batch_size=batch_size) - monkey_patch_vllm_all_gather(reverse=True) tp_group.ca_comm = backup_ca_comm diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 46920d922..d5cdcf2be 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -63,8 +63,8 @@ from sglang.srt.utils import ( init_custom_process_group, is_cuda, is_hip, + monkey_patch_p2p_access_check, monkey_patch_vllm_gguf_config, - monkey_patch_vllm_p2p_access_check, set_cpu_offload_max_bytes, ) @@ -229,7 +229,8 @@ class ModelRunner: backend = "gloo" if not self.server_args.enable_p2p_check: - monkey_patch_vllm_p2p_access_check(self.gpu_id) + monkey_patch_p2p_access_check() + if self.server_args.dist_init_addr: dist_init_method = f"tcp://{self.server_args.dist_init_addr}" else: diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index c67b6635b..cf74f1d0f 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -518,68 +518,24 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N pass -def monkey_patch_vllm_p2p_access_check(gpu_id: int): +def monkey_patch_p2p_access_check(): """ - Monkey patch the slow p2p access check in vllm. + Monkey patch the slow p2p access check. NOTE: We assume the p2p access is always allowed, which can be wrong for some setups. """ - import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt + import sglang.srt.distributed.device_communicators.custom_all_reduce_utils as tgt setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True) # Suppress the warnings from this delete function when using sglang.bench_one_batch - from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce + from sglang.srt.distributed.device_communicators.custom_all_reduce import ( + CustomAllreduce, + ) setattr(CustomAllreduce, "__del__", lambda *args, **kwargs: None) -vllm_all_gather_backup = None - - -def monkey_patch_vllm_all_gather(reverse: bool = False): - """Monkey patch all-gather to remove in-place operations.""" - from torch.distributed import _functional_collectives as funcol - from vllm.distributed.parallel_state import GroupCoordinator - - global vllm_all_gather_backup - if vllm_all_gather_backup is None: - vllm_all_gather_backup = GroupCoordinator.all_gather - - def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: - world_size = self.world_size - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - assert ( - -input_.dim() <= dim < input_.dim() - ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" - if dim < 0: - # Convert negative dim to positive. - dim += input_.dim() - input_size = input_.size() - # Allocate output tensor. - output_tensor = torch.empty( - (world_size,) + input_size, dtype=input_.dtype, device=input_.device - ) - - output_tensor = funcol.all_gather_tensor( - input_, gather_dim=0, group=self.device_group - ).view((world_size,) + input_size) - - # Reshape - output_tensor = output_tensor.movedim(0, dim) - output_tensor = output_tensor.reshape( - input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :] - ) - return output_tensor - - if reverse: - setattr(GroupCoordinator, "all_gather", vllm_all_gather_backup) - else: - setattr(GroupCoordinator, "all_gather", all_gather) - - def monkey_patch_vllm_gguf_config(): from vllm.model_executor.layers.quantization.gguf import ( GGUFConfig,