Sync distributed package from vllm 0.6.4.post1 (#3010)
This commit is contained in:
@@ -1,10 +1,11 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/communication_op.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/communication_op.py
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from sglang.srt.distributed.parallel_state import get_tp_group
|
from .parallel_state import get_tp_group
|
||||||
|
|
||||||
|
|
||||||
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/cuda_wrapper.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/cuda_wrapper.py
|
||||||
|
|
||||||
"""This file is a pure Python wrapper for the cudart library.
|
"""This file is a pure Python wrapper for the cudart library.
|
||||||
It avoids the need to compile a separate shared library, and is
|
It avoids the need to compile a separate shared library, and is
|
||||||
convenient for use when we just need to call a few functions.
|
convenient for use when we just need to call a few functions.
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce.py
|
||||||
|
|
||||||
import ctypes
|
import ctypes
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce_utils.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce_utils.py
|
||||||
|
|
||||||
import ctypes
|
import ctypes
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/hpu_communicator.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/hpu_communicator.py
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/pynccl.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/pynccl.py
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
# ===================== import region =====================
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ProcessGroup, ReduceOp
|
from torch.distributed import ProcessGroup, ReduceOp
|
||||||
@@ -143,6 +145,57 @@ class PyNcclCommunicator:
|
|||||||
cudaStream_t(stream.cuda_stream),
|
cudaStream_t(stream.cuda_stream),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def all_gather(
|
||||||
|
self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None
|
||||||
|
):
|
||||||
|
if self.disabled:
|
||||||
|
return
|
||||||
|
# nccl communicator created on a specific device
|
||||||
|
# will only work on tensors on the same device
|
||||||
|
# otherwise it will cause "illegal memory access"
|
||||||
|
assert input_tensor.device == self.device, (
|
||||||
|
f"this nccl communicator is created to work on {self.device}, "
|
||||||
|
f"but the input tensor is on {input_tensor.device}"
|
||||||
|
)
|
||||||
|
if stream is None:
|
||||||
|
stream = self.stream
|
||||||
|
self.nccl.ncclAllGather(
|
||||||
|
buffer_type(input_tensor.data_ptr()),
|
||||||
|
buffer_type(output_tensor.data_ptr()),
|
||||||
|
input_tensor.numel(),
|
||||||
|
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||||
|
self.comm,
|
||||||
|
cudaStream_t(stream.cuda_stream),
|
||||||
|
)
|
||||||
|
|
||||||
|
def reduce_scatter(
|
||||||
|
self,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
op: ReduceOp = ReduceOp.SUM,
|
||||||
|
stream=None,
|
||||||
|
):
|
||||||
|
if self.disabled:
|
||||||
|
return
|
||||||
|
# nccl communicator created on a specific device
|
||||||
|
# will only work on tensors on the same device
|
||||||
|
# otherwise it will cause "illegal memory access"
|
||||||
|
assert input_tensor.device == self.device, (
|
||||||
|
f"this nccl communicator is created to work on {self.device}, "
|
||||||
|
f"but the input tensor is on {input_tensor.device}"
|
||||||
|
)
|
||||||
|
if stream is None:
|
||||||
|
stream = self.stream
|
||||||
|
self.nccl.ncclReduceScatter(
|
||||||
|
buffer_type(input_tensor.data_ptr()),
|
||||||
|
buffer_type(output_tensor.data_ptr()),
|
||||||
|
output_tensor.numel(),
|
||||||
|
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||||
|
ncclRedOpTypeEnum.from_torch(op),
|
||||||
|
self.comm,
|
||||||
|
cudaStream_t(stream.cuda_stream),
|
||||||
|
)
|
||||||
|
|
||||||
def send(self, tensor: torch.Tensor, dst: int, stream=None):
|
def send(self, tensor: torch.Tensor, dst: int, stream=None):
|
||||||
if self.disabled:
|
if self.disabled:
|
||||||
return
|
return
|
||||||
@@ -179,6 +232,32 @@ class PyNcclCommunicator:
|
|||||||
cudaStream_t(stream.cuda_stream),
|
cudaStream_t(stream.cuda_stream),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
|
||||||
|
if self.disabled:
|
||||||
|
return
|
||||||
|
assert tensor.device == self.device, (
|
||||||
|
f"this nccl communicator is created to work on {self.device}, "
|
||||||
|
f"but the input tensor is on {tensor.device}"
|
||||||
|
)
|
||||||
|
if stream is None:
|
||||||
|
stream = self.stream
|
||||||
|
if src == self.rank:
|
||||||
|
sendbuff = buffer_type(tensor.data_ptr())
|
||||||
|
# NCCL requires the sender also to have a receive buffer
|
||||||
|
recvbuff = buffer_type(tensor.data_ptr())
|
||||||
|
else:
|
||||||
|
sendbuff = buffer_type()
|
||||||
|
recvbuff = buffer_type(tensor.data_ptr())
|
||||||
|
self.nccl.ncclBroadcast(
|
||||||
|
sendbuff,
|
||||||
|
recvbuff,
|
||||||
|
tensor.numel(),
|
||||||
|
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||||
|
src,
|
||||||
|
self.comm,
|
||||||
|
cudaStream_t(stream.cuda_stream),
|
||||||
|
)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def change_state(
|
def change_state(
|
||||||
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
|
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/pynccl.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/pynccl.py
|
||||||
|
|
||||||
# This file is a pure Python wrapper for the NCCL library.
|
# This file is a pure Python wrapper for the NCCL library.
|
||||||
# The main purpose is to use NCCL combined with CUDA graph.
|
# The main purpose is to use NCCL combined with CUDA graph.
|
||||||
@@ -187,6 +187,43 @@ class NCCLLibrary:
|
|||||||
cudaStream_t,
|
cudaStream_t,
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
# ncclResult_t ncclAllGather(
|
||||||
|
# const void* sendbuff, void* recvbuff, size_t count,
|
||||||
|
# ncclDataType_t datatype, ncclComm_t comm,
|
||||||
|
# cudaStream_t stream);
|
||||||
|
# note that cudaStream_t is a pointer type, so the last argument
|
||||||
|
# is a pointer
|
||||||
|
Function(
|
||||||
|
"ncclAllGather",
|
||||||
|
ncclResult_t,
|
||||||
|
[
|
||||||
|
buffer_type,
|
||||||
|
buffer_type,
|
||||||
|
ctypes.c_size_t,
|
||||||
|
ncclDataType_t,
|
||||||
|
ncclComm_t,
|
||||||
|
cudaStream_t,
|
||||||
|
],
|
||||||
|
),
|
||||||
|
# ncclResult_t ncclReduceScatter(
|
||||||
|
# const void* sendbuff, void* recvbuff, size_t count,
|
||||||
|
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||||
|
# cudaStream_t stream);
|
||||||
|
# note that cudaStream_t is a pointer type, so the last argument
|
||||||
|
# is a pointer
|
||||||
|
Function(
|
||||||
|
"ncclReduceScatter",
|
||||||
|
ncclResult_t,
|
||||||
|
[
|
||||||
|
buffer_type,
|
||||||
|
buffer_type,
|
||||||
|
ctypes.c_size_t,
|
||||||
|
ncclDataType_t,
|
||||||
|
ncclRedOp_t,
|
||||||
|
ncclComm_t,
|
||||||
|
cudaStream_t,
|
||||||
|
],
|
||||||
|
),
|
||||||
# ncclResult_t ncclSend(
|
# ncclResult_t ncclSend(
|
||||||
# const void* sendbuff, size_t count, ncclDataType_t datatype,
|
# const void* sendbuff, size_t count, ncclDataType_t datatype,
|
||||||
# int dest, ncclComm_t comm, cudaStream_t stream);
|
# int dest, ncclComm_t comm, cudaStream_t stream);
|
||||||
@@ -217,6 +254,23 @@ class NCCLLibrary:
|
|||||||
cudaStream_t,
|
cudaStream_t,
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
# ncclResult_t ncclBroadcast(
|
||||||
|
# const void* sendbuff, void* recvbuff, size_t count,
|
||||||
|
# ncclDataType_t datatype, int root, ncclComm_t comm,
|
||||||
|
# cudaStream_t stream);
|
||||||
|
Function(
|
||||||
|
"ncclBroadcast",
|
||||||
|
ncclResult_t,
|
||||||
|
[
|
||||||
|
buffer_type,
|
||||||
|
buffer_type,
|
||||||
|
ctypes.c_size_t,
|
||||||
|
ncclDataType_t,
|
||||||
|
ctypes.c_int,
|
||||||
|
ncclComm_t,
|
||||||
|
cudaStream_t,
|
||||||
|
],
|
||||||
|
),
|
||||||
# be cautious! this is a collective call, it will block until all
|
# be cautious! this is a collective call, it will block until all
|
||||||
# processes in the communicator have called this function.
|
# processes in the communicator have called this function.
|
||||||
# because Python object destruction can happen in random order,
|
# because Python object destruction can happen in random order,
|
||||||
@@ -321,6 +375,46 @@ class NCCLLibrary:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def ncclReduceScatter(
|
||||||
|
self,
|
||||||
|
sendbuff: buffer_type,
|
||||||
|
recvbuff: buffer_type,
|
||||||
|
count: int,
|
||||||
|
datatype: int,
|
||||||
|
op: int,
|
||||||
|
comm: ncclComm_t,
|
||||||
|
stream: cudaStream_t,
|
||||||
|
) -> None:
|
||||||
|
# `datatype` actually should be `ncclDataType_t`
|
||||||
|
# and `op` should be `ncclRedOp_t`
|
||||||
|
# both are aliases of `ctypes.c_int`
|
||||||
|
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||||
|
# by ctypes automatically
|
||||||
|
self.NCCL_CHECK(
|
||||||
|
self._funcs["ncclReduceScatter"](
|
||||||
|
sendbuff, recvbuff, count, datatype, op, comm, stream
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def ncclAllGather(
|
||||||
|
self,
|
||||||
|
sendbuff: buffer_type,
|
||||||
|
recvbuff: buffer_type,
|
||||||
|
count: int,
|
||||||
|
datatype: int,
|
||||||
|
comm: ncclComm_t,
|
||||||
|
stream: cudaStream_t,
|
||||||
|
) -> None:
|
||||||
|
# `datatype` actually should be `ncclDataType_t`
|
||||||
|
# which is an aliases of `ctypes.c_int`
|
||||||
|
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||||
|
# by ctypes automatically
|
||||||
|
self.NCCL_CHECK(
|
||||||
|
self._funcs["ncclAllGather"](
|
||||||
|
sendbuff, recvbuff, count, datatype, comm, stream
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def ncclSend(
|
def ncclSend(
|
||||||
self,
|
self,
|
||||||
sendbuff: buffer_type,
|
sendbuff: buffer_type,
|
||||||
@@ -347,6 +441,22 @@ class NCCLLibrary:
|
|||||||
self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream)
|
self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def ncclBroadcast(
|
||||||
|
self,
|
||||||
|
sendbuff: buffer_type,
|
||||||
|
recvbuff: buffer_type,
|
||||||
|
count: int,
|
||||||
|
datatype: int,
|
||||||
|
root: int,
|
||||||
|
comm: ncclComm_t,
|
||||||
|
stream: cudaStream_t,
|
||||||
|
) -> None:
|
||||||
|
self.NCCL_CHECK(
|
||||||
|
self._funcs["ncclBroadcast"](
|
||||||
|
sendbuff, recvbuff, count, datatype, root, comm, stream
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
||||||
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
|
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,9 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/shm_broadcast.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/shm_broadcast.py
|
||||||
import ipaddress
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import socket
|
|
||||||
import time
|
import time
|
||||||
import warnings
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from multiprocessing import shared_memory
|
from multiprocessing import shared_memory
|
||||||
@@ -18,6 +16,8 @@ from torch.distributed import ProcessGroup
|
|||||||
from zmq import IPV6 # type: ignore
|
from zmq import IPV6 # type: ignore
|
||||||
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
||||||
|
|
||||||
|
from sglang.srt.utils import get_ip, get_open_port, is_valid_ipv6_address
|
||||||
|
|
||||||
# SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60
|
# SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60
|
||||||
SGLANG_RINGBUFFER_WARNING_INTERVAL = int(
|
SGLANG_RINGBUFFER_WARNING_INTERVAL = int(
|
||||||
os.environ.get("SGLANG_RINGBUFFER_WARNING_INTERVAL", "60")
|
os.environ.get("SGLANG_RINGBUFFER_WARNING_INTERVAL", "60")
|
||||||
@@ -26,73 +26,6 @@ SGLANG_RINGBUFFER_WARNING_INTERVAL = int(
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_ip() -> str:
|
|
||||||
# SGLANG_HOST_IP env can be ignore
|
|
||||||
host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
|
|
||||||
if host_ip:
|
|
||||||
return host_ip
|
|
||||||
|
|
||||||
# IP is not set, try to get it from the network interface
|
|
||||||
|
|
||||||
# try ipv4
|
|
||||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
||||||
try:
|
|
||||||
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
|
|
||||||
return s.getsockname()[0]
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# try ipv6
|
|
||||||
try:
|
|
||||||
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
|
|
||||||
# Google's public DNS server, see
|
|
||||||
# https://developers.google.com/speed/public-dns/docs/using#addresses
|
|
||||||
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
|
|
||||||
return s.getsockname()[0]
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
warnings.warn(
|
|
||||||
"Failed to get the IP address, using 0.0.0.0 by default."
|
|
||||||
"The value can be set by the environment variable"
|
|
||||||
" SGLANG_HOST_IP or HOST_IP.",
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
return "0.0.0.0"
|
|
||||||
|
|
||||||
|
|
||||||
def get_open_port() -> int:
|
|
||||||
|
|
||||||
port = os.getenv("SGLANG_PORT")
|
|
||||||
if port is not None:
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
||||||
s.bind(("", port))
|
|
||||||
return port
|
|
||||||
except OSError:
|
|
||||||
port += 1 # Increment port number if already in use
|
|
||||||
logger.info("Port %d is already in use, trying port %d", port - 1, port)
|
|
||||||
# try ipv4
|
|
||||||
try:
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
||||||
s.bind(("", 0))
|
|
||||||
return s.getsockname()[1]
|
|
||||||
except OSError:
|
|
||||||
# try ipv6
|
|
||||||
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
|
||||||
s.bind(("", 0))
|
|
||||||
return s.getsockname()[1]
|
|
||||||
|
|
||||||
|
|
||||||
def is_valid_ipv6_address(address: str) -> bool:
|
|
||||||
try:
|
|
||||||
ipaddress.IPv6Address(address)
|
|
||||||
return True
|
|
||||||
except ValueError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class ShmRingBuffer:
|
class ShmRingBuffer:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/xpu_communicator.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/xpu_communicator.py
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/parallel_state.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/parallel_state.py
|
||||||
|
|
||||||
# Copyright 2023 The vLLM team.
|
# Copyright 2023 The vLLM team.
|
||||||
# Adapted from
|
# Adapted from
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/utils.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/utils.py
|
||||||
|
|
||||||
# Copyright 2023 The vLLM team.
|
# Copyright 2023 The vLLM team.
|
||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
||||||
|
|||||||
@@ -29,8 +29,8 @@ from sglang.srt.utils import (
|
|||||||
get_nvgpu_memory_capacity,
|
get_nvgpu_memory_capacity,
|
||||||
is_flashinfer_available,
|
is_flashinfer_available,
|
||||||
is_hip,
|
is_hip,
|
||||||
is_ipv6,
|
|
||||||
is_port_available,
|
is_port_available,
|
||||||
|
is_valid_ipv6_address,
|
||||||
nullable_str,
|
nullable_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -883,7 +883,7 @@ class ServerArgs:
|
|||||||
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||||
|
|
||||||
def url(self):
|
def url(self):
|
||||||
if is_ipv6(self.host):
|
if is_valid_ipv6_address(self.host):
|
||||||
return f"http://[{self.host}]:{self.port}"
|
return f"http://[{self.host}]:{self.port}"
|
||||||
else:
|
else:
|
||||||
return f"http://{self.host}:{self.port}"
|
return f"http://{self.host}:{self.port}"
|
||||||
|
|||||||
@@ -102,14 +102,6 @@ def is_cuda_available():
|
|||||||
return torch.cuda.is_available() and torch.version.cuda
|
return torch.cuda.is_available() and torch.version.cuda
|
||||||
|
|
||||||
|
|
||||||
def is_ipv6(address):
|
|
||||||
try:
|
|
||||||
ipaddress.IPv6Address(address)
|
|
||||||
return True
|
|
||||||
except ipaddress.AddressValueError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def enable_show_time_cost():
|
def enable_show_time_cost():
|
||||||
global show_time_cost
|
global show_time_cost
|
||||||
show_time_cost = True
|
show_time_cost = True
|
||||||
@@ -1383,3 +1375,70 @@ def set_uvicorn_logging_configs():
|
|||||||
"fmt"
|
"fmt"
|
||||||
] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
|
] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
|
||||||
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
|
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
|
||||||
|
|
||||||
|
|
||||||
|
def get_ip() -> str:
|
||||||
|
# SGLANG_HOST_IP env can be ignore
|
||||||
|
host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
|
||||||
|
if host_ip:
|
||||||
|
return host_ip
|
||||||
|
|
||||||
|
# IP is not set, try to get it from the network interface
|
||||||
|
|
||||||
|
# try ipv4
|
||||||
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
|
try:
|
||||||
|
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
|
||||||
|
return s.getsockname()[0]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# try ipv6
|
||||||
|
try:
|
||||||
|
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
|
||||||
|
# Google's public DNS server, see
|
||||||
|
# https://developers.google.com/speed/public-dns/docs/using#addresses
|
||||||
|
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
|
||||||
|
return s.getsockname()[0]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
"Failed to get the IP address, using 0.0.0.0 by default."
|
||||||
|
"The value can be set by the environment variable"
|
||||||
|
" SGLANG_HOST_IP or HOST_IP.",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
return "0.0.0.0"
|
||||||
|
|
||||||
|
|
||||||
|
def get_open_port() -> int:
|
||||||
|
|
||||||
|
port = os.getenv("SGLANG_PORT")
|
||||||
|
if port is not None:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(("", port))
|
||||||
|
return port
|
||||||
|
except OSError:
|
||||||
|
port += 1 # Increment port number if already in use
|
||||||
|
logger.info("Port %d is already in use, trying port %d", port - 1, port)
|
||||||
|
# try ipv4
|
||||||
|
try:
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(("", 0))
|
||||||
|
return s.getsockname()[1]
|
||||||
|
except OSError:
|
||||||
|
# try ipv6
|
||||||
|
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(("", 0))
|
||||||
|
return s.getsockname()[1]
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_ipv6_address(address: str) -> bool:
|
||||||
|
try:
|
||||||
|
ipaddress.IPv6Address(address)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""Common utilities"""
|
"""Common utilities"""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import gc
|
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|||||||
Reference in New Issue
Block a user