diff --git a/python/sglang/srt/managers/controller/dp_worker.py b/python/sglang/srt/managers/controller/dp_worker.py index a1b67396d..16f5d2308 100644 --- a/python/sglang/srt/managers/controller/dp_worker.py +++ b/python/sglang/srt/managers/controller/dp_worker.py @@ -12,6 +12,7 @@ from sglang.global_config import global_config from sglang.srt.managers.controller.tp_worker import ModelTpClient from sglang.srt.managers.io_struct import BatchTokenIDOut from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.utils import kill_parent_process from sglang.utils import get_exception_traceback logger = logging.getLogger("srt.controller") @@ -58,6 +59,10 @@ class DataParallelWorkerThread(threading.Thread): f"{get_exception_traceback()}" ) self.liveness = False + # Crash the whole server when there are any errors. + # TODO(lianmin): make this an option. + kill_parent_process() + return for obj in out_pyobjs: self.send_to_detokenizer.send_pyobj(obj) diff --git a/python/sglang/srt/managers/controller/manager_single.py b/python/sglang/srt/managers/controller/manager_single.py index 7b39a56de..d1c49c6e2 100644 --- a/python/sglang/srt/managers/controller/manager_single.py +++ b/python/sglang/srt/managers/controller/manager_single.py @@ -1,6 +1,7 @@ """A controller that manages a group of tensor parallel workers.""" import asyncio import logging +import time import uvloop import zmq @@ -9,10 +10,13 @@ import zmq.asyncio from sglang.global_config import global_config from sglang.srt.managers.controller.tp_worker import ModelTpClient from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.utils import kill_parent_process from sglang.utils import get_exception_traceback asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +logger = logging.getLogger("srt.controller") + class ControllerSingle: def __init__(self, model_client: ModelTpClient, port_args: PortArgs): @@ -85,4 +89,9 @@ def start_controller_process( loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.create_task(controller.loop_for_recv_requests()) - loop.run_until_complete(controller.loop_for_forward()) \ No newline at end of file + try: + loop.run_until_complete(controller.loop_for_forward()) + except Exception: + logger.error("Exception in ControllerSingle:\n" + get_exception_traceback()) + finally: + kill_parent_process() \ No newline at end of file diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index 0c22e2720..0033acbf8 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -18,7 +18,7 @@ from vllm.model_executor.models import ModelRegistry from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model +from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model, monkey_patch_vllm_p2p_access_check logger = logging.getLogger("srt.model_runner") @@ -240,10 +240,12 @@ class ModelRunner: logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.") torch.cuda.set_device(self.gpu_id) logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.") + monkey_patch_vllm_p2p_access_check() init_distributed_environment( backend="nccl", world_size=self.tp_size, rank=self.tp_rank, + local_rank=self.gpu_id, distributed_init_method=f"tcp://127.0.0.1:{self.nccl_port}", ) initialize_model_parallel(tensor_model_parallel_size=self.tp_size) @@ -265,7 +267,7 @@ class ModelRunner: def load_model(self): logger.info( f"[gpu_id={self.gpu_id}] Load weight begin. " - f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" + f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" ) device_config = DeviceConfig() @@ -295,8 +297,8 @@ class ModelRunner: ) logger.info( f"[gpu_id={self.gpu_id}] Load weight end. " - f"Type={type(self.model).__name__}. " - f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" + f"type={type(self.model).__name__}, " + f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" ) def profile_max_num_token(self, total_gpu_memory): @@ -333,7 +335,7 @@ class ModelRunner: ) logger.info( f"[gpu_id={self.gpu_id}] Memory pool end. " - f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" + f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" ) @torch.inference_mode() diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index 8343429db..d85873117 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -34,7 +34,7 @@ from sglang.srt.utils import ( ) from sglang.utils import get_exception_traceback -logger = logging.getLogger("srt.model_tp") +logger = logging.getLogger("srt.tp_worker") class ModelTpServer: @@ -187,7 +187,8 @@ class ModelTpServer: # Forward self.forward_step() except Exception: - logger.error("Exception in ModelTpClient:\n" + get_exception_traceback()) + logger.error("Exception in ModelTpServer:\n" + get_exception_traceback()) + raise # Return results ret = self.out_pyobjs diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index c77625eb9..d60edf273 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -87,7 +87,7 @@ def start_detokenizer_process( try: manager = DetokenizerManager(server_args, port_args) - except Exception as e: + except Exception: pipe_writer.send(get_exception_traceback()) raise pipe_writer.send("init ok") diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index e19c76e0a..2403ef57f 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -228,20 +228,21 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg # Send a warmup request try: - res = requests.post( - url + "/generate", - json={ - "text": "The capital city of France is", - "sampling_params": { - "temperature": 0, - "max_new_tokens": 16, + for _ in range(server_args.dp_size): + res = requests.post( + url + "/generate", + json={ + "text": "The capital city of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 16, + }, }, - }, - headers=headers, - timeout=600, - ) - assert res.status_code == 200 - except Exception as e: + headers=headers, + timeout=600, + ) + assert res.status_code == 200 + except Exception: if pipe_finish_writer is not None: pipe_finish_writer.send(get_exception_traceback()) print(f"Initialization failed. warmup error: {e}") diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 272c2beac..bddb3ded5 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -12,6 +12,7 @@ from io import BytesIO from typing import List, Optional import numpy as np +import psutil import requests import rpyc import torch @@ -441,6 +442,27 @@ def assert_pkg_version(pkg: str, min_version: str): ) +def kill_parent_process(): + """Kill the parent process and all children of the parent process.""" + current_process = psutil.Process() + parent_process = current_process.parent() + children = current_process.children(recursive=True) + for child in children: + if child.pid != current_process.pid: + os.kill(child.pid, 9) + os.kill(parent_process.pid, 9) + + +def monkey_patch_vllm_p2p_access_check(): + """ + Monkey patch the slow p2p access check in vllm. + NOTE: We assume the p2p access is always allowed, which can be wrong for some setups. + """ + import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt + + setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True) + + API_KEY_HEADER_NAME = "X-API-Key" @@ -459,3 +481,4 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware): ) response = await call_next(request) return response +