Crash the server correctly during error (#2231)
This commit is contained in:
@@ -15,16 +15,19 @@
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import signal
|
||||
import threading
|
||||
from queue import Queue
|
||||
from typing import Optional
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.io_struct import UpdateWeightReqInput
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -70,6 +73,7 @@ class TpModelWorkerClient:
|
||||
target=self.forward_thread_func,
|
||||
)
|
||||
self.forward_thread.start()
|
||||
self.parent_process = psutil.Process().parent()
|
||||
|
||||
def get_worker_info(self):
|
||||
return self.worker.get_worker_info()
|
||||
@@ -87,8 +91,13 @@ class TpModelWorkerClient:
|
||||
)
|
||||
|
||||
def forward_thread_func(self):
|
||||
with torch.cuda.stream(self.forward_stream):
|
||||
self.forward_thread_func_()
|
||||
try:
|
||||
with torch.cuda.stream(self.forward_stream):
|
||||
self.forward_thread_func_()
|
||||
except Exception:
|
||||
traceback = get_exception_traceback()
|
||||
logger.error(f"TpModelWorkerClient hit an exception: {traceback}")
|
||||
self.parent_process.send_signal(signal.SIGQUIT)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_thread_func_(self):
|
||||
|
||||
Reference in New Issue
Block a user