Roll back to use vllm custom allreduce (#3006)
This commit is contained in:
@@ -12,7 +12,7 @@ import torch.library
|
|||||||
from sglang.srt.utils import is_hpu
|
from sglang.srt.utils import is_hpu
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=False)
|
use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=True)
|
||||||
|
|
||||||
if not is_hpu():
|
if not is_hpu():
|
||||||
if use_vllm_custom_allreduce:
|
if use_vllm_custom_allreduce:
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
from .communication_op import *
|
from sglang.srt.distributed.communication_op import *
|
||||||
from .parallel_state import *
|
from sglang.srt.distributed.parallel_state import *
|
||||||
from .utils import *
|
from sglang.srt.distributed.utils import *
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import Any, Dict, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from .parallel_state import get_tp_group
|
from sglang.srt.distributed.parallel_state import get_tp_group
|
||||||
|
|
||||||
|
|
||||||
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import pickle
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
from functools import lru_cache
|
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from typing import Dict, List, Optional, Sequence
|
from typing import Dict, List, Optional, Sequence
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ def find_nccl_library() -> str:
|
|||||||
so_file = "librccl.so.1"
|
so_file = "librccl.so.1"
|
||||||
else:
|
else:
|
||||||
raise ValueError("NCCL only supports CUDA and ROCm backends.")
|
raise ValueError("NCCL only supports CUDA and ROCm backends.")
|
||||||
logger.info("Found nccl from library %s", so_file)
|
logger.debug("Found nccl from library %s", so_file)
|
||||||
return so_file
|
return so_file
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -313,7 +313,7 @@ class MessageQueue:
|
|||||||
remote_subscribe_port=remote_subscribe_port,
|
remote_subscribe_port=remote_subscribe_port,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("vLLM message queue communication handle: %s", self.handle)
|
logger.debug("Message queue communication handle: %s", self.handle)
|
||||||
|
|
||||||
def export_handle(self) -> Handle:
|
def export_handle(self) -> Handle:
|
||||||
return self.handle
|
return self.handle
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from vllm.distributed import parallel_state
|
|
||||||
from vllm.distributed import utils as dist_utils
|
|
||||||
|
|
||||||
|
from sglang.srt.distributed import parallel_state
|
||||||
|
from sglang.srt.distributed import utils as dist_utils
|
||||||
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
||||||
context_attention_fwd,
|
context_attention_fwd,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -33,7 +33,6 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|||||||
ForwardBatch,
|
ForwardBatch,
|
||||||
ForwardMode,
|
ForwardMode,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
@@ -72,7 +71,6 @@ def patch_model(
|
|||||||
try:
|
try:
|
||||||
if enable_compile:
|
if enable_compile:
|
||||||
_to_torch(model, reverse=False, batch_size=batch_size)
|
_to_torch(model, reverse=False, batch_size=batch_size)
|
||||||
monkey_patch_vllm_all_gather()
|
|
||||||
backup_ca_comm = tp_group.ca_comm
|
backup_ca_comm = tp_group.ca_comm
|
||||||
# Use custom-allreduce here.
|
# Use custom-allreduce here.
|
||||||
# We found the custom allreduce is much faster than the built-in allreduce in torch,
|
# We found the custom allreduce is much faster than the built-in allreduce in torch,
|
||||||
@@ -88,7 +86,6 @@ def patch_model(
|
|||||||
finally:
|
finally:
|
||||||
if enable_compile:
|
if enable_compile:
|
||||||
_to_torch(model, reverse=True, batch_size=batch_size)
|
_to_torch(model, reverse=True, batch_size=batch_size)
|
||||||
monkey_patch_vllm_all_gather(reverse=True)
|
|
||||||
tp_group.ca_comm = backup_ca_comm
|
tp_group.ca_comm = backup_ca_comm
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -63,8 +63,8 @@ from sglang.srt.utils import (
|
|||||||
init_custom_process_group,
|
init_custom_process_group,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
is_hip,
|
is_hip,
|
||||||
|
monkey_patch_p2p_access_check,
|
||||||
monkey_patch_vllm_gguf_config,
|
monkey_patch_vllm_gguf_config,
|
||||||
monkey_patch_vllm_p2p_access_check,
|
|
||||||
set_cpu_offload_max_bytes,
|
set_cpu_offload_max_bytes,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -229,7 +229,8 @@ class ModelRunner:
|
|||||||
backend = "gloo"
|
backend = "gloo"
|
||||||
|
|
||||||
if not self.server_args.enable_p2p_check:
|
if not self.server_args.enable_p2p_check:
|
||||||
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
monkey_patch_p2p_access_check()
|
||||||
|
|
||||||
if self.server_args.dist_init_addr:
|
if self.server_args.dist_init_addr:
|
||||||
dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
|
dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -518,68 +518,24 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def monkey_patch_vllm_p2p_access_check(gpu_id: int):
|
def monkey_patch_p2p_access_check():
|
||||||
"""
|
"""
|
||||||
Monkey patch the slow p2p access check in vllm.
|
Monkey patch the slow p2p access check.
|
||||||
NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
|
NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
|
import sglang.srt.distributed.device_communicators.custom_all_reduce_utils as tgt
|
||||||
|
|
||||||
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
|
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
|
||||||
|
|
||||||
# Suppress the warnings from this delete function when using sglang.bench_one_batch
|
# Suppress the warnings from this delete function when using sglang.bench_one_batch
|
||||||
from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce
|
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
|
||||||
|
CustomAllreduce,
|
||||||
|
)
|
||||||
|
|
||||||
setattr(CustomAllreduce, "__del__", lambda *args, **kwargs: None)
|
setattr(CustomAllreduce, "__del__", lambda *args, **kwargs: None)
|
||||||
|
|
||||||
|
|
||||||
vllm_all_gather_backup = None
|
|
||||||
|
|
||||||
|
|
||||||
def monkey_patch_vllm_all_gather(reverse: bool = False):
|
|
||||||
"""Monkey patch all-gather to remove in-place operations."""
|
|
||||||
from torch.distributed import _functional_collectives as funcol
|
|
||||||
from vllm.distributed.parallel_state import GroupCoordinator
|
|
||||||
|
|
||||||
global vllm_all_gather_backup
|
|
||||||
if vllm_all_gather_backup is None:
|
|
||||||
vllm_all_gather_backup = GroupCoordinator.all_gather
|
|
||||||
|
|
||||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
|
||||||
world_size = self.world_size
|
|
||||||
# Bypass the function if we are using only 1 GPU.
|
|
||||||
if world_size == 1:
|
|
||||||
return input_
|
|
||||||
assert (
|
|
||||||
-input_.dim() <= dim < input_.dim()
|
|
||||||
), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
|
|
||||||
if dim < 0:
|
|
||||||
# Convert negative dim to positive.
|
|
||||||
dim += input_.dim()
|
|
||||||
input_size = input_.size()
|
|
||||||
# Allocate output tensor.
|
|
||||||
output_tensor = torch.empty(
|
|
||||||
(world_size,) + input_size, dtype=input_.dtype, device=input_.device
|
|
||||||
)
|
|
||||||
|
|
||||||
output_tensor = funcol.all_gather_tensor(
|
|
||||||
input_, gather_dim=0, group=self.device_group
|
|
||||||
).view((world_size,) + input_size)
|
|
||||||
|
|
||||||
# Reshape
|
|
||||||
output_tensor = output_tensor.movedim(0, dim)
|
|
||||||
output_tensor = output_tensor.reshape(
|
|
||||||
input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
|
|
||||||
)
|
|
||||||
return output_tensor
|
|
||||||
|
|
||||||
if reverse:
|
|
||||||
setattr(GroupCoordinator, "all_gather", vllm_all_gather_backup)
|
|
||||||
else:
|
|
||||||
setattr(GroupCoordinator, "all_gather", all_gather)
|
|
||||||
|
|
||||||
|
|
||||||
def monkey_patch_vllm_gguf_config():
|
def monkey_patch_vllm_gguf_config():
|
||||||
from vllm.model_executor.layers.quantization.gguf import (
|
from vllm.model_executor.layers.quantization.gguf import (
|
||||||
GGUFConfig,
|
GGUFConfig,
|
||||||
|
|||||||
Reference in New Issue
Block a user