Sync distributed package from vllm 0.6.4.post1 (#3010)
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/communication_op.py
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/communication_op.py
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from sglang.srt.distributed.parallel_state import get_tp_group
|
||||
from .parallel_state import get_tp_group
|
||||
|
||||
|
||||
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/cuda_wrapper.py
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/cuda_wrapper.py
|
||||
|
||||
"""This file is a pure Python wrapper for the cudart library.
|
||||
It avoids the need to compile a separate shared library, and is
|
||||
convenient for use when we just need to call a few functions.
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce.py
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce.py
|
||||
|
||||
import ctypes
|
||||
import logging
|
||||
import os
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce_utils.py
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce_utils.py
|
||||
|
||||
import ctypes
|
||||
import json
|
||||
import logging
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/hpu_communicator.py
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/hpu_communicator.py
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/pynccl.py
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/pynccl.py
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Union
|
||||
|
||||
# ===================== import region =====================
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
@@ -143,6 +145,57 @@ class PyNcclCommunicator:
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def all_gather(
|
||||
self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None
|
||||
):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert input_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
self.nccl.ncclAllGather(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()),
|
||||
input_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def reduce_scatter(
|
||||
self,
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None,
|
||||
):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert input_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
self.nccl.ncclReduceScatter(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()),
|
||||
output_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op),
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
@@ -179,6 +232,32 @@ class PyNcclCommunicator:
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}"
|
||||
)
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
if src == self.rank:
|
||||
sendbuff = buffer_type(tensor.data_ptr())
|
||||
# NCCL requires the sender also to have a receive buffer
|
||||
recvbuff = buffer_type(tensor.data_ptr())
|
||||
else:
|
||||
sendbuff = buffer_type()
|
||||
recvbuff = buffer_type(tensor.data_ptr())
|
||||
self.nccl.ncclBroadcast(
|
||||
sendbuff,
|
||||
recvbuff,
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
src,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def change_state(
|
||||
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/pynccl.py
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/pynccl.py
|
||||
|
||||
# This file is a pure Python wrapper for the NCCL library.
|
||||
# The main purpose is to use NCCL combined with CUDA graph.
|
||||
@@ -187,6 +187,43 @@ class NCCLLibrary:
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# ncclResult_t ncclAllGather(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function(
|
||||
"ncclAllGather",
|
||||
ncclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
ncclDataType_t,
|
||||
ncclComm_t,
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# ncclResult_t ncclReduceScatter(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function(
|
||||
"ncclReduceScatter",
|
||||
ncclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
ncclDataType_t,
|
||||
ncclRedOp_t,
|
||||
ncclComm_t,
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# ncclResult_t ncclSend(
|
||||
# const void* sendbuff, size_t count, ncclDataType_t datatype,
|
||||
# int dest, ncclComm_t comm, cudaStream_t stream);
|
||||
@@ -217,6 +254,23 @@ class NCCLLibrary:
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# ncclResult_t ncclBroadcast(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, int root, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
Function(
|
||||
"ncclBroadcast",
|
||||
ncclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
ncclDataType_t,
|
||||
ctypes.c_int,
|
||||
ncclComm_t,
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# be cautious! this is a collective call, it will block until all
|
||||
# processes in the communicator have called this function.
|
||||
# because Python object destruction can happen in random order,
|
||||
@@ -321,6 +375,46 @@ class NCCLLibrary:
|
||||
)
|
||||
)
|
||||
|
||||
def ncclReduceScatter(
|
||||
self,
|
||||
sendbuff: buffer_type,
|
||||
recvbuff: buffer_type,
|
||||
count: int,
|
||||
datatype: int,
|
||||
op: int,
|
||||
comm: ncclComm_t,
|
||||
stream: cudaStream_t,
|
||||
) -> None:
|
||||
# `datatype` actually should be `ncclDataType_t`
|
||||
# and `op` should be `ncclRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclReduceScatter"](
|
||||
sendbuff, recvbuff, count, datatype, op, comm, stream
|
||||
)
|
||||
)
|
||||
|
||||
def ncclAllGather(
|
||||
self,
|
||||
sendbuff: buffer_type,
|
||||
recvbuff: buffer_type,
|
||||
count: int,
|
||||
datatype: int,
|
||||
comm: ncclComm_t,
|
||||
stream: cudaStream_t,
|
||||
) -> None:
|
||||
# `datatype` actually should be `ncclDataType_t`
|
||||
# which is an aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclAllGather"](
|
||||
sendbuff, recvbuff, count, datatype, comm, stream
|
||||
)
|
||||
)
|
||||
|
||||
def ncclSend(
|
||||
self,
|
||||
sendbuff: buffer_type,
|
||||
@@ -347,6 +441,22 @@ class NCCLLibrary:
|
||||
self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream)
|
||||
)
|
||||
|
||||
def ncclBroadcast(
|
||||
self,
|
||||
sendbuff: buffer_type,
|
||||
recvbuff: buffer_type,
|
||||
count: int,
|
||||
datatype: int,
|
||||
root: int,
|
||||
comm: ncclComm_t,
|
||||
stream: cudaStream_t,
|
||||
) -> None:
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclBroadcast"](
|
||||
sendbuff, recvbuff, count, datatype, root, comm, stream
|
||||
)
|
||||
)
|
||||
|
||||
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
|
||||
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/shm_broadcast.py
|
||||
import ipaddress
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/shm_broadcast.py
|
||||
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import socket
|
||||
import time
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing import shared_memory
|
||||
@@ -18,6 +16,8 @@ from torch.distributed import ProcessGroup
|
||||
from zmq import IPV6 # type: ignore
|
||||
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
||||
|
||||
from sglang.srt.utils import get_ip, get_open_port, is_valid_ipv6_address
|
||||
|
||||
# SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60
|
||||
SGLANG_RINGBUFFER_WARNING_INTERVAL = int(
|
||||
os.environ.get("SGLANG_RINGBUFFER_WARNING_INTERVAL", "60")
|
||||
@@ -26,73 +26,6 @@ SGLANG_RINGBUFFER_WARNING_INTERVAL = int(
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_ip() -> str:
|
||||
# SGLANG_HOST_IP env can be ignore
|
||||
host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
|
||||
if host_ip:
|
||||
return host_ip
|
||||
|
||||
# IP is not set, try to get it from the network interface
|
||||
|
||||
# try ipv4
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
try:
|
||||
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
|
||||
return s.getsockname()[0]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# try ipv6
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
|
||||
# Google's public DNS server, see
|
||||
# https://developers.google.com/speed/public-dns/docs/using#addresses
|
||||
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
|
||||
return s.getsockname()[0]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
warnings.warn(
|
||||
"Failed to get the IP address, using 0.0.0.0 by default."
|
||||
"The value can be set by the environment variable"
|
||||
" SGLANG_HOST_IP or HOST_IP.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return "0.0.0.0"
|
||||
|
||||
|
||||
def get_open_port() -> int:
|
||||
|
||||
port = os.getenv("SGLANG_PORT")
|
||||
if port is not None:
|
||||
while True:
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", port))
|
||||
return port
|
||||
except OSError:
|
||||
port += 1 # Increment port number if already in use
|
||||
logger.info("Port %d is already in use, trying port %d", port - 1, port)
|
||||
# try ipv4
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
except OSError:
|
||||
# try ipv6
|
||||
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def is_valid_ipv6_address(address: str) -> bool:
|
||||
try:
|
||||
ipaddress.IPv6Address(address)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
class ShmRingBuffer:
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/xpu_communicator.py
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/xpu_communicator.py
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/parallel_state.py
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/parallel_state.py
|
||||
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Adapted from
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/utils.py
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/utils.py
|
||||
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Adapted from
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
||||
|
||||
@@ -29,8 +29,8 @@ from sglang.srt.utils import (
|
||||
get_nvgpu_memory_capacity,
|
||||
is_flashinfer_available,
|
||||
is_hip,
|
||||
is_ipv6,
|
||||
is_port_available,
|
||||
is_valid_ipv6_address,
|
||||
nullable_str,
|
||||
)
|
||||
|
||||
@@ -883,7 +883,7 @@ class ServerArgs:
|
||||
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||
|
||||
def url(self):
|
||||
if is_ipv6(self.host):
|
||||
if is_valid_ipv6_address(self.host):
|
||||
return f"http://[{self.host}]:{self.port}"
|
||||
else:
|
||||
return f"http://{self.host}:{self.port}"
|
||||
|
||||
@@ -102,14 +102,6 @@ def is_cuda_available():
|
||||
return torch.cuda.is_available() and torch.version.cuda
|
||||
|
||||
|
||||
def is_ipv6(address):
|
||||
try:
|
||||
ipaddress.IPv6Address(address)
|
||||
return True
|
||||
except ipaddress.AddressValueError:
|
||||
return False
|
||||
|
||||
|
||||
def enable_show_time_cost():
|
||||
global show_time_cost
|
||||
show_time_cost = True
|
||||
@@ -1383,3 +1375,70 @@ def set_uvicorn_logging_configs():
|
||||
"fmt"
|
||||
] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
|
||||
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
|
||||
def get_ip() -> str:
|
||||
# SGLANG_HOST_IP env can be ignore
|
||||
host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
|
||||
if host_ip:
|
||||
return host_ip
|
||||
|
||||
# IP is not set, try to get it from the network interface
|
||||
|
||||
# try ipv4
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
try:
|
||||
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
|
||||
return s.getsockname()[0]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# try ipv6
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
|
||||
# Google's public DNS server, see
|
||||
# https://developers.google.com/speed/public-dns/docs/using#addresses
|
||||
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
|
||||
return s.getsockname()[0]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
warnings.warn(
|
||||
"Failed to get the IP address, using 0.0.0.0 by default."
|
||||
"The value can be set by the environment variable"
|
||||
" SGLANG_HOST_IP or HOST_IP.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return "0.0.0.0"
|
||||
|
||||
|
||||
def get_open_port() -> int:
|
||||
|
||||
port = os.getenv("SGLANG_PORT")
|
||||
if port is not None:
|
||||
while True:
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", port))
|
||||
return port
|
||||
except OSError:
|
||||
port += 1 # Increment port number if already in use
|
||||
logger.info("Port %d is already in use, trying port %d", port - 1, port)
|
||||
# try ipv4
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
except OSError:
|
||||
# try ipv6
|
||||
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def is_valid_ipv6_address(address: str) -> bool:
|
||||
try:
|
||||
ipaddress.IPv6Address(address)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Common utilities"""
|
||||
|
||||
import base64
|
||||
import gc
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
|
||||
Reference in New Issue
Block a user