diff --git a/python/pyproject.toml b/python/pyproject.toml index 83b365eb7..1a04389b5 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ [project.optional-dependencies] srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow", - "psutil", "pydantic", "rpyc", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.1", "outlines>=0.0.44"] + "psutil", "pydantic", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.1", "outlines>=0.0.44"] openai = ["openai>=1.0", "tiktoken"] anthropic = ["anthropic>=0.20.0"] litellm = ["litellm>=1.0.0"] diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index cb2b07251..91dc0dc4e 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -11,4 +11,4 @@ if __name__ == "__main__": args = parser.parse_args() server_args = ServerArgs.from_cli_args(args) - launch_server(server_args, None) + launch_server(server_args) diff --git a/python/sglang/launch_server_llavavid.py b/python/sglang/launch_server_llavavid.py index b71d8701d..c34dd2116 100644 --- a/python/sglang/launch_server_llavavid.py +++ b/python/sglang/launch_server_llavavid.py @@ -1,7 +1,6 @@ """Launch the inference server for Llava-video model.""" import argparse -import multiprocessing as mp from sglang.srt.server import ServerArgs, launch_server @@ -27,6 +26,4 @@ if __name__ == "__main__": server_args = ServerArgs.from_cli_args(args) - pipe_reader, pipe_writer = mp.Pipe(duplex=False) - - launch_server(server_args, pipe_writer, model_overide_args) + launch_server(server_args, model_overide_args, None) diff --git a/python/sglang/srt/managers/controller/dp_worker.py b/python/sglang/srt/managers/controller/dp_worker.py deleted file mode 100644 index 3b6becfd2..000000000 --- a/python/sglang/srt/managers/controller/dp_worker.py +++ /dev/null @@ -1,113 +0,0 @@ -"""A data parallel worker thread.""" - -import asyncio -import logging -import queue -import threading -from typing import Callable, List - -import uvloop -import zmq - -from sglang.global_config import global_config -from sglang.srt.managers.controller.tp_worker import ModelTpClient -from sglang.srt.managers.io_struct import BatchTokenIDOut -from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import kill_parent_process -from sglang.utils import get_exception_traceback - -logger = logging.getLogger("srt.controller") -CHECKING_INTERVAL = 5 - -asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) - - -class DataParallelWorkerThread(threading.Thread): - def __init__( - self, - worker_id: int, - request_queue: queue.Queue, - detokenizer_port: int, - step_func: Callable, - ): - super(DataParallelWorkerThread, self).__init__() - self.worker_id = worker_id - self.request_queue = request_queue - self.liveness = True - self.request_dependency_delay = global_config.request_dependency_delay - - context = zmq.asyncio.Context() - self.send_to_detokenizer = context.socket(zmq.PUSH) - self.send_to_detokenizer.connect(f"tcp://127.0.0.1:{detokenizer_port}") - - self.step = step_func - - async def loop_for_forward(self): - while self.liveness: - requests = [] - while not self.request_queue.empty(): - requests.append(self.request_queue.get()) - - out_pyobjs: List[BatchTokenIDOut] = [] - try: - out_pyobjs = await self.step(requests) - except Exception: - for r in requests: - self.request_queue.put(r) - logger.error( - f"Worker thread {self.worker_id}: " - f"failed to get back from Model Server\n" - f"{get_exception_traceback()}" - ) - self.liveness = False - # Crash the whole server when there are any errors. - # TODO(lianmin): make this an option. - kill_parent_process() - return - - for obj in out_pyobjs: - self.send_to_detokenizer.send_pyobj(obj) - - # async sleep for receiving the subsequent request and avoiding cache miss - if len(out_pyobjs) != 0: - has_finished = any( - [obj.finished_reason is not None for obj in out_pyobjs] - ) - if has_finished: - await asyncio.sleep(self.request_dependency_delay) - await asyncio.sleep(global_config.wait_for_new_request_delay) - - async def monitoring(self): - while True: - await asyncio.sleep(CHECKING_INTERVAL) - # can plug in monitoring logic here - - def run(self): - logger.info(f"DataParallelWorkerThread {self.worker_id} start") - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.create_task(self.monitoring()) - loop.run_until_complete(self.loop_for_forward()) - - -def start_data_parallel_worker( - server_args: ServerArgs, - port_args: PortArgs, - model_overide_args, - gpu_ids: List[int], - worker_id: int, -): - model_tp_client = ModelTpClient( - gpu_ids, - server_args, - port_args.model_port_args[worker_id], - model_overide_args, - ) - worker_thread = DataParallelWorkerThread( - worker_id=worker_id, - request_queue=queue.Queue(), - detokenizer_port=port_args.detokenizer_port, - step_func=model_tp_client.step, - ) - worker_thread.start() - return worker_thread diff --git a/python/sglang/srt/managers/controller/manager_multi.py b/python/sglang/srt/managers/controller/manager_multi.py index ea942093a..188ee0e20 100644 --- a/python/sglang/srt/managers/controller/manager_multi.py +++ b/python/sglang/srt/managers/controller/manager_multi.py @@ -3,19 +3,17 @@ A controller that manages multiple data parallel workers. Each data parallel worker can manage multiple tensor parallel workers. """ -import asyncio +import dataclasses import logging -from concurrent.futures import ThreadPoolExecutor +import multiprocessing +import os from enum import Enum, auto -from typing import Dict +import numpy as np import zmq -import zmq.asyncio -from sglang.global_config import global_config -from sglang.srt.managers.controller.dp_worker import ( - DataParallelWorkerThread, - start_data_parallel_worker, +from sglang.srt.managers.controller.manager_single import ( + start_controller_process as start_controller_process_single, ) from sglang.srt.managers.io_struct import ( AbortReq, @@ -23,12 +21,14 @@ from sglang.srt.managers.io_struct import ( TokenizedGenerateReqInput, ) from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.utils import kill_parent_process from sglang.utils import get_exception_traceback logger = logging.getLogger("srt.controller") class LoadBalanceMethod(Enum): + """Load balance method.""" ROUND_ROBIN = auto() SHORTEST_QUEUE = auto() @@ -41,155 +41,155 @@ class LoadBalanceMethod(Enum): raise ValueError(f"Invalid load balance method: {method}") from exc -class Controller: +@dataclasses.dataclass +class WorkerHandle: + """Store the handle of a data parallel worker.""" + proc: multiprocessing.Process + queue: multiprocessing.Queue + + +class ControllerMulti: """A controller that manages multiple data parallel workers.""" def __init__( self, - load_balance_method: str, server_args: ServerArgs, port_args: PortArgs, model_overide_args, ): - self.load_balance_method = LoadBalanceMethod.from_str(load_balance_method) + # Parse args self.server_args = server_args self.port_args = port_args + self.model_overide_args = model_overide_args + self.load_balance_method = LoadBalanceMethod.from_str( + server_args.load_balance_method) - if self.load_balance_method == LoadBalanceMethod.ROUND_ROBIN: - self.round_robin_counter = 0 + # Init communication + context = zmq.Context() + self.recv_from_tokenizer = context.socket(zmq.PULL) + self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}") - self.dispatch_lookup = { + # Dispatch method + self.round_robin_counter = 0 + dispatch_lookup = { LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler, LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler, } - self.dispatching = self.dispatch_lookup[self.load_balance_method] - - # Init communication - context = zmq.asyncio.Context() - self.recv_from_tokenizer = context.socket(zmq.PULL) - self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}") - - # Init status - self.recv_reqs = [] + self.dispatching = dispatch_lookup[self.load_balance_method] # Start data parallel workers - self.workers: Dict[int, DataParallelWorkerThread] = {} - tp_size = server_args.tp_size - - def start_dp_worker(i): - try: - gpu_ids = list(range(i * tp_size, (i + 1) * tp_size)) - worker_thread = start_data_parallel_worker( - server_args, port_args, model_overide_args, gpu_ids, i - ) - self.workers[i] = worker_thread - except Exception: - logger.error( - f"Failed to start local worker {i}\n{get_exception_traceback()}" - ) - + self.workers = [] for i in range(server_args.dp_size): - start_dp_worker(i) + self.start_dp_worker(i) - # Parallel launch is slower, probably due to the disk bandwidth limitations. - # with ThreadPoolExecutor(server_args.dp_size) as executor: - # executor.map(start_dp_worker, range(server_args.dp_size)) + def start_dp_worker(self, dp_worker_id: int): + tp_size = self.server_args.tp_size - def have_any_live_worker(self): - return any(worker_thread.liveness for worker_thread in self.workers.values()) + pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(duplex=False) - def put_req_to_worker(self, worker_id, req): - self.workers[worker_id].request_queue.put(req) + gpu_ids = list(range(dp_worker_id * tp_size, (dp_worker_id + 1) * tp_size)) + queue = multiprocessing.Queue() + proc = multiprocessing.Process( + target=start_controller_process_single, + args=( + self.server_args, + self.port_args, + pipe_controller_writer, + self.model_overide_args, + True, + gpu_ids, + dp_worker_id, + queue, + ) + ) + proc.start() - async def round_robin_scheduler(self, input_requests): - available_workers = list(self.workers.keys()) + controller_init_state = pipe_controller_reader.recv() + if controller_init_state != "init ok": + raise RuntimeError( + f"Initialization failed. controller_init_state: {controller_init_state}" + ) + self.workers.append(WorkerHandle( + proc=proc, + queue=queue, + )) + + def round_robin_scheduler(self, input_requests): for r in input_requests: - self.put_req_to_worker(available_workers[self.round_robin_counter], r) + self.workers[self.round_robin_counter].queue.put(r) self.round_robin_counter = (self.round_robin_counter + 1) % len( - available_workers + self.workers ) - return - async def shortest_queue_scheduler(self, input_requests): + def shortest_queue_scheduler(self, input_requests): for r in input_requests: - worker = min( - self.workers, key=lambda w: self.workers[w].request_queue.qsize() - ) - self.put_req_to_worker(worker, r) - return + queue_sizes = [worker.queue.qsize() for worker in self.workers] + wid = np.argmin(queue_sizes) + self.workers[wid].queue.put(r) - async def remove_dead_workers(self): - for i in list(self.workers.keys()): - worker_thread = self.workers[i] - if not worker_thread.liveness: - worker_thread.join() - # move unsuccessful requests back to the queue - while not worker_thread.request_queue.empty(): - self.recv_reqs.append(worker_thread.request_queue.get()) - del self.workers[i] - logger.info(f"Stale worker {i} removed") - - async def loop_for_forward(self): + def loop_for_forward(self): while True: - await self.remove_dead_workers() + recv_reqs = self.recv_requests() + self.dispatching(recv_reqs) - if self.have_any_live_worker(): - next_step_input = list(self.recv_reqs) - self.recv_reqs = [] - if next_step_input: - await self.dispatching(next_step_input) - # else: - # logger.error("There is no live worker.") + def recv_requests(self): + recv_reqs = [] - await asyncio.sleep(global_config.wait_for_new_request_delay) - - async def loop_for_recv_requests(self): while True: - recv_req = await self.recv_from_tokenizer.recv_pyobj() + try: + recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) + except zmq.ZMQError: + break + if isinstance(recv_req, FlushCacheReq): # TODO(lsyin): apply more specific flushCacheReq - for worker_thread in self.workers.values(): - worker_thread.request_queue.put(recv_req) - elif isinstance(recv_req, TokenizedGenerateReqInput): - self.recv_reqs.append(recv_req) + for worker in self.workers: + worker.queue.put(recv_req) elif isinstance(recv_req, AbortReq): in_queue = False - for i, req in enumerate(self.recv_reqs): + for i, req in enumerate(recv_reqs): if req.rid == recv_req.rid: - self.recv_reqs[i] = recv_req + recv_reqs[i] = recv_req in_queue = True break if not in_queue: # Send abort req to all TP groups - for worker in list(self.workers.keys()): - self.put_req_to_worker(worker, recv_req) + for worker in self.workers: + worker.queue.put(recv_req) + elif isinstance(recv_req, TokenizedGenerateReqInput): + recv_reqs.append(recv_req) else: logger.error(f"Invalid object: {recv_req}") + return recv_reqs + def start_controller_process( server_args: ServerArgs, port_args: PortArgs, pipe_writer, - model_overide_args=None, + model_overide_args: dict, ): + """Start a controller process.""" + logging.basicConfig( level=getattr(logging, server_args.log_level.upper()), format="%(message)s", ) try: - controller = Controller( - server_args.load_balance_method, server_args, port_args, model_overide_args - ) + controller = ControllerMulti(server_args, port_args, model_overide_args) except Exception: pipe_writer.send(get_exception_traceback()) raise + pipe_writer.send("init ok") - loop = asyncio.new_event_loop() - loop.set_default_executor(ThreadPoolExecutor(max_workers=256)) - - asyncio.set_event_loop(loop) - loop.create_task(controller.loop_for_recv_requests()) - loop.run_until_complete(controller.loop_for_forward()) + try: + controller.loop_for_forward() + except Exception: + logger.error("Exception in ControllerMulti:\n" + get_exception_traceback()) + finally: + for w in controller.workers: + os.kill(w.proc.pid, 9) + kill_parent_process() diff --git a/python/sglang/srt/managers/controller/manager_single.py b/python/sglang/srt/managers/controller/manager_single.py index 37af98e9a..9326945f9 100644 --- a/python/sglang/srt/managers/controller/manager_single.py +++ b/python/sglang/srt/managers/controller/manager_single.py @@ -3,126 +3,61 @@ import logging import multiprocessing import os -import pickle +from typing import List -import torch -import torch.distributed as dist import zmq -import zmq.asyncio -from sglang.srt.managers.controller.tp_worker import ModelTpServer -from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs +from sglang.srt.managers.controller.tp_worker import ( + broadcast_recv_input, launch_tp_servers, ModelTpServer +) +from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import kill_parent_process from sglang.utils import get_exception_traceback logger = logging.getLogger("srt.controller") -def run_tp_server( - gpu_id: int, - tp_rank: int, - server_args: ServerArgs, - model_port_args: ModelPortArgs, - model_overide_args: dict, -): - """Run a tp server.""" - try: - model_server = ModelTpServer( - gpu_id, - tp_rank, - server_args, - model_port_args, - model_overide_args, - ) - tp_cpu_group = model_server.model_runner.tp_group.cpu_group - - while True: - recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group) - model_server.exposed_step(recv_reqs) - except Exception: - logger.error("Exception in run_tp_server:\n" + get_exception_traceback()) - raise - - -def launch_tp_servers( - gpu_ids, tp_rank_range, server_args, model_port_args, model_overide_args -): - """Launch multiple tp servers.""" - procs = [] - for i in tp_rank_range: - proc = multiprocessing.Process( - target=run_tp_server, - args=(gpu_ids[i], i, server_args, model_port_args, model_overide_args), - ) - proc.start() - procs.append(proc) - - return procs - - -def broadcast_recv_input(data, rank, dist_group): - """Broadcast inputs from rank=0 to all other ranks with torch.dist backend.""" - - if rank == 0: - if len(data) == 0: - tensor_size = torch.tensor([0], dtype=torch.long) - dist.broadcast(tensor_size, src=0, group=dist_group) - else: - serialized_data = pickle.dumps(data) - size = len(serialized_data) - tensor_data = torch.ByteTensor(list(serialized_data)) - tensor_size = torch.tensor([size], dtype=torch.long) - - dist.broadcast(tensor_size, src=0, group=dist_group) - dist.broadcast(tensor_data, src=0, group=dist_group) - else: - tensor_size = torch.tensor([0], dtype=torch.long) - dist.broadcast(tensor_size, src=0, group=dist_group) - size = tensor_size.item() - - if size == 0: - return [] - - tensor_data = torch.empty(size, dtype=torch.uint8) - dist.broadcast(tensor_data, src=0, group=dist_group) - - serialized_data = bytes(tensor_data.tolist()) - data = pickle.loads(serialized_data) - return data - - class ControllerSingle: """A controller that manages a group of tensor parallel workers.""" def __init__( - self, server_args: ServerArgs, port_args: PortArgs, model_overide_args: dict + self, + server_args: ServerArgs, + port_args: PortArgs, + model_overide_args: dict, + gpu_ids: List[int], + is_data_parallel_worker: bool, + dp_worker_id: int, + mp_queue: multiprocessing.Queue, ): # Parse args - self.server_args = server_args - self.tp_procs = [] + self.tp_size = server_args.tp_size + self.is_dp_worker = is_data_parallel_worker + self.dp_worker_id = dp_worker_id + self.mp_queue = mp_queue # Init communication context = zmq.Context(2) - self.recv_from_tokenizer = context.socket(zmq.PULL) - self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}") + + if not self.is_dp_worker: + self.recv_from_tokenizer = context.socket(zmq.PULL) + self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}") self.send_to_detokenizer = context.socket(zmq.PUSH) self.send_to_detokenizer.connect( f"tcp://127.0.0.1:{port_args.detokenizer_port}" ) - # Init model server - tp_size_local = server_args.tp_size // server_args.nnodes - gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)] - # Launch other tp ranks + tp_size_local = server_args.tp_size // server_args.nnodes + self.tp_procs = [] if tp_size_local > 1: tp_rank_range = range(1, tp_size_local) self.tp_procs = launch_tp_servers( gpu_ids, tp_rank_range, server_args, - port_args.model_port_args[0], + port_args.nccl_ports[dp_worker_id], model_overide_args, ) @@ -131,16 +66,19 @@ class ControllerSingle: gpu_ids[0], 0, server_args, - port_args.model_port_args[0], + port_args.nccl_ports[dp_worker_id], model_overide_args, ) self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group def loop_for_forward(self): while True: - recv_reqs = self.recv_requests() + if not self.is_dp_worker: + recv_reqs = self.recv_requests_from_zmq() + else: + recv_reqs = self.recv_requests_from_mp_queue() - if self.server_args.tp_size > 1: + if self.tp_size > 1: broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group) out_pyobjs = self.tp_server.exposed_step(recv_reqs) @@ -148,27 +86,51 @@ class ControllerSingle: for obj in out_pyobjs: self.send_to_detokenizer.send_pyobj(obj) - def recv_requests(self): + def recv_requests_from_zmq(self): recv_reqs = [] while True: try: recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) - recv_reqs.append(recv_req) except zmq.ZMQError: break + recv_reqs.append(recv_req) + + return recv_reqs + + def recv_requests_from_mp_queue(self): + recv_reqs = [] + while not self.mp_queue.empty(): + recv_reqs.append(self.mp_queue.get()) return recv_reqs def start_controller_process( - server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args: dict + server_args: ServerArgs, + port_args: PortArgs, + pipe_writer: multiprocessing.connection.Connection, + model_overide_args: dict, + is_data_parallel_worker: bool = False, + gpu_ids: List[int] = None, + dp_worker_id: int = None, + queue: multiprocessing.connection.Connection = None, ): + """Start a controller process.""" + logging.basicConfig( level=getattr(logging, server_args.log_level.upper()), format="%(message)s", ) + if not is_data_parallel_worker: + tp_size_local = server_args.tp_size // server_args.nnodes + gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)] + dp_worker_id = 0 + queue = None + try: - controller = ControllerSingle(server_args, port_args, model_overide_args) + controller = ControllerSingle(server_args, port_args, model_overide_args, + gpu_ids, is_data_parallel_worker, + dp_worker_id, queue) except Exception: pipe_writer.send(get_exception_traceback()) raise diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index 80b051644..14a557e27 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -1,15 +1,14 @@ """A tensor parallel worker.""" -import asyncio import logging +import multiprocessing +import pickle import time import warnings -from concurrent.futures import ThreadPoolExecutor from typing import List, Optional -import rpyc import torch -from rpyc.utils.classic import obtain +import torch.distributed as dist from sglang.global_config import global_config from sglang.srt.constrained.fsm_cache import FSMCache @@ -32,13 +31,11 @@ from sglang.srt.managers.io_struct import ( TokenizedGenerateReqInput, ) from sglang.srt.model_config import ModelConfig -from sglang.srt.server_args import ModelPortArgs, ServerArgs +from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( - connect_rpyc_service, get_int_token_logit_bias, is_multimodal_model, set_random_seed, - start_rpyc_service_process, suppress_other_loggers, ) from sglang.utils import get_exception_traceback @@ -52,10 +49,9 @@ class ModelTpServer: gpu_id: int, tp_rank: int, server_args: ServerArgs, - model_port_args: ModelPortArgs, + nccl_port: int, model_overide_args: dict, ): - server_args, model_port_args = obtain(server_args), obtain(model_port_args) suppress_other_loggers() # Copy arguments @@ -79,7 +75,7 @@ class ModelTpServer: gpu_id=gpu_id, tp_rank=tp_rank, tp_size=server_args.tp_size, - nccl_port=model_port_args.nccl_port, + nccl_port=nccl_port, server_args=server_args, ) @@ -178,9 +174,6 @@ class ModelTpServer: self.new_token_ratio_recovery = global_config.new_token_ratio_recovery def exposed_step(self, recv_reqs): - if not isinstance(recv_reqs, list): - recv_reqs = obtain(recv_reqs) - try: # Recv requests for recv_req in recv_reqs: @@ -425,12 +418,6 @@ class ModelTpServer: f"#running-req: {running_bs}, " f"#queue-req: {len(self.forward_queue) - len(can_run_list)}" ) - # logger.debug( - # f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. " - # f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. " - # f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. " - # f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. " - # ) # Return the new batch new_batch = Batch.init_new( @@ -733,87 +720,74 @@ class ModelTpServer: break -class ModelTpService(rpyc.Service): - exposed_ModelTpServer = ModelTpServer +def run_tp_server( + gpu_id: int, + tp_rank: int, + server_args: ServerArgs, + nccl_port: int, + model_overide_args: dict, +): + """Run a tensor parallel server.""" + try: + model_server = ModelTpServer( + gpu_id, + tp_rank, + server_args, + nccl_port, + model_overide_args, + ) + tp_cpu_group = model_server.model_runner.tp_group.cpu_group + + while True: + recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group) + model_server.exposed_step(recv_reqs) + except Exception: + logger.error("Exception in run_tp_server:\n" + get_exception_traceback()) + raise -class ModelTpClient: - def __init__( - self, - gpu_ids: List[int], - server_args: ServerArgs, - model_port_args: ModelPortArgs, - model_overide_args, - ): - server_args, model_port_args = obtain(server_args), obtain(model_port_args) - self.tp_size = server_args.tp_size +def launch_tp_servers( + gpu_ids, tp_rank_range, server_args, nccl_port, model_overide_args +): + """Launch multiple tensor parallel servers.""" + procs = [] + for i in tp_rank_range: + proc = multiprocessing.Process( + target=run_tp_server, + args=(gpu_ids[i], i, server_args, nccl_port, model_overide_args), + ) + proc.start() + procs.append(proc) - if self.tp_size * server_args.dp_size == 1: - # Init model - assert len(gpu_ids) == 1 - self.model_server = ModelTpService().exposed_ModelTpServer( - gpu_ids[0], - 0, - server_args, - model_port_args, - model_overide_args, - ) + return procs - # Wrap functions - def async_wrap(f): - async def _func(*args, **kwargs): - return f(*args, **kwargs) - return _func +def broadcast_recv_input(data, rank, dist_group): + """Broadcast inputs from rank=0 to all other ranks with torch.dist backend.""" - self.step = async_wrap(self.model_server.exposed_step) + if rank == 0: + if len(data) == 0: + tensor_size = torch.tensor([0], dtype=torch.long) + dist.broadcast(tensor_size, src=0, group=dist_group) else: - with ThreadPoolExecutor(self.tp_size) as executor: - # Launch model processes - if server_args.nnodes == 1: - self.procs = list( - executor.map( - lambda args: start_rpyc_service_process(*args), - [ - (ModelTpService, p) - for p in model_port_args.model_tp_ports - ], - ) - ) - addrs = [("localhost", p) for p in model_port_args.model_tp_ports] - else: - addrs = [ - (ip, port) - for ip, port in zip( - model_port_args.model_tp_ips, model_port_args.model_tp_ports - ) - ] + serialized_data = pickle.dumps(data) + size = len(serialized_data) + tensor_data = torch.ByteTensor(list(serialized_data)) + tensor_size = torch.tensor([size], dtype=torch.long) - self.model_services = list( - executor.map(lambda args: connect_rpyc_service(*args), addrs) - ) + dist.broadcast(tensor_size, src=0, group=dist_group) + dist.broadcast(tensor_data, src=0, group=dist_group) + else: + tensor_size = torch.tensor([0], dtype=torch.long) + dist.broadcast(tensor_size, src=0, group=dist_group) + size = tensor_size.item() - # Init model - def init_model(i): - return self.model_services[i].ModelTpServer( - gpu_ids[i], - i, - server_args, - model_port_args, - model_overide_args, - ) + if size == 0: + return [] - self.model_servers = list(executor.map(init_model, range(self.tp_size))) + tensor_data = torch.empty(size, dtype=torch.uint8) + dist.broadcast(tensor_data, src=0, group=dist_group) - # Wrap functions - def async_wrap(func_name): - fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers] - - async def _func(*args, **kwargs): - tasks = [f(*args, **kwargs) for f in fs] - await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks]) - return obtain(tasks[0].value) - - return _func - - self.step = async_wrap("step") + serialized_data = bytes(tensor_data.tolist()) + data = pickle.loads(serialized_data) + return data diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 75af8e62c..0d3f1aa91 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -61,7 +61,7 @@ class TokenizerManager: self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}") self.send_to_router = context.socket(zmq.PUSH) - self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.router_port}") + self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}") self.model_path = server_args.model_path self.hf_config = get_config( diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index cef04bc4b..9467f95c0 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -44,15 +44,13 @@ from sglang.srt.openai_api_adapter import ( v1_chat_completions, v1_completions, ) -from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs +from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( API_KEY_HEADER_NAME, APIKeyValidatorMiddleware, allocate_init_ports, assert_pkg_version, enable_show_time_cost, - receive_addrs, - send_addrs_to_rank_0, ) from sglang.utils import get_exception_traceback @@ -98,6 +96,7 @@ async def flush_cache(): async def generate_request(obj: GenerateReqInput, request: Request): + """Handle a generate request.""" if obj.stream: async def stream_results(): @@ -146,7 +145,10 @@ def _set_global_server_args(server_args: ServerArgs): } -def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None): +def launch_server(server_args: ServerArgs, + model_overide_args: Optional[dict] = None, + pipe_finish_writer: Optional[mp.connection.Connection] = None): + """Launch an HTTP server.""" global tokenizer_manager logging.basicConfig( @@ -173,39 +175,23 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg if server_args.chat_template: # TODO: replace this with huggingface transformers template load_chat_template_for_openai_api(server_args.chat_template) - _set_global_server_args(server_args) # Allocate ports - assert server_args.tp_size % server_args.nnodes == 0 - tp_size_local = server_args.tp_size // server_args.nnodes server_args.port, server_args.additional_ports = allocate_init_ports( server_args.port, server_args.additional_ports, - tp_size_local, server_args.dp_size, ) - ports = server_args.additional_ports - model_port_args = [] - for i in range(server_args.dp_size): - model_port_args.append( - ModelPortArgs( - nccl_port=ports[3 + i * (tp_size_local + 1)], - model_tp_ips=[None] * tp_size_local, - model_tp_ports=ports[ - 3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1) - ], - ) - ) port_args = PortArgs( tokenizer_port=ports[0], - router_port=ports[1], + controller_port=ports[1], detokenizer_port=ports[2], - model_port_args=model_port_args, + nccl_ports=ports[3:], ) - # Handle multi-node tp + # Handle multi-node tensor parallelism if server_args.nnodes > 1: assert server_args.dp_size == 1, "Multi-node dp is not supported." @@ -224,7 +210,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg gpu_ids, tp_rank_range, server_args, - port_args.model_port_args[0], + ports[3], model_overide_args, ) while True: @@ -232,18 +218,18 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg # Launch processes tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args) - pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False) + pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False) pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False) if server_args.dp_size == 1: start_process = start_controller_process_single else: start_process = start_controller_process_multi - proc_router = mp.Process( + proc_controller = mp.Process( target=start_process, - args=(server_args, port_args, pipe_router_writer, model_overide_args), + args=(server_args, port_args, pipe_controller_writer, model_overide_args), ) - proc_router.start() + proc_controller.start() proc_detoken = mp.Process( target=start_detokenizer_process, args=( @@ -255,68 +241,27 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg proc_detoken.start() # Wait for the model to finish loading - router_init_state = pipe_router_reader.recv() + controller_init_state = pipe_controller_reader.recv() detoken_init_state = pipe_detoken_reader.recv() - if router_init_state != "init ok" or detoken_init_state != "init ok": - proc_router.kill() + if controller_init_state != "init ok" or detoken_init_state != "init ok": + proc_controller.kill() proc_detoken.kill() print( - f"Initialization failed. router_init_state: {router_init_state}", flush=True + f"Initialization failed. controller_init_state: {controller_init_state}", flush=True ) print( f"Initialization failed. detoken_init_state: {detoken_init_state}", flush=True, ) sys.exit(1) - assert proc_router.is_alive() and proc_detoken.is_alive() + assert proc_controller.is_alive() and proc_detoken.is_alive() if server_args.api_key and server_args.api_key != "": app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key) # Send a warmup request - def _wait_and_warmup(): - headers = {} - url = server_args.url() - if server_args.api_key: - headers[API_KEY_HEADER_NAME] = server_args.api_key - - # Wait until the server is launched - for _ in range(120): - time.sleep(0.5) - try: - requests.get(url + "/get_model_info", timeout=5, headers=headers) - break - except requests.exceptions.RequestException: - pass - - # Send a warmup request - try: - for _ in range(server_args.dp_size): - res = requests.post( - url + "/generate", - json={ - "text": "The capital city of France is", - "sampling_params": { - "temperature": 0, - "max_new_tokens": 8, - }, - }, - headers=headers, - timeout=600, - ) - assert res.status_code == 200 - except Exception as e: - if pipe_finish_writer is not None: - pipe_finish_writer.send(get_exception_traceback()) - print(f"Initialization failed. warmup error: {e}", flush=True) - raise e - - logger.info("The server is fired up and ready to roll!") - if pipe_finish_writer is not None: - pipe_finish_writer.send("init ok") - - t = threading.Thread(target=_wait_and_warmup) + t = threading.Thread(target=_wait_and_warmup, args=(server_args, pipe_finish_writer)) t.start() # Listen for requests @@ -333,6 +278,48 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg t.join() +def _wait_and_warmup(server_args, pipe_finish_writer): + headers = {} + url = server_args.url() + if server_args.api_key: + headers[API_KEY_HEADER_NAME] = server_args.api_key + + # Wait until the server is launched + for _ in range(120): + time.sleep(0.5) + try: + requests.get(url + "/get_model_info", timeout=5, headers=headers) + break + except requests.exceptions.RequestException: + pass + + # Send a warmup request + try: + for _ in range(server_args.dp_size): + res = requests.post( + url + "/generate", + json={ + "text": "The capital city of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 8, + }, + }, + headers=headers, + timeout=600, + ) + assert res.status_code == 200 + except Exception as e: + if pipe_finish_writer is not None: + pipe_finish_writer.send(get_exception_traceback()) + print(f"Initialization failed. warmup error: {e}", flush=True) + raise e + + logger.info("The server is fired up and ready to roll!") + if pipe_finish_writer is not None: + pipe_finish_writer.send("init ok") + + class Runtime: """ A wrapper for the server. @@ -354,7 +341,6 @@ class Runtime: self.server_args.port, self.server_args.additional_ports = allocate_init_ports( self.server_args.port, self.server_args.additional_ports, - self.server_args.tp_size, self.server_args.dp_size, ) @@ -367,7 +353,7 @@ class Runtime: pipe_reader, pipe_writer = mp.Pipe(duplex=False) proc = mp.Process( target=launch_server, - args=(self.server_args, pipe_writer, model_overide_args), + args=(self.server_args, model_overide_args, pipe_writer), ) proc.start() pipe_writer.close() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b4f79c066..7c0317fc0 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -337,16 +337,9 @@ class ServerArgs: ) -@dataclasses.dataclass -class ModelPortArgs: - nccl_port: int - model_tp_ips: List[str] - model_tp_ports: List[int] - - @dataclasses.dataclass class PortArgs: tokenizer_port: int - router_port: int + controller_port: int detokenizer_port: int - model_port_args: List[ModelPortArgs] + nccl_ports: List[int] diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 981b5e218..66f051ea7 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -3,7 +3,6 @@ import base64 import fcntl import logging -import multiprocessing import os import random import socket @@ -16,12 +15,10 @@ from typing import List, Optional import numpy as np import psutil import requests -import rpyc import torch import triton from fastapi.responses import JSONResponse from packaging import version as pkg_version -from rpyc.utils.server import ThreadedServer from starlette.middleware.base import BaseHTTPMiddleware logger = logging.getLogger(__name__) @@ -148,7 +145,6 @@ def is_port_available(port): def allocate_init_ports( port: Optional[int] = None, additional_ports: Optional[List[int]] = None, - tp_size: int = 1, dp_size: int = 1, ): """Allocate ports for all connections.""" @@ -160,8 +156,8 @@ def allocate_init_ports( ret_ports = list(set(x for x in ret_ports if is_port_available(x))) cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000 - # HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size) - num_ports_needed = 4 + dp_size * (1 + tp_size) + # HTTP + Tokenizer + Controller + Detokenizer + dp_size * 1 (nccl) + num_ports_needed = 4 + dp_size while len(ret_ports) < num_ports_needed: if cur_port not in ret_ports and is_port_available(cur_port): ret_ports.append(cur_port) @@ -371,49 +367,6 @@ def load_image(image_file): return image, image_size -def connect_rpyc_service(host, port): - repeat_count = 0 - while repeat_count < 20: - try: - con = rpyc.connect( - host, - port, - config={ - "allow_public_attrs": True, - "allow_pickle": True, - "sync_request_timeout": 3600, - }, - ) - break - except ConnectionRefusedError as e: - time.sleep(1) - repeat_count += 1 - if repeat_count == 20: - raise RuntimeError(f"Connect rpyc error: {e}") - - return con.root - - -def start_rpyc_service(service: rpyc.Service, port: int): - t = ThreadedServer( - service=service, - port=port, - protocol_config={ - "allow_public_attrs": True, - "allow_pickle": True, - "sync_request_timeout": 3600, - }, - ) - t.logger.setLevel(logging.WARN) - t.start() - - -def start_rpyc_service_process(service: rpyc.Service, port: int): - proc = multiprocessing.Process(target=start_rpyc_service, args=(service, port)) - proc.start() - return proc - - def suppress_other_loggers(): from vllm.logger import logger as vllm_default_logger