From 0463f7fb52f06dcae2b10b7ca2a18a86ac135f96 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Mon, 27 May 2024 21:24:10 -0700 Subject: [PATCH] Support data parallelism (static) (#480) Co-authored-by: Ying Sheng Co-authored-by: Lianmin Zheng Co-authored-by: Liangsheng Yin Co-authored-by: Zhiqiang Xie --- python/sglang/global_config.py | 3 +- python/sglang/srt/layers/logits_processor.py | 2 +- python/sglang/srt/layers/radix_attention.py | 4 +- python/sglang/srt/layers/token_attention.py | 2 +- .../srt/managers/controller/dp_worker.py | 102 ++++++++++ .../{router => controller}/infer_batch.py | 3 +- .../srt/managers/controller/manager_multi.py | 187 ++++++++++++++++++ .../manager_single.py} | 30 +-- .../{router => controller}/model_runner.py | 62 +++--- .../{router => controller}/radix_cache.py | 0 .../schedule_heuristic.py} | 2 +- .../model_rpc.py => controller/tp_worker.py} | 151 ++++++-------- python/sglang/srt/managers/io_struct.py | 3 +- python/sglang/srt/models/commandr.py | 2 +- python/sglang/srt/models/dbrx.py | 2 +- python/sglang/srt/models/gemma.py | 2 +- python/sglang/srt/models/grok.py | 2 +- python/sglang/srt/models/llama2.py | 10 +- python/sglang/srt/models/llava.py | 4 +- python/sglang/srt/models/llavavid.py | 4 +- python/sglang/srt/models/mixtral.py | 2 +- python/sglang/srt/models/mixtral_quant.py | 2 +- python/sglang/srt/models/qwen.py | 2 +- python/sglang/srt/models/qwen2.py | 2 +- python/sglang/srt/models/stablelm.py | 2 +- python/sglang/srt/server.py | 44 +++-- python/sglang/srt/server_args.py | 31 ++- python/sglang/srt/utils.py | 87 ++++++-- test/srt/model/bench_llama_low_api.py | 2 +- test/srt/model/test_llama_extend.py | 4 +- test/srt/model/test_llama_low_api.py | 2 +- test/srt/model/test_llava_low_api.py | 4 +- 32 files changed, 580 insertions(+), 181 deletions(-) create mode 100644 python/sglang/srt/managers/controller/dp_worker.py rename python/sglang/srt/managers/{router => controller}/infer_batch.py (99%) create mode 100644 python/sglang/srt/managers/controller/manager_multi.py rename python/sglang/srt/managers/{router/manager.py => controller/manager_single.py} (70%) rename python/sglang/srt/managers/{router => controller}/model_runner.py (90%) rename python/sglang/srt/managers/{router => controller}/radix_cache.py (100%) rename python/sglang/srt/managers/{router/scheduler.py => controller/schedule_heuristic.py} (98%) rename python/sglang/srt/managers/{router/model_rpc.py => controller/tp_worker.py} (90%) diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index 452412bec..ba1a17b38 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -26,7 +26,8 @@ class GlobalConfig: self.concate_and_append_mode = "no_adjust" # Request dependency time due to network delay - self.request_dependency_time = 0.03 + self.request_dependency_delay = 0.03 + self.wait_for_new_request_delay = 0.0006 # New generation token ratio estimation self.base_new_token_ratio = 0.4 diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index e47a286eb..e7efaadec 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -5,7 +5,7 @@ from vllm.distributed import ( tensor_model_parallel_all_gather, ) -from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata +from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata class LogitsProcessor(nn.Module): diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 59d1a54da..7d0475e50 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -5,7 +5,7 @@ from torch import nn from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd from sglang.srt.layers.extend_attention import extend_attention_fwd from sglang.srt.layers.token_attention import token_attention_fwd -from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata +from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata class RadixAttention(nn.Module): @@ -20,7 +20,7 @@ class RadixAttention(nn.Module): assert np.allclose(scaling, 1.0 / (head_dim**0.5)) - from sglang.srt.managers.router.model_runner import global_server_args_dict + from sglang.srt.managers.controller.model_runner import global_server_args_dict if global_server_args_dict.get("enable_flashinfer", False): self.prefill_forward = self.prefill_forward_flashinfer diff --git a/python/sglang/srt/layers/token_attention.py b/python/sglang/srt/layers/token_attention.py index 58e3fa611..73b25aa85 100644 --- a/python/sglang/srt/layers/token_attention.py +++ b/python/sglang/srt/layers/token_attention.py @@ -5,7 +5,7 @@ import torch import triton import triton.language as tl -from sglang.srt.managers.router.model_runner import global_server_args_dict +from sglang.srt.managers.controller.model_runner import global_server_args_dict from sglang.srt.utils import wrap_kernel_launcher if global_server_args_dict.get("attention_reduce_in_fp32", False): diff --git a/python/sglang/srt/managers/controller/dp_worker.py b/python/sglang/srt/managers/controller/dp_worker.py new file mode 100644 index 000000000..3e300a17a --- /dev/null +++ b/python/sglang/srt/managers/controller/dp_worker.py @@ -0,0 +1,102 @@ +"""A data parallel worker thread.""" +import asyncio +import logging +import queue +import threading +from typing import List, Callable + +import uvloop +import zmq + +from sglang.global_config import global_config +from sglang.srt.managers.controller.tp_worker import ModelTpClient +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.utils import get_exception_traceback + +logger = logging.getLogger("srt.controller") +CHECKING_INTERVAL = 5 + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + + +class DataParallelWorkerThread(threading.Thread): + def __init__( + self, + worker_id: int, + request_queue: queue.Queue, + detokenizer_port: int, + step_func: Callable, + ): + super(DataParallelWorkerThread, self).__init__() + self.worker_id = worker_id + self.request_queue = request_queue + self.liveness = True + self.request_dependency_delay = global_config.request_dependency_delay + + context = zmq.asyncio.Context() + self.send_to_detokenizer = context.socket(zmq.PUSH) + self.send_to_detokenizer.connect(f"tcp://127.0.0.1:{detokenizer_port}") + + self.step = step_func + + async def loop_for_forward(self): + while self.liveness: + requests = [] + while not self.request_queue.empty(): + requests.append(self.request_queue.get()) + try: + out_pyobjs = await self.step(requests) + except Exception: + for r in requests: + self.request_queue.put(r) + logger.error( + f"Worker thread {self.worker_id}: " + f"failed to get back from Model Server\n" + f"{get_exception_traceback()}" + ) + self.liveness = False + + for obj in out_pyobjs: + self.send_to_detokenizer.send_pyobj(obj) + + # async sleep for receiving the subsequent request and avoiding cache miss + if len(out_pyobjs) != 0: + has_finished = any([obj.finished for obj in out_pyobjs]) + if has_finished: + await asyncio.sleep(self.request_dependency_delay) + await asyncio.sleep(global_config.wait_for_new_request_delay) + + async def monitoring(self): + while True: + await asyncio.sleep(CHECKING_INTERVAL) + # can plug in monitoring logic here + + def run(self): + logger.info(f"DataParallelWorkerThread {self.worker_id} start") + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.create_task(self.monitoring()) + loop.run_until_complete(self.loop_for_forward()) + + +def start_data_parallel_worker( + server_args: ServerArgs, + port_args: PortArgs, + model_overide_args, + gpu_ids: List[int], + worker_id: int, +): + model_tp_client = ModelTpClient( + gpu_ids, + server_args, + port_args.model_port_args[worker_id], + model_overide_args, + ) + worker_thread = DataParallelWorkerThread( + worker_id=worker_id, + request_queue=queue.Queue(), + detokenizer_port=port_args.detokenizer_port, + step_func=model_tp_client.step, + ) + worker_thread.start() + return worker_thread \ No newline at end of file diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py similarity index 99% rename from python/sglang/srt/managers/router/infer_batch.py rename to python/sglang/srt/managers/controller/infer_batch.py index fb4afa332..6b82c9f07 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -1,3 +1,4 @@ +"""Meta data for requests and batches""" from dataclasses import dataclass from enum import IntEnum, auto from typing import List @@ -5,7 +6,7 @@ from typing import List import numpy as np import torch -from sglang.srt.managers.router.radix_cache import RadixCache +from sglang.srt.managers.controller.radix_cache import RadixCache from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool diff --git a/python/sglang/srt/managers/controller/manager_multi.py b/python/sglang/srt/managers/controller/manager_multi.py new file mode 100644 index 000000000..a3175c92e --- /dev/null +++ b/python/sglang/srt/managers/controller/manager_multi.py @@ -0,0 +1,187 @@ +""" +A controller that manages multiple data parallel workers. +Each data parallel worker can manage multiple tensor parallel workers. +""" + +import asyncio +import logging +from concurrent.futures import ThreadPoolExecutor +from enum import Enum, auto +from typing import Dict + +import zmq +import zmq.asyncio + +from sglang.global_config import global_config +from sglang.srt.managers.io_struct import ( + AbortReq, + FlushCacheReq, + TokenizedGenerateReqInput, +) +from sglang.srt.managers.controller.dp_worker import ( + DataParallelWorkerThread, + start_data_parallel_worker, +) +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.utils import get_exception_traceback + +logger = logging.getLogger("srt.controller") + + +class LoadBalanceMethod(Enum): + ROUND_ROBIN = auto() + SHORTEST_QUEUE = auto() + + @classmethod + def from_str(cls, method: str): + method = method.upper() + try: + return cls[method] + except KeyError as exc: + raise ValueError(f"Invalid load balance method: {method}") from exc + + +class Controller: + def __init__( + self, + load_balance_method: str, + server_args: ServerArgs, + port_args: PortArgs, + model_overide_args, + ): + self.load_balance_method = LoadBalanceMethod.from_str(load_balance_method) + self.server_args = server_args + self.port_args = port_args + + if self.load_balance_method == LoadBalanceMethod.ROUND_ROBIN: + self.round_robin_counter = 0 + + self.dispatch_lookup = { + LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler, + LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler, + } + self.dispatching = self.dispatch_lookup[self.load_balance_method] + + # Init communication + context = zmq.asyncio.Context() + self.recv_from_tokenizer = context.socket(zmq.PULL) + self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}") + + # Init status + self.recv_reqs = [] + + # Start data parallel workers + self.workers: Dict[int, DataParallelWorkerThread] = {} + tp_size = server_args.tp_size + + def start_dp_worker(i): + try: + gpu_ids = list(range(i * tp_size, (i + 1) * tp_size)) + worker_thread = start_data_parallel_worker( + server_args, port_args, model_overide_args, gpu_ids, i + ) + self.workers[i] = worker_thread + except Exception: + logger.error( + f"Failed to start local worker {i}\n{get_exception_traceback()}" + ) + + with ThreadPoolExecutor(server_args.dp_size) as executor: + executor.map(start_dp_worker, range(server_args.dp_size)) + + def have_any_live_worker(self): + return any(worker_thread.liveness for worker_thread in self.workers.values()) + + def put_req_to_worker(self, worker_id, req): + self.workers[worker_id].request_queue.put(req) + + async def round_robin_scheduler(self, input_requests): + available_workers = list(self.workers.keys()) + for r in input_requests: + self.put_req_to_worker(available_workers[self.round_robin_counter], r) + self.round_robin_counter = (self.round_robin_counter + 1) % len( + available_workers + ) + return + + async def shortest_queue_scheduler(self, input_requests): + for r in input_requests: + worker = min( + self.workers, key=lambda w: self.workers[w].request_queue.qsize() + ) + self.put_req_to_worker(worker, r) + return + + async def remove_dead_workers(self): + for i in list(self.workers.keys()): + worker_thread = self.workers[i] + if not worker_thread.liveness: + worker_thread.join() + # move unsuccessful requests back to the queue + while not worker_thread.request_queue.empty(): + self.recv_reqs.append(worker_thread.request_queue.get()) + del self.workers[i] + logger.info(f"Stale worker {i} removed") + + async def loop_for_forward(self): + while True: + await self.remove_dead_workers() + + if self.have_any_live_worker(): + next_step_input = list(self.recv_reqs) + self.recv_reqs = [] + if next_step_input: + await self.dispatching(next_step_input) + #else: + # logger.error("There is no live worker.") + + await asyncio.sleep(global_config.wait_for_new_request_delay) + + async def loop_for_recv_requests(self): + while True: + recv_req = await self.recv_from_tokenizer.recv_pyobj() + if isinstance(recv_req, FlushCacheReq): + # TODO(lsyin): apply more specific flushCacheReq + for worker_thread in self.workers.values(): + worker_thread.request_queue.put(recv_req) + elif isinstance(recv_req, TokenizedGenerateReqInput): + self.recv_reqs.append(recv_req) + elif isinstance(recv_req, AbortReq): + in_queue = False + for i, req in enumerate(self.recv_reqs): + if req.rid == recv_req.rid: + self.recv_reqs[i] = recv_req + in_queue = True + break + if not in_queue: + # Send abort req to all TP groups + for worker in list(self.workers.keys()): + self.put_req_to_worker(worker, recv_req) + else: + logger.error(f"Invalid object: {recv_req}") + + +def start_controller_process( + server_args: ServerArgs, + port_args: PortArgs, + pipe_writer, + model_overide_args=None, +): + logging.basicConfig( + level=getattr(logging, server_args.log_level.upper()), + format="%(message)s", + ) + + try: + controller = Controller( + server_args.load_balance_method, server_args, port_args, model_overide_args + ) + except Exception: + pipe_writer.send(get_exception_traceback()) + raise + + pipe_writer.send("init ok") + loop = asyncio.get_event_loop() + asyncio.set_event_loop(loop) + loop.create_task(controller.loop_for_recv_requests()) + loop.run_until_complete(controller.loop_for_forward()) diff --git a/python/sglang/srt/managers/router/manager.py b/python/sglang/srt/managers/controller/manager_single.py similarity index 70% rename from python/sglang/srt/managers/router/manager.py rename to python/sglang/srt/managers/controller/manager_single.py index f0e856998..227b8a7b7 100644 --- a/python/sglang/srt/managers/router/manager.py +++ b/python/sglang/srt/managers/controller/manager_single.py @@ -1,3 +1,4 @@ +"""A controller that manages a group of tensor parallel workers.""" import asyncio import logging @@ -6,15 +7,15 @@ import zmq import zmq.asyncio from sglang.global_config import global_config -from sglang.srt.managers.router.model_rpc import ModelRpcClient +from sglang.srt.managers.controller.tp_worker import ModelTpClient from sglang.srt.server_args import PortArgs, ServerArgs from sglang.utils import get_exception_traceback asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -class RouterManager: - def __init__(self, model_client: ModelRpcClient, port_args: PortArgs): +class ControllerSingle: + def __init__(self, model_client: ModelTpClient, port_args: PortArgs): # Init communication context = zmq.asyncio.Context(2) self.recv_from_tokenizer = context.socket(zmq.PULL) @@ -30,7 +31,7 @@ class RouterManager: self.recv_reqs = [] # Init some configs - self.request_dependency_time = global_config.request_dependency_time + self.request_dependency_delay = global_config.request_dependency_delay async def loop_for_forward(self): while True: @@ -46,12 +47,12 @@ class RouterManager: if len(out_pyobjs) != 0: has_finished = any([obj.finished for obj in out_pyobjs]) if has_finished: - if self.request_dependency_time > 0: + if self.request_dependency_delay > 0: slept = True - await asyncio.sleep(self.request_dependency_time) + await asyncio.sleep(self.request_dependency_delay) if not slept: - await asyncio.sleep(0.0006) + await asyncio.sleep(global_config.wait_for_new_request_delay) async def loop_for_recv_requests(self): while True: @@ -59,7 +60,7 @@ class RouterManager: self.recv_reqs.append(recv_req) -def start_router_process( +def start_controller_process( server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args ): logging.basicConfig( @@ -68,8 +69,13 @@ def start_router_process( ) try: - model_client = ModelRpcClient(server_args, port_args, model_overide_args) - router = RouterManager(model_client, port_args) + model_client = ModelTpClient( + list(range(server_args.tp_size)), + server_args, + port_args.model_port_args[0], + model_overide_args, + ) + controller = ControllerSingle(model_client, port_args) except Exception: pipe_writer.send(get_exception_traceback()) raise @@ -78,5 +84,5 @@ def start_router_process( loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - loop.create_task(router.loop_for_recv_requests()) - loop.run_until_complete(router.loop_for_forward()) + loop.create_task(controller.loop_for_recv_requests()) + loop.run_until_complete(controller.loop_for_forward()) \ No newline at end of file diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py similarity index 90% rename from python/sglang/srt/managers/router/model_runner.py rename to python/sglang/srt/managers/controller/model_runner.py index 40be207b6..6a64c84f5 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -15,13 +15,13 @@ from vllm.distributed import initialize_model_parallel from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import ModelRegistry -from sglang.srt.managers.router.infer_batch import Batch, ForwardMode +from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.server_args import ServerArgs from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model -logger = logging.getLogger("model_runner") +logger = logging.getLogger("srt.model_runner") # for server args in model endpoints global_server_args_dict = {} @@ -215,14 +215,16 @@ class ModelRunner: def __init__( self, model_config, - mem_fraction_static, - tp_rank, - tp_size, - nccl_port, + mem_fraction_static: float, + gpu_id: int, + tp_rank: int, + tp_size: int, + nccl_port: int, server_args: ServerArgs, ): self.model_config = model_config self.mem_fraction_static = mem_fraction_static + self.gpu_id = gpu_id self.tp_rank = tp_rank self.tp_size = tp_size self.nccl_port = nccl_port @@ -235,9 +237,9 @@ class ModelRunner: } # Init torch distributed - logger.info(f"[rank={self.tp_rank}] Set cuda device.") - torch.cuda.set_device(self.tp_rank) - logger.info(f"[rank={self.tp_rank}] Init torch begin. Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB") + logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.") + torch.cuda.set_device(self.gpu_id) + logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.") torch.distributed.init_process_group( backend="nccl", world_size=self.tp_size, @@ -245,22 +247,26 @@ class ModelRunner: init_method=f"tcp://127.0.0.1:{self.nccl_port}", ) initialize_model_parallel(tensor_model_parallel_size=self.tp_size) - logger.info(f"[rank={self.tp_rank}] Init torch end.") - - total_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1) + total_gpu_memory = get_available_gpu_memory( + self.gpu_id, distributed=self.tp_size > 1 + ) if self.tp_size > 1: - total_local_gpu_memory = get_available_gpu_memory(self.tp_rank) + total_local_gpu_memory = get_available_gpu_memory(self.gpu_id) if total_local_gpu_memory < total_gpu_memory * 0.9: - raise ValueError("The memory capacity is unbalanced. Some GPUs may be occupied by other processes.") + raise ValueError( + "The memory capacity is unbalanced. Some GPUs may be occupied by other processes." + ) self.load_model() self.init_memory_pool(total_gpu_memory) - self.is_multimodal_model = is_multimodal_model(self.model_config) def load_model(self): - logger.info(f"[rank={self.tp_rank}] Load weight begin.") + logger.info( + f"[gpu_id={self.gpu_id}] Load weight begin. " + f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" + ) device_config = DeviceConfig() load_config = LoadConfig(load_format=self.server_args.load_format) @@ -286,12 +292,16 @@ class ModelRunner: parallel_config=None, scheduler_config=None, ) - logger.info(f"[rank={self.tp_rank}] Load weight end. " - f"Type={type(self.model).__name__}. " - f"Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB") + logger.info( + f"[gpu_id={self.gpu_id}] Load weight end. " + f"Type={type(self.model).__name__}. " + f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" + ) def profile_max_num_token(self, total_gpu_memory): - available_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1) + available_gpu_memory = get_available_gpu_memory( + self.gpu_id, distributed=self.tp_size > 1 + ) head_dim = self.model_config.head_dim head_num = self.model_config.num_key_value_heads // self.tp_size cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2 @@ -306,7 +316,7 @@ class ModelRunner: if self.max_total_num_tokens <= 0: raise RuntimeError( - "Not enought memory. " "Please try to increase --mem-fraction-static." + "Not enought memory. Please try to increase --mem-fraction-static." ) self.req_to_token_pool = ReqToTokenPool( @@ -320,6 +330,10 @@ class ModelRunner: head_dim=self.model_config.head_dim, layer_num=self.model_config.num_hidden_layers, ) + logger.info( + f"[gpu_id={self.gpu_id}] Memory pool end. " + f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" + ) @torch.inference_mode() def forward_prefill(self, batch: Batch): @@ -424,8 +438,8 @@ def import_model_classes(): if hasattr(module, "EntryClass"): entry = module.EntryClass if isinstance(entry, list): # To support multiple model classes in one module - for cls in entry: - model_arch_name_to_cls[cls.__name__] = cls + for tmp in entry: + model_arch_name_to_cls[tmp.__name__] = tmp else: model_arch_name_to_cls[entry.__name__] = entry return model_arch_name_to_cls @@ -442,4 +456,4 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]: # Monkey patch model loader -setattr(ModelRegistry, "load_model_cls", load_model_cls_srt) \ No newline at end of file +setattr(ModelRegistry, "load_model_cls", load_model_cls_srt) diff --git a/python/sglang/srt/managers/router/radix_cache.py b/python/sglang/srt/managers/controller/radix_cache.py similarity index 100% rename from python/sglang/srt/managers/router/radix_cache.py rename to python/sglang/srt/managers/controller/radix_cache.py diff --git a/python/sglang/srt/managers/router/scheduler.py b/python/sglang/srt/managers/controller/schedule_heuristic.py similarity index 98% rename from python/sglang/srt/managers/router/scheduler.py rename to python/sglang/srt/managers/controller/schedule_heuristic.py index def11e775..6c585eb9b 100644 --- a/python/sglang/srt/managers/router/scheduler.py +++ b/python/sglang/srt/managers/controller/schedule_heuristic.py @@ -2,7 +2,7 @@ import random from collections import defaultdict -class Scheduler: +class ScheduleHeuristic: def __init__( self, schedule_heuristic, diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/controller/tp_worker.py similarity index 90% rename from python/sglang/srt/managers/router/model_rpc.py rename to python/sglang/srt/managers/controller/tp_worker.py index 4a4093525..30e209b22 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -1,20 +1,13 @@ import asyncio import logging -import multiprocessing import time import warnings from concurrent.futures import ThreadPoolExecutor -from typing import List, Optional +from typing import List import rpyc import torch from rpyc.utils.classic import obtain -from rpyc.utils.server import ThreadedServer - -try: - from vllm.logger import _default_handler as vllm_default_logger -except ImportError: - from vllm.logger import logger as vllm_default_logger from sglang.global_config import global_config from sglang.srt.constrained.fsm_cache import FSMCache @@ -26,38 +19,41 @@ from sglang.srt.managers.io_struct import ( FlushCacheReq, TokenizedGenerateReqInput, ) -from sglang.srt.managers.router.infer_batch import Batch, FinishReason, ForwardMode, Req -from sglang.srt.managers.router.model_runner import ModelRunner -from sglang.srt.managers.router.radix_cache import RadixCache -from sglang.srt.managers.router.scheduler import Scheduler +from sglang.srt.managers.controller.infer_batch import Batch, FinishReason, ForwardMode, Req +from sglang.srt.managers.controller.model_runner import ModelRunner +from sglang.srt.managers.controller.radix_cache import RadixCache +from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic from sglang.srt.model_config import ModelConfig -from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.server_args import ModelPortArgs, ServerArgs from sglang.srt.utils import ( get_int_token_logit_bias, is_multimodal_model, set_random_seed, + start_rpyc_process, + suppress_other_loggers, ) from sglang.utils import get_exception_traceback -logger = logging.getLogger("model_rpc") -vllm_default_logger.setLevel(logging.WARN) -logging.getLogger("vllm.utils").setLevel(logging.WARN) -logging.getLogger("vllm.selector").setLevel(logging.WARN) +logger = logging.getLogger("srt.model_tp") -class ModelRpcServer: +class ModelTpServer: def __init__( self, + gpu_id: int, tp_rank: int, server_args: ServerArgs, - port_args: PortArgs, - model_overide_args: Optional[dict] = None, + model_port_args: ModelPortArgs, + model_overide_args, ): - server_args, port_args = [obtain(x) for x in [server_args, port_args]] + server_args, model_port_args = obtain(server_args), obtain(model_port_args) + suppress_other_loggers() # Copy arguments + self.gpu_id = gpu_id self.tp_rank = tp_rank self.tp_size = server_args.tp_size + self.dp_size = server_args.dp_size self.schedule_heuristic = server_args.schedule_heuristic self.disable_regex_jump_forward = server_args.disable_regex_jump_forward @@ -68,16 +64,16 @@ class ModelRpcServer: context_length=server_args.context_length, model_overide_args=model_overide_args, ) - - # For model end global settings self.model_runner = ModelRunner( model_config=self.model_config, mem_fraction_static=server_args.mem_fraction_static, + gpu_id=gpu_id, tp_rank=tp_rank, tp_size=server_args.tp_size, - nccl_port=port_args.nccl_port, + nccl_port=model_port_args.nccl_port, server_args=server_args, ) + if is_multimodal_model(server_args.model_path): self.processor = get_processor( server_args.tokenizer_path, @@ -95,21 +91,21 @@ class ModelRpcServer: self.max_prefill_tokens = max( self.model_config.context_len, ( - self.max_total_num_tokens // 6 + min(self.max_total_num_tokens // 6, 65536) if server_args.max_prefill_tokens is None else server_args.max_prefill_tokens ), ) self.max_running_requests = (self.max_total_num_tokens // 2 if server_args.max_running_requests is None else server_args.max_running_requests) - self.int_token_logit_bias = torch.tensor( get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size) ) set_random_seed(server_args.random_seed) # Print info - logger.info(f"[rank={self.tp_rank}] " + logger.info( + f"[gpu_id={self.gpu_id}] " f"max_total_num_tokens={self.max_total_num_tokens}, " f"max_prefill_tokens={self.max_prefill_tokens}, " f"context_len={self.model_config.context_len}, " @@ -124,7 +120,7 @@ class ModelRpcServer: disable=server_args.disable_radix_cache, ) self.tree_cache_metrics = {"total": 0, "hit": 0} - self.scheduler = Scheduler( + self.scheduler = ScheduleHeuristic( self.schedule_heuristic, self.max_running_requests, self.max_prefill_tokens, @@ -170,7 +166,7 @@ class ModelRpcServer: self.new_token_ratio_recovery = global_config.new_token_ratio_recovery def exposed_step(self, recv_reqs): - if self.tp_size != 1: + if self.tp_size * self.dp_size != 1: recv_reqs = obtain(recv_reqs) try: @@ -188,7 +184,7 @@ class ModelRpcServer: # Forward self.forward_step() except Exception: - logger.error("Exception in ModelRpcClient:\n" + get_exception_traceback()) + logger.error("Exception in ModelTpClient:\n" + get_exception_traceback()) # Return results ret = self.out_pyobjs @@ -224,16 +220,17 @@ class ModelRpcServer: self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() ) - throuhgput = self.num_generated_tokens / ( + throughput = self.num_generated_tokens / ( time.time() - self.last_stats_tic ) self.num_generated_tokens = 0 self.last_stats_tic = time.time() logger.info( + f"[gpu_id={self.gpu_id}] " f"#running-req: {len(self.running_batch.reqs)}, " f"#token: {num_used}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, " - f"gen throughput (token/s): {throuhgput:.2f}, " + f"gen throughput (token/s): {throughput:.2f}, " f"#queue-req: {len(self.forward_queue)}" ) @@ -405,7 +402,7 @@ class ModelRpcServer: f"#new_token: {new_batch_input_tokens}. " f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. " f"#running_req: {running_req}. " - f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%." + f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%. " ) # logger.debug( # f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. " @@ -724,20 +721,30 @@ class ModelRpcServer: break -class ModelRpcService(rpyc.Service): - exposed_ModelRpcServer = ModelRpcServer +class ModelTpService(rpyc.Service): + exposed_ModelTpServer = ModelTpServer -class ModelRpcClient: +class ModelTpClient: def __init__( - self, server_args: ServerArgs, port_args: PortArgs, model_overide_args + self, + gpu_ids: List[int], + server_args: ServerArgs, + model_port_args: ModelPortArgs, + model_overide_args, ): - tp_size = server_args.tp_size + server_args, model_port_args = obtain(server_args), obtain(model_port_args) + self.tp_size = server_args.tp_size - if tp_size == 1: + if self.tp_size * server_args.dp_size == 1: # Init model - self.model_server = ModelRpcService().exposed_ModelRpcServer( - 0, server_args, port_args, model_overide_args + assert len(gpu_ids) == 1 + self.model_server = ModelTpService().exposed_ModelTpServer( + 0, + gpu_ids[0], + server_args, + model_port_args, + model_overide_args, ) # Wrap functions @@ -749,19 +756,26 @@ class ModelRpcClient: self.step = async_wrap(self.model_server.exposed_step) else: - with ThreadPoolExecutor(tp_size) as executor: + with ThreadPoolExecutor(self.tp_size) as executor: # Launch model processes - rets = executor.map(start_model_process, port_args.model_rpc_ports) - self.remote_services = [x[0] for x in rets] + rets = executor.map( + lambda args: start_rpyc_process(*args), + [(ModelTpService, p) for p in model_port_args.model_tp_ports], + ) + self.model_services = [x[0] for x in rets] self.procs = [x[1] for x in rets] # Init model def init_model(i): - return self.remote_services[i].ModelRpcServer( - i, server_args, port_args, model_overide_args + return self.model_services[i].ModelTpServer( + gpu_ids[i], + i, + server_args, + model_port_args, + model_overide_args, ) - self.model_servers = executor.map(init_model, range(tp_size)) + self.model_servers = executor.map(init_model, range(self.tp_size)) # Wrap functions def async_wrap(func_name): @@ -774,45 +788,4 @@ class ModelRpcClient: return _func - self.step = async_wrap("step") - - -def _init_service(port): - t = ThreadedServer( - ModelRpcService(), - port=port, - protocol_config={ - "allow_public_attrs": True, - "allow_pickle": True, - "sync_request_timeout": 3600, - }, - ) - t.start() - - -def start_model_process(port): - proc = multiprocessing.Process(target=_init_service, args=(port,)) - proc.start() - time.sleep(1) - - repeat_count = 0 - while repeat_count < 20: - try: - con = rpyc.connect( - "localhost", - port, - config={ - "allow_public_attrs": True, - "allow_pickle": True, - "sync_request_timeout": 3600, - }, - ) - break - except ConnectionRefusedError: - time.sleep(1) - repeat_count += 1 - if repeat_count == 20: - raise RuntimeError("init rpc env error!") - - assert proc.is_alive() - return con.root, proc + self.step = async_wrap("step") \ No newline at end of file diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 4e8d6d74a..a07042b46 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -27,7 +27,6 @@ class GenerateReqInput: return_text_in_logprobs: bool = False # Whether to stream output stream: bool = False - # TODO: make all parameters a Union[List[T], T] to allow for batched requests def post_init(self): @@ -135,4 +134,4 @@ class AbortReq: @dataclass class DetokenizeReqInput: - input_ids: List[int] + input_ids: List[int] \ No newline at end of file diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index ab685ed94..6c1cd0ea3 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -48,7 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.router.model_runner import InputMetadata +from sglang.srt.managers.controller.model_runner import InputMetadata @torch.compile diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index 6b435bd56..ad4e27199 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -29,7 +29,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.router.model_runner import InputMetadata +from sglang.srt.managers.controller.model_runner import InputMetadata class DbrxRouter(nn.Module): diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index 8ad77f12a..5c0b60fd6 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -22,7 +22,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.router.model_runner import InputMetadata +from sglang.srt.managers.controller.model_runner import InputMetadata class GemmaMLP(nn.Module): diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 3aeb72850..91cab15f6 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -37,7 +37,7 @@ from vllm.utils import print_warning_once from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.fused_moe import fused_moe from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.router.model_runner import InputMetadata +from sglang.srt.managers.controller.model_runner import InputMetadata use_fused = True diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index cf292eeb1..aa8c4752d 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -4,9 +4,13 @@ from typing import Any, Dict, Optional, Tuple, Iterable import torch +import tqdm from torch import nn from transformers import LlamaConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size +) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( @@ -24,7 +28,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.router.model_runner import InputMetadata +from sglang.srt.managers.controller.model_runner import InputMetadata class LlamaMLP(nn.Module): @@ -284,6 +288,8 @@ class LlamaForCausalLM(nn.Module): ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) + if get_tensor_model_parallel_rank() == 0: + weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5)) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name or "projector" in name: continue diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 4755939b7..efcc8d91c 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -10,8 +10,8 @@ from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from sglang.srt.managers.router.infer_batch import ForwardMode -from sglang.srt.managers.router.model_runner import InputMetadata +from sglang.srt.managers.controller.infer_batch import ForwardMode +from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.mm_utils import ( get_anyres_image_grid_shape, unpad_image, diff --git a/python/sglang/srt/models/llavavid.py b/python/sglang/srt/models/llavavid.py index 0afc3f0d6..e79b81af1 100644 --- a/python/sglang/srt/models/llavavid.py +++ b/python/sglang/srt/models/llavavid.py @@ -10,8 +10,8 @@ from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from sglang.srt.managers.router.infer_batch import ForwardMode -from sglang.srt.managers.router.model_runner import InputMetadata +from sglang.srt.managers.controller.infer_batch import ForwardMode +from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.mm_utils import ( get_anyres_image_grid_shape, unpad_image, diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index cfe4ab6f8..f718af47f 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -35,7 +35,7 @@ from vllm.utils import print_warning_once from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.router.model_runner import InputMetadata +from sglang.srt.managers.controller.model_runner import InputMetadata diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index f60b4c277..e9edf43c5 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -30,7 +30,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.router.model_runner import InputMetadata +from sglang.srt.managers.controller.model_runner import InputMetadata class MixtralMLP(nn.Module): diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index 9b4da3c36..bce76d53d 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.router.model_runner import InputMetadata +from sglang.srt.managers.controller.model_runner import InputMetadata class QWenMLP(nn.Module): diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 843d91a94..f5bee35a3 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.router.model_runner import InputMetadata +from sglang.srt.managers.controller.model_runner import InputMetadata Qwen2Config = None diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index 5850deb26..279184d8d 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -24,7 +24,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.router.model_runner import InputMetadata +from sglang.srt.managers.controller.model_runner import InputMetadata class StablelmMLP(nn.Module): diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 695e129ed..e19c76e0a 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -10,7 +10,7 @@ import sys import threading import time from http import HTTPStatus -from typing import List, Optional, Union +from typing import Optional # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) @@ -28,14 +28,15 @@ from sglang.srt.constrained import disable_cache from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.detokenizer_manager import start_detokenizer_process from sglang.srt.managers.io_struct import GenerateReqInput -from sglang.srt.managers.router.manager import start_router_process +from sglang.srt.managers.controller.manager_single import start_controller_process as start_controller_process_single +from sglang.srt.managers.controller.manager_multi import start_controller_process as start_controller_process_multi from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.openai_api_adapter import ( load_chat_template_for_openai_api, v1_chat_completions, v1_completions, ) -from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.server_args import ModelPortArgs, PortArgs, ServerArgs from sglang.srt.utils import ( API_KEY_HEADER_NAME, APIKeyValidatorMiddleware, @@ -141,14 +142,28 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg # Allocate ports server_args.port, server_args.additional_ports = allocate_init_ports( - server_args.port, server_args.additional_ports, server_args.tp_size + server_args.port, + server_args.additional_ports, + server_args.tp_size, + server_args.dp_size, ) + + # Init local models port args + ports = server_args.additional_ports + tp = server_args.tp_size + model_port_args = [] + for i in range(server_args.dp_size): + model_port_args.append( + ModelPortArgs( + nccl_port=ports[3 + i * (tp + 1)], + model_tp_ports=ports[3 + i * (tp + 1) + 1 : 3 + (i + 1) * (tp + 1)], + ) + ) port_args = PortArgs( - tokenizer_port=server_args.additional_ports[0], - router_port=server_args.additional_ports[1], - detokenizer_port=server_args.additional_ports[2], - nccl_port=server_args.additional_ports[3], - model_rpc_ports=server_args.additional_ports[4:], + tokenizer_port=ports[0], + router_port=ports[1], + detokenizer_port=ports[2], + model_port_args=model_port_args, ) # Launch processes @@ -156,8 +171,12 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False) pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False) + if server_args.dp_size == 1: + start_process = start_controller_process_single + else: + start_process = start_controller_process_multi proc_router = mp.Process( - target=start_router_process, + target=start_process, args=(server_args, port_args, pipe_router_writer, model_overide_args), ) proc_router.start() @@ -251,19 +270,20 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg class Runtime: def __init__( self, - log_evel: str = "error", + log_level: str = "error", model_overide_args: Optional[dict] = None, *args, **kwargs, ): """See the arguments in server_args.py::ServerArgs""" - self.server_args = ServerArgs(*args, log_level=log_evel, **kwargs) + self.server_args = ServerArgs(*args, log_level=log_level, **kwargs) # Pre-allocate ports self.server_args.port, self.server_args.additional_ports = allocate_init_ports( self.server_args.port, self.server_args.additional_ports, self.server_args.tp_size, + self.server_args.dp_size, ) self.url = self.server_args.url() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 416f0cc6f..35cc02411 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -44,6 +44,10 @@ class ServerArgs: # Other api_key: str = "" + # Data parallelism + dp_size: int = 1 + load_balance_method: str = "round_robin" + # Optimization/debug options enable_flashinfer: bool = False attention_reduce_in_fp32: bool = False @@ -226,6 +230,24 @@ class ServerArgs: help="Set API key of the server", ) + # Data parallelism + parser.add_argument( + "--dp-size", + type=int, + default=ServerArgs.dp_size, + help="Data parallelism size.", + ) + parser.add_argument( + "--load-balance-method", + type=str, + default=ServerArgs.load_balance_method, + help="Load balancing strategy for data parallelism.", + choices=[ + "round_robin", + "shortest_queue", + ], + ) + # Optimization/debug options parser.add_argument( "--enable-flashinfer", @@ -271,10 +293,15 @@ class ServerArgs: ) +@dataclasses.dataclass +class ModelPortArgs: + nccl_port: int + model_tp_ports: List[int] + + @dataclasses.dataclass class PortArgs: tokenizer_port: int router_port: int detokenizer_port: int - nccl_port: int - model_rpc_ports: List[int] + model_port_args: List[ModelPortArgs] diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index c222f4378..5c32fd65d 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1,6 +1,7 @@ """Common utilities.""" import base64 +import multiprocessing import logging import os import random @@ -12,12 +13,14 @@ from typing import List, Optional import numpy as np import requests +import rpyc import torch import triton +from rpyc.utils.server import ThreadedServer from fastapi.responses import JSONResponse from packaging import version as pkg_version from starlette.middleware.base import BaseHTTPMiddleware -import torch.distributed as dist + logger = logging.getLogger(__name__) @@ -120,14 +123,16 @@ def get_available_gpu_memory(gpu_id, distributed=False): def set_random_seed(seed: int) -> None: + """Set the random seed for all libraries.""" random.seed(seed) - + np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def is_port_available(port): + """Return whether a port is available.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: try: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) @@ -142,7 +147,9 @@ def allocate_init_ports( port: Optional[int] = None, additional_ports: Optional[List[int]] = None, tp_size: int = 1, + dp_size: int = 1, ): + """Allocate ports for all connections.""" if additional_ports: ret_ports = [port] + additional_ports else: @@ -151,20 +158,23 @@ def allocate_init_ports( ret_ports = list(set(x for x in ret_ports if is_port_available(x))) cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000 - while len(ret_ports) < 5 + tp_size: + # HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size) + num_ports_needed = 4 + dp_size * (1 + tp_size) + while len(ret_ports) < num_ports_needed: if cur_port not in ret_ports and is_port_available(cur_port): ret_ports.append(cur_port) cur_port += 1 - if port and ret_ports[0] != port: + if port is not None and ret_ports[0] != port: logger.warn( f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead." ) - return ret_ports[0], ret_ports[1:] + return ret_ports[0], ret_ports[1:num_ports_needed] def get_int_token_logit_bias(tokenizer, vocab_size): + """Get the logit bias for integer-only tokens.""" # a bug when model's vocab size > tokenizer.vocab_size vocab_size = tokenizer.vocab_size logit_bias = np.zeros(vocab_size, dtype=np.float32) @@ -181,12 +191,8 @@ def wrap_kernel_launcher(kernel): if int(triton.__version__.split(".")[0]) >= 3: return None - if dist.is_initialized(): - rank = dist.get_rank() - else: - rank = 0 - - kernels = kernel.cache[rank].values() + gpu_id = torch.cuda.current_device() + kernels = kernel.cache[gpu_id].values() kernel = next(iter(kernels)) # Different trition versions use different low-level names @@ -363,6 +369,63 @@ def load_image(image_file): return image, image_size +def init_rpyc_service(service: rpyc.Service, port: int): + t = ThreadedServer( + service=service, + port=port, + protocol_config={ + "allow_public_attrs": True, + "allow_pickle": True, + "sync_request_timeout": 3600 + }, + ) + t.logger.setLevel(logging.WARN) + t.start() + + +def connect_to_rpyc_service(port, host="localhost"): + time.sleep(1) + + repeat_count = 0 + while repeat_count < 20: + try: + con = rpyc.connect( + host, + port, + config={ + "allow_public_attrs": True, + "allow_pickle": True, + "sync_request_timeout": 3600 + }, + ) + break + except ConnectionRefusedError: + time.sleep(1) + repeat_count += 1 + if repeat_count == 20: + raise RuntimeError("init rpc env error!") + + return con.root + + +def start_rpyc_process(service: rpyc.Service, port: int): + # Return the proxy and the process + proc = multiprocessing.Process(target=init_rpyc_service, args=(service, port)) + proc.start() + proxy = connect_to_rpyc_service(port) + assert proc.is_alive() + return proxy, proc + + +def suppress_other_loggers(): + from vllm.logger import logger as vllm_default_logger + + vllm_default_logger.setLevel(logging.WARN) + logging.getLogger("vllm.utils").setLevel(logging.WARN) + logging.getLogger("vllm.selector").setLevel(logging.WARN) + logging.getLogger("vllm.config").setLevel(logging.ERROR) + + def assert_pkg_version(pkg: str, min_version: str): try: installed_version = version(pkg) @@ -394,4 +457,4 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware): content={"detail": "Invalid API Key"}, ) response = await call_next(request) - return response \ No newline at end of file + return response diff --git a/test/srt/model/bench_llama_low_api.py b/test/srt/model/bench_llama_low_api.py index 9c6bce91d..339574228 100644 --- a/test/srt/model/bench_llama_low_api.py +++ b/test/srt/model/bench_llama_low_api.py @@ -5,7 +5,7 @@ from dataclasses import dataclass import torch import torch.distributed as dist -from sglang.srt.managers.router.model_runner import ModelRunner +from sglang.srt.managers.controller.model_runner import ModelRunner from sglang.srt.model_config import ModelConfig diff --git a/test/srt/model/test_llama_extend.py b/test/srt/model/test_llama_extend.py index cdb40f887..2814dc2a0 100644 --- a/test/srt/model/test_llama_extend.py +++ b/test/srt/model/test_llama_extend.py @@ -7,8 +7,8 @@ import torch import torch.distributed as dist import transformers -from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req -from sglang.srt.managers.router.model_runner import ModelRunner +from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, Req +from sglang.srt.managers.controller.model_runner import ModelRunner from sglang.srt.model_config import ModelConfig from sglang.srt.sampling_params import SamplingParams diff --git a/test/srt/model/test_llama_low_api.py b/test/srt/model/test_llama_low_api.py index 20b59e5c7..0eb1574b1 100644 --- a/test/srt/model/test_llama_low_api.py +++ b/test/srt/model/test_llama_low_api.py @@ -5,7 +5,7 @@ import numpy as np import torch import torch.distributed as dist -from sglang.srt.managers.router.model_runner import ModelRunner +from sglang.srt.managers.controller.model_runner import ModelRunner from sglang.srt.model_config import ModelConfig diff --git a/test/srt/model/test_llava_low_api.py b/test/srt/model/test_llava_low_api.py index 186a46df0..2a9fa543d 100644 --- a/test/srt/model/test_llava_low_api.py +++ b/test/srt/model/test_llava_low_api.py @@ -6,8 +6,8 @@ import torch import torch.distributed as dist from sglang.srt.hf_transformers_utils import get_processor -from sglang.srt.managers.router.infer_batch import ForwardMode -from sglang.srt.managers.router.model_runner import InputMetadata, ModelRunner +from sglang.srt.managers.controller.infer_batch import ForwardMode +from sglang.srt.managers.controller.model_runner import InputMetadata, ModelRunner from sglang.srt.model_config import ModelConfig from sglang.srt.utils import load_image