Make scripts under /test/srt as unit tests (#875)

This commit is contained in:
Ying Sheng
2024-08-01 14:34:55 -07:00
committed by GitHub
parent e4d3333c6c
commit 72b6ea88b4
18 changed files with 353 additions and 212 deletions

View File

@@ -20,8 +20,6 @@ concurrency:
jobs: jobs:
unit-test: unit-test:
runs-on: self-hosted runs-on: self-hosted
env:
CUDA_VISIBLE_DEVICES: 6
steps: steps:
- name: Checkout code - name: Checkout code
@@ -30,6 +28,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
cd /data/zhyncs/venv && source ./bin/activate && cd - cd /data/zhyncs/venv && source ./bin/activate && cd -
pip cache purge pip cache purge
pip install --upgrade pip pip install --upgrade pip
pip install -e "python[all]" pip install -e "python[all]"
@@ -39,6 +38,14 @@ jobs:
- name: Test OpenAI Backend - name: Test OpenAI Backend
run: | run: |
cd /data/zhyncs/venv && source ./bin/activate && cd - cd /data/zhyncs/venv && source ./bin/activate && cd -
cd test/lang
export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }} export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}
cd test/lang
python3 test_openai_backend.py python3 test_openai_backend.py
- name: Test SRT Backend
run: |
cd /data/zhyncs/venv && source ./bin/activate && cd -
cd test/lang
python3 test_srt_backend.py

View File

@@ -73,6 +73,7 @@ from sglang.srt.utils import (
assert_pkg_version, assert_pkg_version,
enable_show_time_cost, enable_show_time_cost,
maybe_set_triton_cache_manager, maybe_set_triton_cache_manager,
kill_child_process,
set_ulimit, set_ulimit,
) )
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
@@ -467,16 +468,7 @@ class Runtime:
def shutdown(self): def shutdown(self):
if self.pid is not None: if self.pid is not None:
try: kill_child_process(self.pid)
parent = psutil.Process(self.pid)
except psutil.NoSuchProcess:
return
children = parent.children(recursive=True)
for child in children:
child.kill()
psutil.wait_procs(children, timeout=5)
parent.kill()
parent.wait(timeout=5)
self.pid = None self.pid = None
def cache_prefix(self, prefix: str): def cache_prefix(self, prefix: str):

View File

@@ -366,6 +366,26 @@ def kill_parent_process():
os.kill(parent_process.pid, 9) os.kill(parent_process.pid, 9)
def kill_child_process(pid, including_parent=True):
try:
parent = psutil.Process(pid)
except psutil.NoSuchProcess:
return
children = parent.children(recursive=True)
for child in children:
try:
child.kill()
except psutil.NoSuchProcess:
pass
if including_parent:
try:
parent.kill()
except psutil.NoSuchProcess:
pass
def monkey_patch_vllm_p2p_access_check(gpu_id: int): def monkey_patch_vllm_p2p_access_check(gpu_id: int):
""" """
Monkey patch the slow p2p access check in vllm. Monkey patch the slow p2p access check in vllm.

View File

@@ -105,15 +105,14 @@ def test_decode_json_regex():
def decode_json(s): def decode_json(s):
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING
s += "Generate a JSON object to describe the basic information of a city.\n" s += "Generate a JSON object to describe the basic city information of Paris.\n"
with s.var_scope("json_output"): with s.var_scope("json_output"):
s += "{\n" s += "{\n"
s += ' "name": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n" s += ' "name": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n"
s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n" s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n" s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT + ",") + "\n" s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT) + "\n"
s += ' "country": ' + sgl.gen(regex=REGEX_STRING) + "\n"
s += "}" s += "}"
ret = decode_json.run(temperature=0.0) ret = decode_json.run(temperature=0.0)
@@ -129,7 +128,7 @@ def test_decode_json_regex():
def test_decode_json(): def test_decode_json():
@sgl.function @sgl.function
def decode_json(s): def decode_json(s):
s += "Generate a JSON object to describe the basic information of a city.\n" s += "Generate a JSON object to describe the basic city information of Paris.\n"
with s.var_scope("json_output"): with s.var_scope("json_output"):
s += "{\n" s += "{\n"
@@ -264,6 +263,7 @@ def test_parallel_decoding():
s += "\nIn summary," + sgl.gen("summary", max_tokens=512) s += "\nIn summary," + sgl.gen("summary", max_tokens=512)
ret = parallel_decoding.run(topic="writing a good blog post", temperature=0.3) ret = parallel_decoding.run(topic="writing a good blog post", temperature=0.3)
assert isinstance(ret["summary"], str)
def test_parallel_encoding(check_answer=True): def test_parallel_encoding(check_answer=True):

View File

@@ -21,7 +21,7 @@ class TestSRTBackend(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.backend = sgl.Runtime(model_path="meta-llama/Meta-Llama-3-8B-Instruct") cls.backend = sgl.Runtime(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct")
sgl.set_default_backend(cls.backend) sgl.set_default_backend(cls.backend)
@classmethod @classmethod

View File

@@ -1 +0,0 @@
../lang/example_image.png

View File

@@ -0,0 +1,209 @@
"""
First run the following command to launch the server.
Note that TinyLlama adopts different chat templates in different versions.
For v0.4, the chat template is chatml.
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 \
--port 30000 --chat-template chatml
Output example:
The capital of France is Paris.
The capital of the United States is Washington, D.C.
The capital of Canada is Ottawa.
The capital of Japan is Tokyo
"""
import argparse
import json
import openai
def test_completion(args, echo, logprobs):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.completions.create(
model="default",
prompt="The capital of France is",
temperature=0,
max_tokens=32,
echo=echo,
logprobs=logprobs,
)
text = response.choices[0].text
print(response.choices[0].text)
if echo:
assert text.startswith("The capital of France is")
if logprobs:
print(response.choices[0].logprobs.top_logprobs)
assert response.choices[0].logprobs
if echo:
assert response.choices[0].logprobs.token_logprobs[0] == None
else:
assert response.choices[0].logprobs.token_logprobs[0] != None
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
print("=" * 100)
def test_completion_stream(args, echo, logprobs):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.completions.create(
model="default",
prompt="The capital of France is",
temperature=0,
max_tokens=32,
stream=True,
echo=echo,
logprobs=logprobs,
)
first = True
for r in response:
if first:
if echo:
assert r.choices[0].text.startswith("The capital of France is")
first = False
if logprobs:
print(
f"{r.choices[0].text:12s}\t" f"{r.choices[0].logprobs.token_logprobs}",
flush=True,
)
print(r.choices[0].logprobs.top_logprobs)
else:
print(r.choices[0].text, end="", flush=True)
assert r.id
assert r.usage.prompt_tokens > 0
assert r.usage.completion_tokens > 0
assert r.usage.total_tokens > 0
print("=" * 100)
def test_chat_completion(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "What is the capital of France?"},
],
temperature=0,
max_tokens=32,
)
print(response.choices[0].message.content)
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
print("=" * 100)
def test_chat_completion_image(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": [
{"type": "text", "text": "Describe this image"},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/mixtral_8x7b.jpg"
},
},
],
},
],
temperature=0,
max_tokens=32,
)
print(response.choices[0].message.content)
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
print("=" * 100)
def test_chat_completion_stream(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
stream=True,
)
is_first = True
for chunk in response:
if is_first:
is_first = False
assert chunk.choices[0].delta.role == "assistant"
continue
data = chunk.choices[0].delta
if not data.content:
continue
print(data.content, end="", flush=True)
print("=" * 100)
def test_regex(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
regex = (
r"""\{\n"""
+ r""" "name": "[\w]+",\n"""
+ r""" "population": [\d]+\n"""
+ r"""\}"""
)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "Introduce the capital of France."},
],
temperature=0,
max_tokens=128,
extra_body={"regex": regex},
)
text = response.choices[0].message.content
print(json.loads(text))
print("=" * 100)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1")
parser.add_argument(
"--test-image", action="store_true", help="Enables testing image inputs"
)
args = parser.parse_args()
test_completion(args, echo=False, logprobs=False)
test_completion(args, echo=True, logprobs=False)
test_completion(args, echo=False, logprobs=True)
test_completion(args, echo=True, logprobs=True)
test_completion(args, echo=False, logprobs=3)
test_completion(args, echo=True, logprobs=3)
test_completion_stream(args, echo=False, logprobs=False)
test_completion_stream(args, echo=True, logprobs=False)
test_completion_stream(args, echo=False, logprobs=True)
test_completion_stream(args, echo=True, logprobs=True)
test_completion_stream(args, echo=False, logprobs=3)
test_completion_stream(args, echo=True, logprobs=3)
test_chat_completion(args)
test_chat_completion_stream(args)
test_regex(args)
if args.test_image:
test_chat_completion_image(args)

View File

@@ -1,209 +1,123 @@
""" import subprocess
First run the following command to launch the server. import time
Note that TinyLlama adopts different chat templates in different versions. import unittest
For v0.4, the chat template is chatml.
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 \
--port 30000 --chat-template chatml
Output example:
The capital of France is Paris.
The capital of the United States is Washington, D.C.
The capital of Canada is Ottawa.
The capital of Japan is Tokyo
"""
import argparse
import json
import openai import openai
import requests
from sglang.srt.utils import kill_child_process
def test_completion(args, echo, logprobs): class TestOpenAIServer(unittest.TestCase):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.completions.create( @classmethod
model="default", def setUpClass(cls):
prompt="The capital of France is", model = "meta-llama/Meta-Llama-3.1-8B-Instruct"
temperature=0, port = 30000
max_tokens=32, timeout = 300
echo=echo,
logprobs=logprobs, command = [
) "python3", "-m", "sglang.launch_server",
text = response.choices[0].text "--model-path", model,
print(response.choices[0].text) "--host", "localhost",
if echo: "--port", str(port),
assert text.startswith("The capital of France is") ]
if logprobs: cls.process = subprocess.Popen(command, stdout=None, stderr=None)
print(response.choices[0].logprobs.top_logprobs) cls.base_url = f"http://localhost:{port}/v1"
assert response.choices[0].logprobs cls.model = model
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = requests.get(f"{cls.base_url}/models")
if response.status_code == 200:
return
except requests.RequestException:
pass
time.sleep(10)
raise TimeoutError("Server failed to start within the timeout period.")
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
def run_completion(self, echo, logprobs):
client = openai.Client(api_key="EMPTY", base_url=self.base_url)
prompt = "The capital of France is"
response = client.completions.create(
model=self.model,
prompt=prompt,
temperature=0.1,
max_tokens=32,
echo=echo,
logprobs=logprobs,
)
text = response.choices[0].text
if echo: if echo:
assert response.choices[0].logprobs.token_logprobs[0] == None assert text.startswith(prompt)
else:
assert response.choices[0].logprobs.token_logprobs[0] != None
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
print("=" * 100)
def test_completion_stream(args, echo, logprobs):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.completions.create(
model="default",
prompt="The capital of France is",
temperature=0,
max_tokens=32,
stream=True,
echo=echo,
logprobs=logprobs,
)
first = True
for r in response:
if first:
if echo:
assert r.choices[0].text.startswith("The capital of France is")
first = False
if logprobs: if logprobs:
print( assert response.choices[0].logprobs
f"{r.choices[0].text:12s}\t" f"{r.choices[0].logprobs.token_logprobs}", assert isinstance(response.choices[0].logprobs.tokens[0], str)
flush=True, assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
) assert len(response.choices[0].logprobs.top_logprobs[1]) == logprobs
print(r.choices[0].logprobs.top_logprobs) if echo:
else: assert response.choices[0].logprobs.token_logprobs[0] == None
print(r.choices[0].text, end="", flush=True) else:
assert r.id assert response.choices[0].logprobs.token_logprobs[0] != None
assert r.usage.prompt_tokens > 0 assert response.id
assert r.usage.completion_tokens > 0 assert response.created
assert r.usage.total_tokens > 0 assert response.usage.prompt_tokens > 0
print("=" * 100) assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def run_completion_stream(self, echo, logprobs):
client = openai.Client(api_key="EMPTY", base_url=self.base_url)
prompt = "The capital of France is"
generator = client.completions.create(
model=self.model,
prompt=prompt,
temperature=0.1,
max_tokens=32,
echo=echo,
logprobs=logprobs,
stream=True,
)
def test_chat_completion(args): first = True
client = openai.Client(api_key="EMPTY", base_url=args.base_url) for response in generator:
response = client.chat.completions.create( if logprobs:
model="default", assert response.choices[0].logprobs
messages=[ assert isinstance(response.choices[0].logprobs.tokens[0], str)
{"role": "system", "content": "You are a helpful AI assistant"}, if not (first and echo):
{"role": "user", "content": "What is the capital of France?"}, assert isinstance(response.choices[0].logprobs.top_logprobs[0], dict)
], #assert len(response.choices[0].logprobs.top_logprobs[0]) == logprobs
temperature=0,
max_tokens=32,
)
print(response.choices[0].message.content)
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
print("=" * 100)
if first:
if echo:
assert response.choices[0].text.startswith(prompt)
first = False
def test_chat_completion_image(args): assert response.id
client = openai.Client(api_key="EMPTY", base_url=args.base_url) assert response.created
response = client.chat.completions.create( assert response.usage.prompt_tokens > 0
model="default", assert response.usage.completion_tokens > 0
messages=[ assert response.usage.total_tokens > 0
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": [
{"type": "text", "text": "Describe this image"},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/mixtral_8x7b.jpg"
},
},
],
},
],
temperature=0,
max_tokens=32,
)
print(response.choices[0].message.content)
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
print("=" * 100)
def test_completion(self):
for echo in [False, True]:
for logprobs in [None, 5]:
self.run_completion(echo, logprobs)
def test_chat_completion_stream(args): def test_completion_stream(self):
client = openai.Client(api_key="EMPTY", base_url=args.base_url) for echo in [True]:
response = client.chat.completions.create( for logprobs in [5]:
model="default", self.run_completion_stream(echo, logprobs)
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
stream=True,
)
is_first = True
for chunk in response:
if is_first:
is_first = False
assert chunk.choices[0].delta.role == "assistant"
continue
data = chunk.choices[0].delta
if not data.content:
continue
print(data.content, end="", flush=True)
print("=" * 100)
def test_regex(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
regex = (
r"""\{\n"""
+ r""" "name": "[\w]+",\n"""
+ r""" "population": [\d]+\n"""
+ r"""\}"""
)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "Introduce the capital of France."},
],
temperature=0,
max_tokens=128,
extra_body={"regex": regex},
)
text = response.choices[0].message.content
print(json.loads(text))
print("=" * 100)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() # unittest.main(warnings="ignore")
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1")
parser.add_argument(
"--test-image", action="store_true", help="Enables testing image inputs"
)
args = parser.parse_args()
test_completion(args, echo=False, logprobs=False) t = TestOpenAIServer()
test_completion(args, echo=True, logprobs=False) t.setUpClass()
test_completion(args, echo=False, logprobs=True) t.test_completion_stream()
test_completion(args, echo=True, logprobs=True) t.tearDownClass()
test_completion(args, echo=False, logprobs=3)
test_completion(args, echo=True, logprobs=3)
test_completion_stream(args, echo=False, logprobs=False)
test_completion_stream(args, echo=True, logprobs=False)
test_completion_stream(args, echo=False, logprobs=True)
test_completion_stream(args, echo=True, logprobs=True)
test_completion_stream(args, echo=False, logprobs=3)
test_completion_stream(args, echo=True, logprobs=3)
test_chat_completion(args)
test_chat_completion_stream(args)
test_regex(args)
if args.test_image:
test_chat_completion_image(args)