Crash the server when error or OOM happens (#514)
This commit is contained in:
@@ -12,6 +12,7 @@ from sglang.global_config import global_config
|
|||||||
from sglang.srt.managers.controller.tp_worker import ModelTpClient
|
from sglang.srt.managers.controller.tp_worker import ModelTpClient
|
||||||
from sglang.srt.managers.io_struct import BatchTokenIDOut
|
from sglang.srt.managers.io_struct import BatchTokenIDOut
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
|
from sglang.srt.utils import kill_parent_process
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
logger = logging.getLogger("srt.controller")
|
logger = logging.getLogger("srt.controller")
|
||||||
@@ -58,6 +59,10 @@ class DataParallelWorkerThread(threading.Thread):
|
|||||||
f"{get_exception_traceback()}"
|
f"{get_exception_traceback()}"
|
||||||
)
|
)
|
||||||
self.liveness = False
|
self.liveness = False
|
||||||
|
# Crash the whole server when there are any errors.
|
||||||
|
# TODO(lianmin): make this an option.
|
||||||
|
kill_parent_process()
|
||||||
|
return
|
||||||
|
|
||||||
for obj in out_pyobjs:
|
for obj in out_pyobjs:
|
||||||
self.send_to_detokenizer.send_pyobj(obj)
|
self.send_to_detokenizer.send_pyobj(obj)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""A controller that manages a group of tensor parallel workers."""
|
"""A controller that manages a group of tensor parallel workers."""
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
import uvloop
|
import uvloop
|
||||||
import zmq
|
import zmq
|
||||||
@@ -9,10 +10,13 @@ import zmq.asyncio
|
|||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.managers.controller.tp_worker import ModelTpClient
|
from sglang.srt.managers.controller.tp_worker import ModelTpClient
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
|
from sglang.srt.utils import kill_parent_process
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
|
logger = logging.getLogger("srt.controller")
|
||||||
|
|
||||||
|
|
||||||
class ControllerSingle:
|
class ControllerSingle:
|
||||||
def __init__(self, model_client: ModelTpClient, port_args: PortArgs):
|
def __init__(self, model_client: ModelTpClient, port_args: PortArgs):
|
||||||
@@ -85,4 +89,9 @@ def start_controller_process(
|
|||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
loop.create_task(controller.loop_for_recv_requests())
|
loop.create_task(controller.loop_for_recv_requests())
|
||||||
loop.run_until_complete(controller.loop_for_forward())
|
try:
|
||||||
|
loop.run_until_complete(controller.loop_for_forward())
|
||||||
|
except Exception:
|
||||||
|
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
|
||||||
|
finally:
|
||||||
|
kill_parent_process()
|
||||||
@@ -18,7 +18,7 @@ from vllm.model_executor.models import ModelRegistry
|
|||||||
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
|
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
|
||||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model
|
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model, monkey_patch_vllm_p2p_access_check
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger("srt.model_runner")
|
logger = logging.getLogger("srt.model_runner")
|
||||||
@@ -240,10 +240,12 @@ class ModelRunner:
|
|||||||
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
|
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
|
||||||
torch.cuda.set_device(self.gpu_id)
|
torch.cuda.set_device(self.gpu_id)
|
||||||
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
|
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
|
||||||
|
monkey_patch_vllm_p2p_access_check()
|
||||||
init_distributed_environment(
|
init_distributed_environment(
|
||||||
backend="nccl",
|
backend="nccl",
|
||||||
world_size=self.tp_size,
|
world_size=self.tp_size,
|
||||||
rank=self.tp_rank,
|
rank=self.tp_rank,
|
||||||
|
local_rank=self.gpu_id,
|
||||||
distributed_init_method=f"tcp://127.0.0.1:{self.nccl_port}",
|
distributed_init_method=f"tcp://127.0.0.1:{self.nccl_port}",
|
||||||
)
|
)
|
||||||
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
||||||
@@ -265,7 +267,7 @@ class ModelRunner:
|
|||||||
def load_model(self):
|
def load_model(self):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[gpu_id={self.gpu_id}] Load weight begin. "
|
f"[gpu_id={self.gpu_id}] Load weight begin. "
|
||||||
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
device_config = DeviceConfig()
|
device_config = DeviceConfig()
|
||||||
@@ -295,8 +297,8 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[gpu_id={self.gpu_id}] Load weight end. "
|
f"[gpu_id={self.gpu_id}] Load weight end. "
|
||||||
f"Type={type(self.model).__name__}. "
|
f"type={type(self.model).__name__}, "
|
||||||
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
def profile_max_num_token(self, total_gpu_memory):
|
def profile_max_num_token(self, total_gpu_memory):
|
||||||
@@ -333,7 +335,7 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[gpu_id={self.gpu_id}] Memory pool end. "
|
f"[gpu_id={self.gpu_id}] Memory pool end. "
|
||||||
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ from sglang.srt.utils import (
|
|||||||
)
|
)
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
logger = logging.getLogger("srt.model_tp")
|
logger = logging.getLogger("srt.tp_worker")
|
||||||
|
|
||||||
|
|
||||||
class ModelTpServer:
|
class ModelTpServer:
|
||||||
@@ -187,7 +187,8 @@ class ModelTpServer:
|
|||||||
# Forward
|
# Forward
|
||||||
self.forward_step()
|
self.forward_step()
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error("Exception in ModelTpClient:\n" + get_exception_traceback())
|
logger.error("Exception in ModelTpServer:\n" + get_exception_traceback())
|
||||||
|
raise
|
||||||
|
|
||||||
# Return results
|
# Return results
|
||||||
ret = self.out_pyobjs
|
ret = self.out_pyobjs
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ def start_detokenizer_process(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
manager = DetokenizerManager(server_args, port_args)
|
manager = DetokenizerManager(server_args, port_args)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pipe_writer.send(get_exception_traceback())
|
pipe_writer.send(get_exception_traceback())
|
||||||
raise
|
raise
|
||||||
pipe_writer.send("init ok")
|
pipe_writer.send("init ok")
|
||||||
|
|||||||
@@ -228,20 +228,21 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|||||||
|
|
||||||
# Send a warmup request
|
# Send a warmup request
|
||||||
try:
|
try:
|
||||||
res = requests.post(
|
for _ in range(server_args.dp_size):
|
||||||
url + "/generate",
|
res = requests.post(
|
||||||
json={
|
url + "/generate",
|
||||||
"text": "The capital city of France is",
|
json={
|
||||||
"sampling_params": {
|
"text": "The capital city of France is",
|
||||||
"temperature": 0,
|
"sampling_params": {
|
||||||
"max_new_tokens": 16,
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 16,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
headers=headers,
|
||||||
headers=headers,
|
timeout=600,
|
||||||
timeout=600,
|
)
|
||||||
)
|
assert res.status_code == 200
|
||||||
assert res.status_code == 200
|
except Exception:
|
||||||
except Exception as e:
|
|
||||||
if pipe_finish_writer is not None:
|
if pipe_finish_writer is not None:
|
||||||
pipe_finish_writer.send(get_exception_traceback())
|
pipe_finish_writer.send(get_exception_traceback())
|
||||||
print(f"Initialization failed. warmup error: {e}")
|
print(f"Initialization failed. warmup error: {e}")
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from io import BytesIO
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import psutil
|
||||||
import requests
|
import requests
|
||||||
import rpyc
|
import rpyc
|
||||||
import torch
|
import torch
|
||||||
@@ -441,6 +442,27 @@ def assert_pkg_version(pkg: str, min_version: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def kill_parent_process():
|
||||||
|
"""Kill the parent process and all children of the parent process."""
|
||||||
|
current_process = psutil.Process()
|
||||||
|
parent_process = current_process.parent()
|
||||||
|
children = current_process.children(recursive=True)
|
||||||
|
for child in children:
|
||||||
|
if child.pid != current_process.pid:
|
||||||
|
os.kill(child.pid, 9)
|
||||||
|
os.kill(parent_process.pid, 9)
|
||||||
|
|
||||||
|
|
||||||
|
def monkey_patch_vllm_p2p_access_check():
|
||||||
|
"""
|
||||||
|
Monkey patch the slow p2p access check in vllm.
|
||||||
|
NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
|
||||||
|
"""
|
||||||
|
import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
|
||||||
|
|
||||||
|
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
|
||||||
|
|
||||||
|
|
||||||
API_KEY_HEADER_NAME = "X-API-Key"
|
API_KEY_HEADER_NAME = "X-API-Key"
|
||||||
|
|
||||||
|
|
||||||
@@ -459,3 +481,4 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
|||||||
)
|
)
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user