Simplify port allocation (#447)
This commit is contained in:
@@ -7,7 +7,7 @@ import zmq.asyncio
|
|||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import zmq.asyncio
|
|||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.managers.router.model_rpc import ModelRpcClient
|
from sglang.srt.managers.router.model_rpc import ModelRpcClient
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
|
|||||||
@@ -31,11 +31,12 @@ from sglang.srt.managers.router.scheduler import Scheduler
|
|||||||
from sglang.srt.model_config import ModelConfig
|
from sglang.srt.model_config import ModelConfig
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_exception_traceback,
|
|
||||||
get_int_token_logit_bias,
|
get_int_token_logit_bias,
|
||||||
is_multimodal_model,
|
is_multimodal_model,
|
||||||
set_random_seed,
|
set_random_seed,
|
||||||
)
|
)
|
||||||
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger("model_rpc")
|
logger = logging.getLogger("model_rpc")
|
||||||
vllm_default_logger.setLevel(logging.WARN)
|
vllm_default_logger.setLevel(logging.WARN)
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from sglang.srt.hf_transformers_utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
BatchStrOut,
|
BatchStrOut,
|
||||||
DetokenizeReqInput,
|
|
||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
@@ -28,7 +27,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||||
from sglang.srt.sampling_params import SamplingParams
|
from sglang.srt.sampling_params import SamplingParams
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import get_exception_traceback, is_multimodal_model, load_image
|
from sglang.srt.utils import is_multimodal_model, load_image
|
||||||
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
|
|||||||
@@ -41,8 +41,9 @@ from sglang.srt.utils import (
|
|||||||
allocate_init_ports,
|
allocate_init_ports,
|
||||||
assert_pkg_version,
|
assert_pkg_version,
|
||||||
enable_show_time_cost,
|
enable_show_time_cost,
|
||||||
get_exception_traceback,
|
|
||||||
)
|
)
|
||||||
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Common utilities."""
|
"""Common utilities."""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import socket
|
import socket
|
||||||
@@ -18,7 +19,9 @@ from packaging import version as pkg_version
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
|
||||||
from sglang.utils import get_exception_traceback
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
show_time_cost = False
|
show_time_cost = False
|
||||||
time_infos = {}
|
time_infos = {}
|
||||||
@@ -124,31 +127,12 @@ def set_random_seed(seed: int) -> None:
|
|||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
def alloc_usable_network_port(num, used_list=()):
|
def is_port_available(port):
|
||||||
port_list = []
|
|
||||||
for port in range(10000, 65536):
|
|
||||||
if port in used_list:
|
|
||||||
continue
|
|
||||||
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
||||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
||||||
try:
|
|
||||||
s.bind(("", port))
|
|
||||||
s.listen(1) # Attempt to listen on the port
|
|
||||||
port_list.append(port)
|
|
||||||
except socket.error:
|
|
||||||
pass # If any error occurs, this port is not usable
|
|
||||||
|
|
||||||
if len(port_list) == num:
|
|
||||||
return port_list
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def check_port(port):
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
try:
|
try:
|
||||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
s.bind(("", port))
|
s.bind(("", port))
|
||||||
|
s.listen(1)
|
||||||
return True
|
return True
|
||||||
except socket.error:
|
except socket.error:
|
||||||
return False
|
return False
|
||||||
@@ -159,31 +143,23 @@ def allocate_init_ports(
|
|||||||
additional_ports: Optional[List[int]] = None,
|
additional_ports: Optional[List[int]] = None,
|
||||||
tp_size: int = 1,
|
tp_size: int = 1,
|
||||||
):
|
):
|
||||||
port = 30000 if port is None else port
|
if additional_ports:
|
||||||
additional_ports = [] if additional_ports is None else additional_ports
|
ret_ports = [port] + additional_ports
|
||||||
additional_ports = (
|
else:
|
||||||
[additional_ports] if isinstance(additional_ports, int) else additional_ports
|
ret_ports = [port]
|
||||||
)
|
|
||||||
# first check on server port
|
|
||||||
if not check_port(port):
|
|
||||||
new_port = alloc_usable_network_port(1, used_list=[port])[0]
|
|
||||||
print(f"WARNING: Port {port} is not available. Use {new_port} instead.")
|
|
||||||
port = new_port
|
|
||||||
|
|
||||||
# then we check on additional ports
|
ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
|
||||||
additional_unique_ports = set(additional_ports) - {port}
|
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
|
||||||
# filter out ports that are already in use
|
|
||||||
can_use_ports = [port for port in additional_unique_ports if check_port(port)]
|
|
||||||
|
|
||||||
num_specified_ports = len(can_use_ports)
|
while len(ret_ports) < 5 + tp_size:
|
||||||
if num_specified_ports < 4 + tp_size:
|
if cur_port not in ret_ports and is_port_available(cur_port):
|
||||||
addtional_can_use_ports = alloc_usable_network_port(
|
ret_ports.append(cur_port)
|
||||||
num=4 + tp_size - num_specified_ports, used_list=can_use_ports + [port]
|
cur_port += 1
|
||||||
)
|
|
||||||
can_use_ports.extend(addtional_can_use_ports)
|
|
||||||
|
|
||||||
additional_ports = can_use_ports[: 4 + tp_size]
|
if port and ret_ports[0] != port:
|
||||||
return port, additional_ports
|
logger.warn(f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead.")
|
||||||
|
|
||||||
|
return ret_ports[0], ret_ports[1:]
|
||||||
|
|
||||||
|
|
||||||
def get_int_token_logit_bias(tokenizer, vocab_size):
|
def get_int_token_logit_bias(tokenizer, vocab_size):
|
||||||
|
|||||||
Reference in New Issue
Block a user