Crash the server correctly during error (#2231)

This commit is contained in:
Lianmin Zheng
2024-11-28 00:22:39 -08:00
committed by GitHub
parent db674e3d24
commit d4fc1a70e3
46 changed files with 147 additions and 139 deletions

View File

@@ -15,16 +15,19 @@
import dataclasses
import logging
import signal
import threading
from queue import Queue
from typing import Optional
import psutil
import torch
from sglang.srt.managers.io_struct import UpdateWeightReqInput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
@@ -70,6 +73,7 @@ class TpModelWorkerClient:
target=self.forward_thread_func,
)
self.forward_thread.start()
self.parent_process = psutil.Process().parent()
def get_worker_info(self):
return self.worker.get_worker_info()
@@ -87,8 +91,13 @@ class TpModelWorkerClient:
)
def forward_thread_func(self):
with torch.cuda.stream(self.forward_stream):
self.forward_thread_func_()
try:
with torch.cuda.stream(self.forward_stream):
self.forward_thread_func_()
except Exception:
traceback = get_exception_traceback()
logger.error(f"TpModelWorkerClient hit an exception: {traceback}")
self.parent_process.send_signal(signal.SIGQUIT)
@torch.no_grad()
def forward_thread_func_(self):