Roll back to use vllm custom allreduce (#3006)
This commit is contained in:
@@ -12,7 +12,7 @@ import torch.library
|
||||
from sglang.srt.utils import is_hpu
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=False)
|
||||
use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=True)
|
||||
|
||||
if not is_hpu():
|
||||
if use_vllm_custom_allreduce:
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from .communication_op import *
|
||||
from .parallel_state import *
|
||||
from .utils import *
|
||||
from sglang.srt.distributed.communication_op import *
|
||||
from sglang.srt.distributed.parallel_state import *
|
||||
from sglang.srt.distributed.utils import *
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any, Dict, Optional, Union
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from .parallel_state import get_tp_group
|
||||
from sglang.srt.distributed.parallel_state import get_tp_group
|
||||
|
||||
|
||||
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@@ -7,7 +7,6 @@ import pickle
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from functools import lru_cache
|
||||
from itertools import product
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ def find_nccl_library() -> str:
|
||||
so_file = "librccl.so.1"
|
||||
else:
|
||||
raise ValueError("NCCL only supports CUDA and ROCm backends.")
|
||||
logger.info("Found nccl from library %s", so_file)
|
||||
logger.debug("Found nccl from library %s", so_file)
|
||||
return so_file
|
||||
|
||||
|
||||
|
||||
@@ -313,7 +313,7 @@ class MessageQueue:
|
||||
remote_subscribe_port=remote_subscribe_port,
|
||||
)
|
||||
|
||||
logger.info("vLLM message queue communication handle: %s", self.handle)
|
||||
logger.debug("Message queue communication handle: %s", self.handle)
|
||||
|
||||
def export_handle(self) -> Handle:
|
||||
return self.handle
|
||||
|
||||
@@ -5,9 +5,9 @@ from typing import Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange, repeat
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
|
||||
from sglang.srt.distributed import parallel_state
|
||||
from sglang.srt.distributed import utils as dist_utils
|
||||
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
||||
context_attention_fwd,
|
||||
)
|
||||
|
||||
@@ -33,7 +33,6 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
@@ -72,7 +71,6 @@ def patch_model(
|
||||
try:
|
||||
if enable_compile:
|
||||
_to_torch(model, reverse=False, batch_size=batch_size)
|
||||
monkey_patch_vllm_all_gather()
|
||||
backup_ca_comm = tp_group.ca_comm
|
||||
# Use custom-allreduce here.
|
||||
# We found the custom allreduce is much faster than the built-in allreduce in torch,
|
||||
@@ -88,7 +86,6 @@ def patch_model(
|
||||
finally:
|
||||
if enable_compile:
|
||||
_to_torch(model, reverse=True, batch_size=batch_size)
|
||||
monkey_patch_vllm_all_gather(reverse=True)
|
||||
tp_group.ca_comm = backup_ca_comm
|
||||
|
||||
|
||||
|
||||
@@ -63,8 +63,8 @@ from sglang.srt.utils import (
|
||||
init_custom_process_group,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
monkey_patch_p2p_access_check,
|
||||
monkey_patch_vllm_gguf_config,
|
||||
monkey_patch_vllm_p2p_access_check,
|
||||
set_cpu_offload_max_bytes,
|
||||
)
|
||||
|
||||
@@ -229,7 +229,8 @@ class ModelRunner:
|
||||
backend = "gloo"
|
||||
|
||||
if not self.server_args.enable_p2p_check:
|
||||
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
||||
monkey_patch_p2p_access_check()
|
||||
|
||||
if self.server_args.dist_init_addr:
|
||||
dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
|
||||
else:
|
||||
|
||||
@@ -518,68 +518,24 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N
|
||||
pass
|
||||
|
||||
|
||||
def monkey_patch_vllm_p2p_access_check(gpu_id: int):
|
||||
def monkey_patch_p2p_access_check():
|
||||
"""
|
||||
Monkey patch the slow p2p access check in vllm.
|
||||
Monkey patch the slow p2p access check.
|
||||
NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
|
||||
"""
|
||||
|
||||
import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
|
||||
import sglang.srt.distributed.device_communicators.custom_all_reduce_utils as tgt
|
||||
|
||||
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
|
||||
|
||||
# Suppress the warnings from this delete function when using sglang.bench_one_batch
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce
|
||||
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
|
||||
CustomAllreduce,
|
||||
)
|
||||
|
||||
setattr(CustomAllreduce, "__del__", lambda *args, **kwargs: None)
|
||||
|
||||
|
||||
vllm_all_gather_backup = None
|
||||
|
||||
|
||||
def monkey_patch_vllm_all_gather(reverse: bool = False):
|
||||
"""Monkey patch all-gather to remove in-place operations."""
|
||||
from torch.distributed import _functional_collectives as funcol
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
|
||||
global vllm_all_gather_backup
|
||||
if vllm_all_gather_backup is None:
|
||||
vllm_all_gather_backup = GroupCoordinator.all_gather
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
assert (
|
||||
-input_.dim() <= dim < input_.dim()
|
||||
), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
input_size = input_.size()
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(
|
||||
(world_size,) + input_size, dtype=input_.dtype, device=input_.device
|
||||
)
|
||||
|
||||
output_tensor = funcol.all_gather_tensor(
|
||||
input_, gather_dim=0, group=self.device_group
|
||||
).view((world_size,) + input_size)
|
||||
|
||||
# Reshape
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(
|
||||
input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
|
||||
)
|
||||
return output_tensor
|
||||
|
||||
if reverse:
|
||||
setattr(GroupCoordinator, "all_gather", vllm_all_gather_backup)
|
||||
else:
|
||||
setattr(GroupCoordinator, "all_gather", all_gather)
|
||||
|
||||
|
||||
def monkey_patch_vllm_gguf_config():
|
||||
from vllm.model_executor.layers.quantization.gguf import (
|
||||
GGUFConfig,
|
||||
|
||||
Reference in New Issue
Block a user