Crash the server correctly during error (#2231)

This commit is contained in:
Lianmin Zheng
2024-11-28 00:22:39 -08:00
committed by GitHub
parent db674e3d24
commit d4fc1a70e3
46 changed files with 147 additions and 139 deletions

View File

@@ -47,6 +47,7 @@ import itertools
import json import json
import logging import logging
import multiprocessing import multiprocessing
import os
import time import time
from typing import Tuple from typing import Tuple
@@ -62,11 +63,7 @@ from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server import _set_envs_and_config from sglang.srt.server import _set_envs_and_config
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import configure_logger, kill_process_tree, suppress_other_loggers
configure_logger,
kill_child_process,
suppress_other_loggers,
)
@dataclasses.dataclass @dataclasses.dataclass
@@ -468,4 +465,4 @@ if __name__ == "__main__":
main(server_args, bench_args) main(server_args, bench_args)
finally: finally:
if server_args.tp_size != 1: if server_args.tp_size != 1:
kill_child_process() kill_process_tree(os.getpid(), include_parent=False)

View File

@@ -15,6 +15,7 @@ import dataclasses
import itertools import itertools
import json import json
import multiprocessing import multiprocessing
import os
import time import time
from typing import Tuple from typing import Tuple
@@ -23,7 +24,7 @@ import requests
from sglang.srt.server import launch_server from sglang.srt.server import launch_server
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
@dataclasses.dataclass @dataclasses.dataclass
@@ -69,7 +70,7 @@ def launch_server_internal(server_args):
except Exception as e: except Exception as e:
raise e raise e
finally: finally:
kill_child_process() kill_process_tree(os.getpid(), include_parent=False)
def launch_server_process(server_args: ServerArgs): def launch_server_process(server_args: ServerArgs):
@@ -175,7 +176,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
) )
finally: finally:
if proc: if proc:
kill_child_process(proc.pid, include_self=True) kill_process_tree(proc.pid)
print(f"\nResults are saved to {bench_args.result_filename}") print(f"\nResults are saved to {bench_args.result_filename}")

View File

@@ -4,7 +4,7 @@ import sys
from sglang.srt.server import launch_server from sglang.srt.server import launch_server
from sglang.srt.server_args import prepare_server_args from sglang.srt.server_args import prepare_server_args
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
if __name__ == "__main__": if __name__ == "__main__":
server_args = prepare_server_args(sys.argv[1:]) server_args = prepare_server_args(sys.argv[1:])
@@ -12,4 +12,4 @@ if __name__ == "__main__":
try: try:
launch_server(server_args) launch_server(server_args)
finally: finally:
kill_child_process() kill_process_tree(os.getpid(), include_parent=False)

View File

@@ -15,9 +15,11 @@
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import signal
import threading import threading
from enum import Enum, auto from enum import Enum, auto
import psutil
import zmq import zmq
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
@@ -26,13 +28,7 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
bind_port,
configure_logger,
get_zmq_socket,
kill_parent_process,
suppress_other_loggers,
)
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -235,7 +231,7 @@ def run_data_parallel_controller_process(
pipe_writer, pipe_writer,
): ):
configure_logger(server_args) configure_logger(server_args)
suppress_other_loggers() parent_process = psutil.Process().parent()
try: try:
controller = DataParallelController(server_args, port_args) controller = DataParallelController(server_args, port_args)
@@ -244,6 +240,6 @@ def run_data_parallel_controller_process(
) )
controller.event_loop() controller.event_loop()
except Exception: except Exception:
msg = get_exception_traceback() traceback = get_exception_traceback()
logger.error(msg) logger.error(f"DataParallelController hit an exception: {traceback}")
kill_parent_process() parent_process.send_signal(signal.SIGQUIT)

View File

@@ -15,9 +15,11 @@
import dataclasses import dataclasses
import logging import logging
import signal
from collections import OrderedDict from collections import OrderedDict
from typing import List, Union from typing import List, Union
import psutil
import zmq import zmq
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
@@ -28,7 +30,7 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import configure_logger, get_zmq_socket, kill_parent_process from sglang.srt.utils import configure_logger, get_zmq_socket
from sglang.utils import find_printable_text, get_exception_traceback from sglang.utils import find_printable_text, get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -193,11 +195,12 @@ def run_detokenizer_process(
port_args: PortArgs, port_args: PortArgs,
): ):
configure_logger(server_args) configure_logger(server_args)
parent_process = psutil.Process().parent()
try: try:
manager = DetokenizerManager(server_args, port_args) manager = DetokenizerManager(server_args, port_args)
manager.event_loop() manager.event_loop()
except Exception: except Exception:
msg = get_exception_traceback() traceback = get_exception_traceback()
logger.error(msg) logger.error(f"DetokenizerManager hit an exception: {traceback}")
kill_parent_process() parent_process.send_signal(signal.SIGQUIT)

View File

@@ -15,6 +15,7 @@
import logging import logging
import os import os
import signal
import threading import threading
import time import time
import warnings import warnings
@@ -23,6 +24,7 @@ from concurrent import futures
from types import SimpleNamespace from types import SimpleNamespace
from typing import List, Optional from typing import List, Optional
import psutil
import torch import torch
import zmq import zmq
@@ -73,7 +75,6 @@ from sglang.srt.utils import (
crash_on_warnings, crash_on_warnings,
get_bool_env_var, get_bool_env_var,
get_zmq_socket, get_zmq_socket,
kill_parent_process,
set_gpu_proc_affinity, set_gpu_proc_affinity,
set_random_seed, set_random_seed,
suppress_other_loggers, suppress_other_loggers,
@@ -316,6 +317,7 @@ class Scheduler:
self.watchdog_timeout = server_args.watchdog_timeout self.watchdog_timeout = server_args.watchdog_timeout
t = threading.Thread(target=self.watchdog_thread, daemon=True) t = threading.Thread(target=self.watchdog_thread, daemon=True)
t.start() t.start()
self.parent_process = psutil.Process().parent()
# Init profiler # Init profiler
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "": if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
@@ -359,7 +361,7 @@ class Scheduler:
self.watchdog_last_time = time.time() self.watchdog_last_time = time.time()
time.sleep(self.watchdog_timeout / 2) time.sleep(self.watchdog_timeout / 2)
kill_parent_process() self.parent_process.send_signal(signal.SIGQUIT)
@torch.no_grad() @torch.no_grad()
def event_loop_normal(self): def event_loop_normal(self):
@@ -1423,6 +1425,7 @@ def run_scheduler_process(
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}") configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
suppress_other_loggers() suppress_other_loggers()
parent_process = psutil.Process().parent()
try: try:
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank) scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
@@ -1434,6 +1437,6 @@ def run_scheduler_process(
else: else:
scheduler.event_loop_normal() scheduler.event_loop_normal()
except Exception: except Exception:
msg = get_exception_traceback() traceback = get_exception_traceback()
logger.error(msg) logger.error(f"Scheduler hit an exception: {traceback}")
kill_parent_process() parent_process.send_signal(signal.SIGQUIT)

View File

@@ -58,7 +58,7 @@ from sglang.srt.managers.io_struct import (
from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_zmq_socket, kill_child_process from sglang.srt.utils import get_zmq_socket, kill_process_tree
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -532,7 +532,7 @@ class TokenizerManager:
else: else:
break break
kill_child_process(include_self=True) kill_process_tree(os.getpid(), include_parent=True)
sys.exit(0) sys.exit(0)
async def handle_loop(self): async def handle_loop(self):

View File

@@ -15,16 +15,19 @@
import dataclasses import dataclasses
import logging import logging
import signal
import threading import threading
from queue import Queue from queue import Queue
from typing import Optional from typing import Optional
import psutil
import torch import torch
from sglang.srt.managers.io_struct import UpdateWeightReqInput from sglang.srt.managers.io_struct import UpdateWeightReqInput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -70,6 +73,7 @@ class TpModelWorkerClient:
target=self.forward_thread_func, target=self.forward_thread_func,
) )
self.forward_thread.start() self.forward_thread.start()
self.parent_process = psutil.Process().parent()
def get_worker_info(self): def get_worker_info(self):
return self.worker.get_worker_info() return self.worker.get_worker_info()
@@ -87,8 +91,13 @@ class TpModelWorkerClient:
) )
def forward_thread_func(self): def forward_thread_func(self):
with torch.cuda.stream(self.forward_stream): try:
self.forward_thread_func_() with torch.cuda.stream(self.forward_stream):
self.forward_thread_func_()
except Exception:
traceback = get_exception_traceback()
logger.error(f"TpModelWorkerClient hit an exception: {traceback}")
self.parent_process.send_signal(signal.SIGQUIT)
@torch.no_grad() @torch.no_grad()
def forward_thread_func_(self): def forward_thread_func_(self):

View File

@@ -23,6 +23,8 @@ import json
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import os import os
import signal
import sys
import threading import threading
import time import time
from http import HTTPStatus from http import HTTPStatus
@@ -79,7 +81,7 @@ from sglang.srt.utils import (
configure_logger, configure_logger,
delete_directory, delete_directory,
is_port_available, is_port_available,
kill_child_process, kill_process_tree,
maybe_set_triton_cache_manager, maybe_set_triton_cache_manager,
prepare_model_and_tokenizer, prepare_model_and_tokenizer,
set_prometheus_multiproc_dir, set_prometheus_multiproc_dir,
@@ -572,6 +574,15 @@ def _set_envs_and_config(server_args: ServerArgs):
"at https://docs.flashinfer.ai/installation.html.", "at https://docs.flashinfer.ai/installation.html.",
) )
# Register the signal handler.
# The child processes will send SIGQUIT to this process when any error happens
# This process then clean up the whole process tree
def sigquit_handler(signum, frame):
kill_process_tree(os.getpid())
signal.signal(signal.SIGQUIT, sigquit_handler)
# Set mp start method
mp.set_start_method("spawn", force=True) mp.set_start_method("spawn", force=True)
@@ -598,7 +609,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback) pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}") logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_child_process(include_self=True) kill_process_tree(os.getpid())
return return
model_info = res.json() model_info = res.json()
@@ -631,7 +642,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback) pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}") logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_child_process(include_self=True) kill_process_tree(os.getpid())
return return
# logger.info(f"{res.json()=}") # logger.info(f"{res.json()=}")
@@ -700,7 +711,7 @@ class Runtime:
def shutdown(self): def shutdown(self):
if self.pid is not None: if self.pid is not None:
kill_child_process(self.pid, include_self=True) kill_process_tree(self.pid)
self.pid = None self.pid = None
def cache_prefix(self, prefix: str): def cache_prefix(self, prefix: str):
@@ -924,7 +935,7 @@ class Engine:
return ret return ret
def shutdown(self): def shutdown(self):
kill_child_process() kill_process_tree(os.getpid(), include_parent=False)
def get_tokenizer(self): def get_tokenizer(self):
global tokenizer_manager global tokenizer_manager

View File

@@ -443,26 +443,14 @@ def assert_pkg_version(pkg: str, min_version: str, message: str):
) )
def kill_parent_process(): def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
"""Kill the parent process and all children of the parent process.""" """Kill the process and all its child processes."""
current_process = psutil.Process() if parent_pid is None:
parent_process = current_process.parent() parent_pid = os.getpid()
kill_child_process( include_parent = False
parent_process.pid, include_self=True, skip_pid=current_process.pid
)
try:
current_process.kill()
except psutil.NoSuchProcess:
pass
def kill_child_process(pid=None, include_self=False, skip_pid=None):
"""Kill the process and all its children process."""
if pid is None:
pid = os.getpid()
try: try:
itself = psutil.Process(pid) itself = psutil.Process(parent_pid)
except psutil.NoSuchProcess: except psutil.NoSuchProcess:
return return
@@ -475,13 +463,13 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
except psutil.NoSuchProcess: except psutil.NoSuchProcess:
pass pass
if include_self: if include_parent:
try: try:
itself.kill() itself.kill()
# Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes), # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
# so we send an additional signal to kill them. # so we send an additional signal to kill them.
itself.send_signal(signal.SIGINT) itself.send_signal(signal.SIGQUIT)
except psutil.NoSuchProcess: except psutil.NoSuchProcess:
pass pass

View File

@@ -22,7 +22,7 @@ from sglang.bench_serving import run_benchmark
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.lang.backend.openai import OpenAI from sglang.lang.backend.openai import OpenAI
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.utils import get_bool_env_var, kill_child_process from sglang.srt.utils import get_bool_env_var, kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
@@ -504,7 +504,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
) )
assert ret_code == 0 assert ret_code == 0
except TimeoutError: except TimeoutError:
kill_child_process(process.pid, include_self=True) kill_process_tree(process.pid)
time.sleep(5) time.sleep(5)
print( print(
f"\nTimeout after {timeout_per_file} seconds when running {filename}\n", f"\nTimeout after {timeout_per_file} seconds when running {filename}\n",
@@ -578,7 +578,7 @@ def run_bench_serving(
run_benchmark(warmup_args) run_benchmark(warmup_args)
res = run_benchmark(args) res = run_benchmark(args)
finally: finally:
kill_child_process(process.pid, include_self=True) kill_process_tree(process.pid)
assert res["completed"] == num_prompts assert res["completed"] == num_prompts
return res return res
@@ -611,7 +611,7 @@ def run_bench_one_batch(model, other_args):
lastline = output.split("\n")[-3] lastline = output.split("\n")[-3]
output_throughput = float(lastline.split(" ")[-2]) output_throughput = float(lastline.split(" ")[-2])
finally: finally:
kill_child_process(process.pid, include_self=True) kill_process_tree(process.pid)
return output_throughput return output_throughput
@@ -710,8 +710,8 @@ def run_and_check_memory_leak(
workload_func(base_url, model) workload_func(base_url, model)
# Clean up everything # Clean up everything
kill_child_process(process.pid, include_self=True) kill_process_tree(process.pid)
kill_child_process(process.pid, include_self=True) kill_process_tree(process.pid)
stdout.close() stdout.close()
stderr.close() stderr.close()
if os.path.exists(STDOUT_FILENAME): if os.path.exists(STDOUT_FILENAME):

View File

@@ -348,9 +348,9 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
def terminate_process(process): def terminate_process(process):
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
kill_child_process(process.pid, include_self=True) kill_process_tree(process.pid)
def print_highlight(html_content: str): def print_highlight(html_content: str):

View File

@@ -5,7 +5,7 @@ from types import SimpleNamespace
import requests import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
@@ -79,7 +79,7 @@ class TestEvalAccuracyMini(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(

View File

@@ -4,7 +4,7 @@ from multiprocessing import Process
import requests import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
@@ -31,7 +31,7 @@ class TestBatchPenalizerE2E(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def run_decode( def run_decode(
self, self,

View File

@@ -4,7 +4,7 @@ import unittest
import openai import openai
import requests import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
@@ -44,7 +44,7 @@ class TestCacheReport(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
response = requests.post( response = requests.post(

View File

@@ -4,7 +4,7 @@ from types import SimpleNamespace
import requests import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
@@ -28,7 +28,7 @@ class TestDataParallelism(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(

View File

@@ -2,7 +2,7 @@ import os
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
@@ -45,7 +45,7 @@ class TestDoubleSparsity(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(

View File

@@ -1,7 +1,7 @@
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST, DEFAULT_MLA_MODEL_NAME_FOR_TEST,
@@ -30,7 +30,7 @@ class TestDPAttention(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(

View File

@@ -3,7 +3,7 @@ import unittest
import openai import openai
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
@@ -28,7 +28,7 @@ class TestOpenAIServer(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def run_embedding(self, use_list_input, token_input): def run_embedding(self, use_list_input, token_input):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)

View File

@@ -6,7 +6,7 @@ python -m unittest test_eval_accuracy_large.TestEvalAccuracyLarge.test_mmlu
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
@@ -30,7 +30,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(

View File

@@ -1,7 +1,7 @@
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
@@ -25,7 +25,7 @@ class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(

View File

@@ -1,7 +1,7 @@
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
@@ -31,7 +31,7 @@ class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(

View File

@@ -1,7 +1,7 @@
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
@@ -22,7 +22,7 @@ class TestEvalAccuracyMini(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(

View File

@@ -4,7 +4,7 @@ import unittest
import requests import requests
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
@@ -107,7 +107,7 @@ class TestInputEmbeds(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -9,7 +9,7 @@ from concurrent.futures import ThreadPoolExecutor
import openai import openai
import requests import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
@@ -46,7 +46,7 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1): def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1):
response = requests.post( response = requests.post(

View File

@@ -10,7 +10,7 @@ from concurrent.futures import ThreadPoolExecutor
import openai import openai
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
@@ -52,7 +52,7 @@ class TestLargeMaxNewTokens(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
cls.stdout.close() cls.stdout.close()
cls.stderr.close() cls.stderr.close()
os.remove(STDOUT_FILENAME) os.remove(STDOUT_FILENAME)

View File

@@ -3,7 +3,7 @@ import unittest
import requests import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
@@ -32,7 +32,7 @@ class TestMatchedStop(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def run_completions_generation( def run_completions_generation(
self, self,

View File

@@ -2,7 +2,7 @@ import unittest
import requests import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
@@ -75,7 +75,7 @@ class TestEnableMetrics(unittest.TestCase):
self.assertIn("_bucket{", metrics_content) self.assertIn("_bucket{", metrics_content)
finally: finally:
kill_child_process(process.pid, include_self=True) kill_process_tree(process.pid)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,7 +1,7 @@
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST, DEFAULT_MLA_MODEL_NAME_FOR_TEST,
@@ -25,7 +25,7 @@ class TestMLA(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(

View File

@@ -1,7 +1,7 @@
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST, DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST,
@@ -31,7 +31,7 @@ class TestMLA(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_mgsm_en(self): def test_mgsm_en(self):
args = SimpleNamespace( args = SimpleNamespace(

View File

@@ -6,7 +6,7 @@ python -m unittest test_moe_eval_accuracy_large.TestMoEEvalAccuracyLarge.test_mm
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MOE_MODEL_NAME_FOR_TEST, DEFAULT_MOE_MODEL_NAME_FOR_TEST,
@@ -35,7 +35,7 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(

View File

@@ -6,7 +6,7 @@ import warnings
from datetime import datetime from datetime import datetime
from types import SimpleNamespace from types import SimpleNamespace
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1, DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1,
@@ -132,7 +132,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
def tearDown(self): def tearDown(self):
if self.process: if self.process:
kill_child_process(self.process.pid, include_self=True) kill_process_tree(self.process.pid)
def test_mgsm_en_all_models(self): def test_mgsm_en_all_models(self):
warnings.filterwarnings( warnings.filterwarnings(

View File

@@ -6,7 +6,7 @@ import unittest
from test_nightly_gsm8k_eval import launch_server, parse_models from test_nightly_gsm8k_eval import launch_server, parse_models
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1, DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1,
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2, DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2,
@@ -32,9 +32,9 @@ class TestEvalAccuracyLarge(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
if cls.process: if cls.process:
kill_child_process(cls.process.pid) kill_process_tree(cls.process.pid)
if cls.eval_process: if cls.eval_process:
kill_child_process(cls.eval_process.pid) kill_process_tree(cls.eval_process.pid)
def run_evalplus(self, model): def run_evalplus(self, model):
print("Delete evalplus results") print("Delete evalplus results")

View File

@@ -11,7 +11,7 @@ import unittest
import openai import openai
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
@@ -37,7 +37,7 @@ class TestOpenAIServer(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def run_completion( def run_completion(
self, echo, logprobs, use_list_input, parallel_sample_num, token_input self, echo, logprobs, use_list_input, parallel_sample_num, token_input

View File

@@ -3,7 +3,7 @@ from types import SimpleNamespace
import requests import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
@@ -27,7 +27,7 @@ class TestPyTorchSamplingBackend(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(

View File

@@ -8,7 +8,7 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
kill_child_process, kill_process_tree,
popen_launch_server, popen_launch_server,
) )
@@ -80,7 +80,7 @@ class TestRadixCacheFCFS(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_radix_attention(self): def test_radix_attention(self):
nodes = gen_radix_tree() nodes = gen_radix_tree()

View File

@@ -1,7 +1,7 @@
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
@@ -22,7 +22,7 @@ class TestRetractDecode(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(

View File

@@ -9,7 +9,7 @@ import unittest
import requests import requests
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
@@ -29,7 +29,7 @@ class TestSessionControl(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_session_control(self): def test_session_control(self):
chunks = [ chunks = [
@@ -191,7 +191,7 @@ class TestSessionControlVision(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_session_control(self): def test_session_control(self):
text_chunks = [ text_chunks = [

View File

@@ -7,7 +7,7 @@ import unittest
import requests import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
@@ -30,7 +30,7 @@ class TestSkipTokenizerInit(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
max_new_tokens = 32 max_new_tokens = 32

View File

@@ -9,7 +9,7 @@ import unittest
import numpy as np import numpy as np
import requests import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
@@ -29,7 +29,7 @@ class TestSRTEndpoint(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def run_decode( def run_decode(
self, self,

View File

@@ -4,7 +4,7 @@ from types import SimpleNamespace
import requests import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
@@ -28,7 +28,7 @@ class TestTorchCompile(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(

View File

@@ -4,7 +4,7 @@ from types import SimpleNamespace
import requests import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST,
@@ -28,7 +28,7 @@ class TestTorchCompile(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(

View File

@@ -3,7 +3,7 @@ from types import SimpleNamespace
import requests import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
@@ -27,7 +27,7 @@ class TestTorchAO(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(

View File

@@ -6,7 +6,7 @@ python3 -m unittest test_triton_attention_backend.TestTritonAttnBackend.test_mml
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
@@ -54,7 +54,7 @@ class TestTritonAttnBackend(unittest.TestCase):
metrics = run_eval(args) metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.65) self.assertGreaterEqual(metrics["score"], 0.65)
finally: finally:
kill_child_process(process.pid, include_self=True) kill_process_tree(process.pid)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -3,7 +3,7 @@ import unittest
import requests import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
@@ -23,7 +23,7 @@ class TestUpdateWeights(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def run_decode(self): def run_decode(self):
response = requests.post( response = requests.post(

View File

@@ -17,7 +17,7 @@ import requests
from decord import VideoReader, cpu from decord import VideoReader, cpu
from PIL import Image from PIL import Image
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
@@ -46,7 +46,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_chat_completion(self): def test_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
@@ -387,7 +387,7 @@ class TestQWen2VLServerContextLengthIssue(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_chat_completion(self): def test_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)