Add a watch dog thread (#1816)
This commit is contained in:
@@ -550,4 +550,4 @@ if __name__ == "__main__":
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
finally:
|
finally:
|
||||||
kill_child_process(os.getpid(), including_parent=False)
|
kill_child_process()
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ import dataclasses
|
|||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
|
||||||
import time
|
import time
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
@@ -70,7 +69,7 @@ def launch_server_internal(server_args):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
finally:
|
finally:
|
||||||
kill_child_process(os.getpid(), including_parent=False)
|
kill_child_process()
|
||||||
|
|
||||||
|
|
||||||
def launch_server_process(server_args: ServerArgs):
|
def launch_server_process(server_args: ServerArgs):
|
||||||
@@ -176,7 +175,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
if proc:
|
if proc:
|
||||||
kill_child_process(proc.pid)
|
kill_child_process(proc.pid, include_self=True)
|
||||||
|
|
||||||
print(f"\nResults are saved to {bench_args.result_filename}")
|
print(f"\nResults are saved to {bench_args.result_filename}")
|
||||||
|
|
||||||
|
|||||||
@@ -15,4 +15,4 @@ if __name__ == "__main__":
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
finally:
|
finally:
|
||||||
kill_child_process(os.getpid(), including_parent=False)
|
kill_child_process()
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from collections import deque
|
from collections import deque
|
||||||
@@ -222,10 +223,11 @@ class Scheduler:
|
|||||||
self.waiting_queue: List[Req] = []
|
self.waiting_queue: List[Req] = []
|
||||||
self.running_batch: Optional[ScheduleBatch] = None
|
self.running_batch: Optional[ScheduleBatch] = None
|
||||||
self.cur_batch: Optional[ScheduleBatch] = None
|
self.cur_batch: Optional[ScheduleBatch] = None
|
||||||
self.decode_forward_ct = 0
|
self.forward_ct = 0
|
||||||
self.stream_interval = server_args.stream_interval
|
self.forward_ct_decode = 0
|
||||||
self.num_generated_tokens = 0
|
self.num_generated_tokens = 0
|
||||||
self.last_stats_tic = time.time()
|
self.last_stats_tic = time.time()
|
||||||
|
self.stream_interval = server_args.stream_interval
|
||||||
|
|
||||||
# Init chunked prefill
|
# Init chunked prefill
|
||||||
self.chunked_prefill_size = server_args.chunked_prefill_size
|
self.chunked_prefill_size = server_args.chunked_prefill_size
|
||||||
@@ -272,6 +274,11 @@ class Scheduler:
|
|||||||
|
|
||||||
self.batch_is_full = False
|
self.batch_is_full = False
|
||||||
|
|
||||||
|
# Init watchdog thread
|
||||||
|
self.watchdog_timeout = server_args.watchdog_timeout
|
||||||
|
t = threading.Thread(target=self.watchdog_thread, daemon=True)
|
||||||
|
t.start()
|
||||||
|
|
||||||
# Init profiler
|
# Init profiler
|
||||||
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
|
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
|
||||||
self.profiler = None
|
self.profiler = None
|
||||||
@@ -289,6 +296,23 @@ class Scheduler:
|
|||||||
with_stack=True,
|
with_stack=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def watchdog_thread(self):
|
||||||
|
self.watchdog_last_forward_ct = 0
|
||||||
|
self.watchdog_last_time = time.time()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if self.cur_batch is not None:
|
||||||
|
if self.watchdog_last_forward_ct == self.forward_ct:
|
||||||
|
if time.time() > self.watchdog_last_time + self.watchdog_timeout:
|
||||||
|
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
self.watchdog_last_forward_ct = self.forward_ct
|
||||||
|
self.watchdog_last_time = time.time()
|
||||||
|
time.sleep(self.watchdog_timeout / 2)
|
||||||
|
|
||||||
|
kill_parent_process()
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def event_loop_normal(self):
|
def event_loop_normal(self):
|
||||||
"""A normal blocking scheduler loop."""
|
"""A normal blocking scheduler loop."""
|
||||||
@@ -299,6 +323,7 @@ class Scheduler:
|
|||||||
self.process_input_requests(recv_reqs)
|
self.process_input_requests(recv_reqs)
|
||||||
|
|
||||||
batch = self.get_next_batch_to_run()
|
batch = self.get_next_batch_to_run()
|
||||||
|
self.cur_batch = batch
|
||||||
|
|
||||||
if batch:
|
if batch:
|
||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
@@ -746,6 +771,8 @@ class Scheduler:
|
|||||||
|
|
||||||
def run_batch(self, batch: ScheduleBatch):
|
def run_batch(self, batch: ScheduleBatch):
|
||||||
"""Run a batch."""
|
"""Run a batch."""
|
||||||
|
self.forward_ct += 1
|
||||||
|
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
@@ -778,6 +805,7 @@ class Scheduler:
|
|||||||
self.process_batch_result_prefill(batch, result)
|
self.process_batch_result_prefill(batch, result)
|
||||||
|
|
||||||
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
||||||
|
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
logits_output, next_token_ids, bid = result
|
logits_output, next_token_ids, bid = result
|
||||||
|
|
||||||
@@ -890,8 +918,8 @@ class Scheduler:
|
|||||||
|
|
||||||
self.token_to_kv_pool.free_group_end()
|
self.token_to_kv_pool.free_group_end()
|
||||||
|
|
||||||
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
||||||
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
if self.tp_rank == 0 and self.forward_ct_decode % 40 == 0:
|
||||||
self.print_decode_stats()
|
self.print_decode_stats()
|
||||||
|
|
||||||
def add_logprob_return_values(
|
def add_logprob_return_values(
|
||||||
@@ -984,7 +1012,7 @@ class Scheduler:
|
|||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
output_embeddings = []
|
output_embeddings = []
|
||||||
|
|
||||||
is_stream_iter = self.decode_forward_ct % self.stream_interval == 0
|
is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
|
||||||
|
|
||||||
for req in reqs:
|
for req in reqs:
|
||||||
if req.finished() or (
|
if req.finished() or (
|
||||||
|
|||||||
@@ -441,7 +441,7 @@ def launch_server(
|
|||||||
|
|
||||||
# Send a warmup request
|
# Send a warmup request
|
||||||
t = threading.Thread(
|
t = threading.Thread(
|
||||||
target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid())
|
target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
|
||||||
)
|
)
|
||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
@@ -496,7 +496,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
mp.set_start_method("spawn", force=True)
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
|
||||||
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
def _wait_and_warmup(server_args, pipe_finish_writer):
|
||||||
headers = {}
|
headers = {}
|
||||||
url = server_args.url()
|
url = server_args.url()
|
||||||
if server_args.api_key:
|
if server_args.api_key:
|
||||||
@@ -519,7 +519,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
|||||||
if pipe_finish_writer is not None:
|
if pipe_finish_writer is not None:
|
||||||
pipe_finish_writer.send(last_traceback)
|
pipe_finish_writer.send(last_traceback)
|
||||||
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
||||||
kill_child_process(pid, including_parent=False)
|
kill_child_process(include_self=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
model_info = res.json()
|
model_info = res.json()
|
||||||
@@ -551,7 +551,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
|||||||
if pipe_finish_writer is not None:
|
if pipe_finish_writer is not None:
|
||||||
pipe_finish_writer.send(last_traceback)
|
pipe_finish_writer.send(last_traceback)
|
||||||
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
||||||
kill_child_process(pid, including_parent=False)
|
kill_child_process(include_self=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
# logger.info(f"{res.json()=}")
|
# logger.info(f"{res.json()=}")
|
||||||
@@ -617,7 +617,7 @@ class Runtime:
|
|||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
if self.pid is not None:
|
if self.pid is not None:
|
||||||
kill_child_process(self.pid)
|
kill_child_process(self.pid, include_self=True)
|
||||||
self.pid = None
|
self.pid = None
|
||||||
|
|
||||||
def cache_prefix(self, prefix: str):
|
def cache_prefix(self, prefix: str):
|
||||||
@@ -834,7 +834,7 @@ class Engine:
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
kill_child_process(os.getpid(), including_parent=False)
|
kill_child_process(include_self=True)
|
||||||
|
|
||||||
def get_tokenizer(self):
|
def get_tokenizer(self):
|
||||||
global tokenizer_manager
|
global tokenizer_manager
|
||||||
|
|||||||
@@ -74,6 +74,7 @@ class ServerArgs:
|
|||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
file_storage_pth: str = "SGLang_storage"
|
file_storage_pth: str = "SGLang_storage"
|
||||||
enable_cache_report: bool = False
|
enable_cache_report: bool = False
|
||||||
|
watchdog_timeout: float = 600
|
||||||
|
|
||||||
# Data parallelism
|
# Data parallelism
|
||||||
dp_size: int = 1
|
dp_size: int = 1
|
||||||
@@ -429,6 +430,12 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
|
help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--watchdog-timeout",
|
||||||
|
type=float,
|
||||||
|
default=ServerArgs.watchdog_timeout,
|
||||||
|
help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
|
||||||
|
)
|
||||||
|
|
||||||
# Data parallelism
|
# Data parallelism
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -398,17 +398,26 @@ def kill_parent_process():
|
|||||||
"""Kill the parent process and all children of the parent process."""
|
"""Kill the parent process and all children of the parent process."""
|
||||||
current_process = psutil.Process()
|
current_process = psutil.Process()
|
||||||
parent_process = current_process.parent()
|
parent_process = current_process.parent()
|
||||||
kill_child_process(parent_process.pid, skip_pid=current_process.pid)
|
kill_child_process(
|
||||||
|
parent_process.pid, include_self=True, skip_pid=current_process.pid
|
||||||
|
)
|
||||||
def kill_child_process(pid, including_parent=True, skip_pid=None):
|
|
||||||
"""Kill the process and all its children process."""
|
|
||||||
try:
|
try:
|
||||||
parent = psutil.Process(pid)
|
current_process.kill()
|
||||||
|
except psutil.NoSuchProcess:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def kill_child_process(pid=None, include_self=False, skip_pid=None):
|
||||||
|
"""Kill the process and all its children process."""
|
||||||
|
if pid is None:
|
||||||
|
pid = os.getpid()
|
||||||
|
|
||||||
|
try:
|
||||||
|
itself = psutil.Process(pid)
|
||||||
except psutil.NoSuchProcess:
|
except psutil.NoSuchProcess:
|
||||||
return
|
return
|
||||||
|
|
||||||
children = parent.children(recursive=True)
|
children = itself.children(recursive=True)
|
||||||
for child in children:
|
for child in children:
|
||||||
if child.pid == skip_pid:
|
if child.pid == skip_pid:
|
||||||
continue
|
continue
|
||||||
@@ -417,9 +426,9 @@ def kill_child_process(pid, including_parent=True, skip_pid=None):
|
|||||||
except psutil.NoSuchProcess:
|
except psutil.NoSuchProcess:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if including_parent:
|
if include_self:
|
||||||
try:
|
try:
|
||||||
parent.kill()
|
itself.kill()
|
||||||
except psutil.NoSuchProcess:
|
except psutil.NoSuchProcess:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -495,7 +495,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
|
|||||||
)
|
)
|
||||||
assert ret_code == 0
|
assert ret_code == 0
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
kill_child_process(process.pid)
|
kill_child_process(process.pid, include_self=True)
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
print(
|
print(
|
||||||
f"\nTimeout after {timeout_per_file} seconds when running {filename}\n",
|
f"\nTimeout after {timeout_per_file} seconds when running {filename}\n",
|
||||||
@@ -563,7 +563,7 @@ def run_bench_serving(
|
|||||||
try:
|
try:
|
||||||
res = run_benchmark(args)
|
res = run_benchmark(args)
|
||||||
finally:
|
finally:
|
||||||
kill_child_process(process.pid)
|
kill_child_process(process.pid, include_self=True)
|
||||||
|
|
||||||
assert res["completed"] == num_prompts
|
assert res["completed"] == num_prompts
|
||||||
return res
|
return res
|
||||||
@@ -596,7 +596,7 @@ def run_bench_latency(model, other_args):
|
|||||||
lastline = output.split("\n")[-3]
|
lastline = output.split("\n")[-3]
|
||||||
output_throughput = float(lastline.split(" ")[-2])
|
output_throughput = float(lastline.split(" ")[-2])
|
||||||
finally:
|
finally:
|
||||||
kill_child_process(process.pid)
|
kill_child_process(process.pid, include_self=True)
|
||||||
|
|
||||||
return output_throughput
|
return output_throughput
|
||||||
|
|
||||||
@@ -707,8 +707,8 @@ def run_mmlu_test(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# Clean up everything
|
# Clean up everything
|
||||||
kill_child_process(process.pid)
|
kill_child_process(process.pid, include_self=True)
|
||||||
kill_child_process(process.pid)
|
kill_child_process(process.pid, include_self=True)
|
||||||
stdout.close()
|
stdout.close()
|
||||||
stderr.close()
|
stderr.close()
|
||||||
if os.path.exists(STDOUT_FILENAME):
|
if os.path.exists(STDOUT_FILENAME):
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class TestBatchPenalizerE2E(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
def run_decode(
|
def run_decode(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ class TestCacheReport(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
|
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ class TestDataParallelism(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
def test_mmlu(self):
|
def test_mmlu(self):
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class TestDoubleSparsity(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
def test_mmlu(self):
|
def test_mmlu(self):
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
def run_embedding(self, use_list_input, token_input):
|
def run_embedding(self, use_list_input, token_input):
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
def test_mmlu(self):
|
def test_mmlu(self):
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
def test_mmlu(self):
|
def test_mmlu(self):
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
def test_mmlu(self):
|
def test_mmlu(self):
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ class TestEvalAccuracyMini(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
def test_mmlu(self):
|
def test_mmlu(self):
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ class TestJSONConstrained(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1):
|
def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ class TestLargeMaxNewTokens(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
cls.stdout.close()
|
cls.stdout.close()
|
||||||
cls.stderr.close()
|
cls.stderr.close()
|
||||||
os.remove("stdout.txt")
|
os.remove("stdout.txt")
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ class TestMatchedStop(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
def run_completions_generation(
|
def run_completions_generation(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ class TestMLA(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
def test_mmlu(self):
|
def test_mmlu(self):
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class TestMLA(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
def test_mgsm_en(self):
|
def test_mgsm_en(self):
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
def test_mmlu(self):
|
def test_mmlu(self):
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
|
|||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
if self.process:
|
if self.process:
|
||||||
kill_child_process(self.process.pid)
|
kill_child_process(self.process.pid, include_self=True)
|
||||||
|
|
||||||
def launch_server(self, model, is_fp8, is_tp2):
|
def launch_server(self, model, is_fp8, is_tp2):
|
||||||
other_args = ["--log-level-http", "warning", "--trust-remote-code"]
|
other_args = ["--log-level-http", "warning", "--trust-remote-code"]
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
def run_completion(
|
def run_completion(
|
||||||
self, echo, logprobs, use_list_input, parallel_sample_num, token_input
|
self, echo, logprobs, use_list_input, parallel_sample_num, token_input
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ class TestPyTorchSamplingBackend(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
def test_mmlu(self):
|
def test_mmlu(self):
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ class TestRetractDecode(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
def test_mmlu(self):
|
def test_mmlu(self):
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
|
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
|
||||||
max_new_tokens = 32
|
max_new_tokens = 32
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
def run_decode(
|
def run_decode(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ class TestTorchCompile(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
def test_mmlu(self):
|
def test_mmlu(self):
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ class TestTorchCompile(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
def test_mmlu(self):
|
def test_mmlu(self):
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ class TestTritonAttnBackend(unittest.TestCase):
|
|||||||
metrics = run_eval(args)
|
metrics = run_eval(args)
|
||||||
assert metrics["score"] >= 0.65
|
assert metrics["score"] >= 0.65
|
||||||
finally:
|
finally:
|
||||||
kill_child_process(process.pid)
|
kill_child_process(process.pid, include_self=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class TestUpdateWeights(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
def run_decode(self):
|
def run_decode(self):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid, include_self=True)
|
||||||
|
|
||||||
def test_chat_completion(self):
|
def test_chat_completion(self):
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|||||||
Reference in New Issue
Block a user