Crash the server when error or OOM happens (#514)
This commit is contained in:
@@ -12,6 +12,7 @@ from sglang.global_config import global_config
|
||||
from sglang.srt.managers.controller.tp_worker import ModelTpClient
|
||||
from sglang.srt.managers.io_struct import BatchTokenIDOut
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import kill_parent_process
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger("srt.controller")
|
||||
@@ -58,6 +59,10 @@ class DataParallelWorkerThread(threading.Thread):
|
||||
f"{get_exception_traceback()}"
|
||||
)
|
||||
self.liveness = False
|
||||
# Crash the whole server when there are any errors.
|
||||
# TODO(lianmin): make this an option.
|
||||
kill_parent_process()
|
||||
return
|
||||
|
||||
for obj in out_pyobjs:
|
||||
self.send_to_detokenizer.send_pyobj(obj)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""A controller that manages a group of tensor parallel workers."""
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
|
||||
import uvloop
|
||||
import zmq
|
||||
@@ -9,10 +10,13 @@ import zmq.asyncio
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.managers.controller.tp_worker import ModelTpClient
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import kill_parent_process
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
logger = logging.getLogger("srt.controller")
|
||||
|
||||
|
||||
class ControllerSingle:
|
||||
def __init__(self, model_client: ModelTpClient, port_args: PortArgs):
|
||||
@@ -85,4 +89,9 @@ def start_controller_process(
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.create_task(controller.loop_for_recv_requests())
|
||||
loop.run_until_complete(controller.loop_for_forward())
|
||||
try:
|
||||
loop.run_until_complete(controller.loop_for_forward())
|
||||
except Exception:
|
||||
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
|
||||
finally:
|
||||
kill_parent_process()
|
||||
@@ -18,7 +18,7 @@ from vllm.model_executor.models import ModelRegistry
|
||||
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
|
||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model
|
||||
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model, monkey_patch_vllm_p2p_access_check
|
||||
|
||||
|
||||
logger = logging.getLogger("srt.model_runner")
|
||||
@@ -240,10 +240,12 @@ class ModelRunner:
|
||||
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
|
||||
torch.cuda.set_device(self.gpu_id)
|
||||
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
|
||||
monkey_patch_vllm_p2p_access_check()
|
||||
init_distributed_environment(
|
||||
backend="nccl",
|
||||
world_size=self.tp_size,
|
||||
rank=self.tp_rank,
|
||||
local_rank=self.gpu_id,
|
||||
distributed_init_method=f"tcp://127.0.0.1:{self.nccl_port}",
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
||||
@@ -265,7 +267,7 @@ class ModelRunner:
|
||||
def load_model(self):
|
||||
logger.info(
|
||||
f"[gpu_id={self.gpu_id}] Load weight begin. "
|
||||
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
device_config = DeviceConfig()
|
||||
@@ -295,8 +297,8 @@ class ModelRunner:
|
||||
)
|
||||
logger.info(
|
||||
f"[gpu_id={self.gpu_id}] Load weight end. "
|
||||
f"Type={type(self.model).__name__}. "
|
||||
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||
f"type={type(self.model).__name__}, "
|
||||
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
def profile_max_num_token(self, total_gpu_memory):
|
||||
@@ -333,7 +335,7 @@ class ModelRunner:
|
||||
)
|
||||
logger.info(
|
||||
f"[gpu_id={self.gpu_id}] Memory pool end. "
|
||||
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
@@ -34,7 +34,7 @@ from sglang.srt.utils import (
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger("srt.model_tp")
|
||||
logger = logging.getLogger("srt.tp_worker")
|
||||
|
||||
|
||||
class ModelTpServer:
|
||||
@@ -187,7 +187,8 @@ class ModelTpServer:
|
||||
# Forward
|
||||
self.forward_step()
|
||||
except Exception:
|
||||
logger.error("Exception in ModelTpClient:\n" + get_exception_traceback())
|
||||
logger.error("Exception in ModelTpServer:\n" + get_exception_traceback())
|
||||
raise
|
||||
|
||||
# Return results
|
||||
ret = self.out_pyobjs
|
||||
|
||||
@@ -87,7 +87,7 @@ def start_detokenizer_process(
|
||||
|
||||
try:
|
||||
manager = DetokenizerManager(server_args, port_args)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pipe_writer.send(get_exception_traceback())
|
||||
raise
|
||||
pipe_writer.send("init ok")
|
||||
|
||||
@@ -228,20 +228,21 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
|
||||
# Send a warmup request
|
||||
try:
|
||||
res = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "The capital city of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
for _ in range(server_args.dp_size):
|
||||
res = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": "The capital city of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
},
|
||||
},
|
||||
},
|
||||
headers=headers,
|
||||
timeout=600,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
except Exception as e:
|
||||
headers=headers,
|
||||
timeout=600,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
except Exception:
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(get_exception_traceback())
|
||||
print(f"Initialization failed. warmup error: {e}")
|
||||
|
||||
@@ -12,6 +12,7 @@ from io import BytesIO
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import psutil
|
||||
import requests
|
||||
import rpyc
|
||||
import torch
|
||||
@@ -441,6 +442,27 @@ def assert_pkg_version(pkg: str, min_version: str):
|
||||
)
|
||||
|
||||
|
||||
def kill_parent_process():
|
||||
"""Kill the parent process and all children of the parent process."""
|
||||
current_process = psutil.Process()
|
||||
parent_process = current_process.parent()
|
||||
children = current_process.children(recursive=True)
|
||||
for child in children:
|
||||
if child.pid != current_process.pid:
|
||||
os.kill(child.pid, 9)
|
||||
os.kill(parent_process.pid, 9)
|
||||
|
||||
|
||||
def monkey_patch_vllm_p2p_access_check():
|
||||
"""
|
||||
Monkey patch the slow p2p access check in vllm.
|
||||
NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
|
||||
"""
|
||||
import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
|
||||
|
||||
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
|
||||
|
||||
|
||||
API_KEY_HEADER_NAME = "X-API-Key"
|
||||
|
||||
|
||||
@@ -459,3 +481,4 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
||||
)
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
Reference in New Issue
Block a user