From c05956e53495a219bdb12d9f995d22afa89fd6cd Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 16 May 2024 18:07:30 -0700 Subject: [PATCH] Simplify port allocation (#447) --- .../srt/managers/detokenizer_manager.py | 2 +- python/sglang/srt/managers/router/manager.py | 2 +- .../sglang/srt/managers/router/model_rpc.py | 3 +- .../sglang/srt/managers/tokenizer_manager.py | 4 +- python/sglang/srt/server.py | 3 +- python/sglang/srt/utils.py | 64 ++++++------------- 6 files changed, 28 insertions(+), 50 deletions(-) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 436d91525..52bad9792 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -7,7 +7,7 @@ import zmq.asyncio from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import get_exception_traceback +from sglang.utils import get_exception_traceback asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) diff --git a/python/sglang/srt/managers/router/manager.py b/python/sglang/srt/managers/router/manager.py index 66adc2e59..f0e856998 100644 --- a/python/sglang/srt/managers/router/manager.py +++ b/python/sglang/srt/managers/router/manager.py @@ -8,7 +8,7 @@ import zmq.asyncio from sglang.global_config import global_config from sglang.srt.managers.router.model_rpc import ModelRpcClient from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import get_exception_traceback +from sglang.utils import get_exception_traceback asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index e9b57d23c..660f09f3e 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -31,11 +31,12 @@ from sglang.srt.managers.router.scheduler import Scheduler from sglang.srt.model_config import ModelConfig from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( - get_exception_traceback, get_int_token_logit_bias, is_multimodal_model, set_random_seed, ) +from sglang.utils import get_exception_traceback + logger = logging.getLogger("model_rpc") vllm_default_logger.setLevel(logging.WARN) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index f4cd4ad86..8cc27f849 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -20,7 +20,6 @@ from sglang.srt.hf_transformers_utils import ( ) from sglang.srt.managers.io_struct import ( BatchStrOut, - DetokenizeReqInput, FlushCacheReq, GenerateReqInput, TokenizedGenerateReqInput, @@ -28,7 +27,8 @@ from sglang.srt.managers.io_struct import ( from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import get_exception_traceback, is_multimodal_model, load_image +from sglang.srt.utils import is_multimodal_model, load_image +from sglang.utils import get_exception_traceback asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index f3a437ab0..6f471bfbc 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -41,8 +41,9 @@ from sglang.srt.utils import ( allocate_init_ports, assert_pkg_version, enable_show_time_cost, - get_exception_traceback, ) +from sglang.utils import get_exception_traceback + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 9a1e6400d..fbd98b3bb 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1,6 +1,7 @@ """Common utilities.""" import base64 +import logging import os import random import socket @@ -18,7 +19,9 @@ from packaging import version as pkg_version from pydantic import BaseModel from starlette.middleware.base import BaseHTTPMiddleware -from sglang.utils import get_exception_traceback + +logger = logging.getLogger(__name__) + show_time_cost = False time_infos = {} @@ -124,31 +127,12 @@ def set_random_seed(seed: int) -> None: torch.cuda.manual_seed_all(seed) -def alloc_usable_network_port(num, used_list=()): - port_list = [] - for port in range(10000, 65536): - if port in used_list: - continue - - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - try: - s.bind(("", port)) - s.listen(1) # Attempt to listen on the port - port_list.append(port) - except socket.error: - pass # If any error occurs, this port is not usable - - if len(port_list) == num: - return port_list - return None - - -def check_port(port): +def is_port_available(port): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: try: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind(("", port)) + s.listen(1) return True except socket.error: return False @@ -159,31 +143,23 @@ def allocate_init_ports( additional_ports: Optional[List[int]] = None, tp_size: int = 1, ): - port = 30000 if port is None else port - additional_ports = [] if additional_ports is None else additional_ports - additional_ports = ( - [additional_ports] if isinstance(additional_ports, int) else additional_ports - ) - # first check on server port - if not check_port(port): - new_port = alloc_usable_network_port(1, used_list=[port])[0] - print(f"WARNING: Port {port} is not available. Use {new_port} instead.") - port = new_port + if additional_ports: + ret_ports = [port] + additional_ports + else: + ret_ports = [port] - # then we check on additional ports - additional_unique_ports = set(additional_ports) - {port} - # filter out ports that are already in use - can_use_ports = [port for port in additional_unique_ports if check_port(port)] + ret_ports = list(set(x for x in ret_ports if is_port_available(x))) + cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000 - num_specified_ports = len(can_use_ports) - if num_specified_ports < 4 + tp_size: - addtional_can_use_ports = alloc_usable_network_port( - num=4 + tp_size - num_specified_ports, used_list=can_use_ports + [port] - ) - can_use_ports.extend(addtional_can_use_ports) + while len(ret_ports) < 5 + tp_size: + if cur_port not in ret_ports and is_port_available(cur_port): + ret_ports.append(cur_port) + cur_port += 1 - additional_ports = can_use_ports[: 4 + tp_size] - return port, additional_ports + if port and ret_ports[0] != port: + logger.warn(f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead.") + + return ret_ports[0], ret_ports[1:] def get_int_token_logit_bias(tokenizer, vocab_size):