Make scripts under /test/srt as unit tests (#875)
This commit is contained in:
@@ -73,6 +73,7 @@ from sglang.srt.utils import (
|
||||
assert_pkg_version,
|
||||
enable_show_time_cost,
|
||||
maybe_set_triton_cache_manager,
|
||||
kill_child_process,
|
||||
set_ulimit,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
@@ -467,16 +468,7 @@ class Runtime:
|
||||
|
||||
def shutdown(self):
|
||||
if self.pid is not None:
|
||||
try:
|
||||
parent = psutil.Process(self.pid)
|
||||
except psutil.NoSuchProcess:
|
||||
return
|
||||
children = parent.children(recursive=True)
|
||||
for child in children:
|
||||
child.kill()
|
||||
psutil.wait_procs(children, timeout=5)
|
||||
parent.kill()
|
||||
parent.wait(timeout=5)
|
||||
kill_child_process(self.pid)
|
||||
self.pid = None
|
||||
|
||||
def cache_prefix(self, prefix: str):
|
||||
|
||||
@@ -366,6 +366,26 @@ def kill_parent_process():
|
||||
os.kill(parent_process.pid, 9)
|
||||
|
||||
|
||||
def kill_child_process(pid, including_parent=True):
|
||||
try:
|
||||
parent = psutil.Process(pid)
|
||||
except psutil.NoSuchProcess:
|
||||
return
|
||||
|
||||
children = parent.children(recursive=True)
|
||||
for child in children:
|
||||
try:
|
||||
child.kill()
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
if including_parent:
|
||||
try:
|
||||
parent.kill()
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
|
||||
def monkey_patch_vllm_p2p_access_check(gpu_id: int):
|
||||
"""
|
||||
Monkey patch the slow p2p access check in vllm.
|
||||
|
||||
@@ -105,15 +105,14 @@ def test_decode_json_regex():
|
||||
def decode_json(s):
|
||||
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING
|
||||
|
||||
s += "Generate a JSON object to describe the basic information of a city.\n"
|
||||
s += "Generate a JSON object to describe the basic city information of Paris.\n"
|
||||
|
||||
with s.var_scope("json_output"):
|
||||
s += "{\n"
|
||||
s += ' "name": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n"
|
||||
s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
|
||||
s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
|
||||
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT + ",") + "\n"
|
||||
s += ' "country": ' + sgl.gen(regex=REGEX_STRING) + "\n"
|
||||
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT) + "\n"
|
||||
s += "}"
|
||||
|
||||
ret = decode_json.run(temperature=0.0)
|
||||
@@ -129,7 +128,7 @@ def test_decode_json_regex():
|
||||
def test_decode_json():
|
||||
@sgl.function
|
||||
def decode_json(s):
|
||||
s += "Generate a JSON object to describe the basic information of a city.\n"
|
||||
s += "Generate a JSON object to describe the basic city information of Paris.\n"
|
||||
|
||||
with s.var_scope("json_output"):
|
||||
s += "{\n"
|
||||
@@ -264,6 +263,7 @@ def test_parallel_decoding():
|
||||
s += "\nIn summary," + sgl.gen("summary", max_tokens=512)
|
||||
|
||||
ret = parallel_decoding.run(topic="writing a good blog post", temperature=0.3)
|
||||
assert isinstance(ret["summary"], str)
|
||||
|
||||
|
||||
def test_parallel_encoding(check_answer=True):
|
||||
|
||||
Reference in New Issue
Block a user