diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index e7a43d350..d21ff7f73 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -6,7 +6,7 @@ import os import sys import threading import time -from typing import List, Optional +from typing import List, Optional, Union # Fix a Python bug setattr(threading, "_register_atexit", lambda *args, **kwargs: None) @@ -47,7 +47,7 @@ from sglang.srt.managers.openai_protocol import ( from sglang.srt.managers.router.manager import start_router_process from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import alloc_usable_network_port +from sglang.srt.utils import alloc_usable_network_port, handle_port_init asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -306,16 +306,17 @@ def launch_server(server_args, pipe_finish_writer): global tokenizer_manager global chat_template_name - # Allocate ports - can_use_ports = alloc_usable_network_port( - num=4 + server_args.tp_size, used_list=(server_args.port,) + # Handle ports + server_args.port, server_args.additional_ports = handle_port_init( + server_args.port, server_args.additional_ports, server_args.tp_size ) + port_args = PortArgs( - tokenizer_port=can_use_ports[0], - router_port=can_use_ports[1], - detokenizer_port=can_use_ports[2], - nccl_port=can_use_ports[3], - model_rpc_ports=can_use_ports[4:], + tokenizer_port=server_args.additional_ports[0], + router_port=server_args.additional_ports[1], + detokenizer_port=server_args.additional_ports[2], + nccl_port=server_args.additional_ports[3], + model_rpc_ports=server_args.additional_ports[4:], ) # Load chat template if needed @@ -435,14 +436,19 @@ class Runtime: schedule_heuristic: str = "lpm", random_seed: int = 42, log_level: str = "error", + port: Optional[int] = None, + additional_ports: Optional[Union[List[int], int]] = None, ): host = "127.0.0.1" - port = alloc_usable_network_port(1)[0] + port, additional_ports = handle_port_init( + port, additional_ports, tp_size + ) self.server_args = ServerArgs( model_path=model_path, tokenizer_path=tokenizer_path, host=host, port=port, + additional_ports=additional_ports, load_format=load_format, tokenizer_mode=tokenizer_mode, trust_remote_code=trust_remote_code, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 17e436d8d..39622967b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1,6 +1,6 @@ import argparse import dataclasses -from typing import List, Optional +from typing import List, Optional, Union @dataclasses.dataclass @@ -9,6 +9,7 @@ class ServerArgs: tokenizer_path: Optional[str] = None host: str = "127.0.0.1" port: int = 30000 + additional_ports: Optional[Union[List[int], int]] = None load_format: str = "auto" tokenizer_mode: str = "auto" chat_template: Optional[str] = None @@ -37,6 +38,10 @@ class ServerArgs: self.mem_fraction_static = 0.85 else: self.mem_fraction_static = 0.90 + if isinstance(self.additional_ports, int): + self.additional_ports = [self.additional_ports] + elif self.additional_ports is None: + self.additional_ports = [] @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -54,6 +59,14 @@ class ServerArgs: ) parser.add_argument("--host", type=str, default=ServerArgs.host) parser.add_argument("--port", type=int, default=ServerArgs.port) + # we want to be able to pass a list of ports + parser.add_argument( + "--additional-ports", + type=int, + nargs="*", + default=[], + help="Additional ports specified for launching server.", + ) parser.add_argument( "--load-format", type=str, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 8c5876602..05e035a0a 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -99,6 +99,40 @@ def alloc_usable_network_port(num, used_list=()): return None +def check_port(port): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("", port)) + return True + except socket.error: + return False + + +def handle_port_init(port: Optional[int] = None, additional_ports: Optional[List[int]] = None, tp_size: int = 1): + port = 30000 if port is None else port + additional_ports = [] if additional_ports is None else additional_ports + additional_ports = [additional_ports] if isinstance(additional_ports, int) else additional_ports + # first check on server port + if not check_port(port): + new_port = alloc_usable_network_port(1, used_list=[port])[0] + print(f"Port {port} is not available, using {new_port} instead.") + port = new_port + + # then we check on additional ports + additional_unique_ports = set(additional_ports) - {port} + # filter out ports that are already in use + can_use_ports = [port for port in additional_unique_ports if check_port(port)] + + num_specified_ports = len(can_use_ports) + if num_specified_ports < 4 + tp_size: + addtional_can_use_ports = alloc_usable_network_port( + num=4 + tp_size - num_specified_ports, used_list=can_use_ports + [port] + ) + can_use_ports.extend(addtional_can_use_ports) + + additional_ports = can_use_ports[:4 + tp_size] + return port, additional_ports + def get_exception_traceback(): etype, value, tb = sys.exc_info() err_str = "".join(traceback.format_exception(etype, value, tb))