Make scripts under /test/srt as unit tests (#875)
This commit is contained in:
13
.github/workflows/unit-test.yml
vendored
13
.github/workflows/unit-test.yml
vendored
@@ -20,8 +20,6 @@ concurrency:
|
|||||||
jobs:
|
jobs:
|
||||||
unit-test:
|
unit-test:
|
||||||
runs-on: self-hosted
|
runs-on: self-hosted
|
||||||
env:
|
|
||||||
CUDA_VISIBLE_DEVICES: 6
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
@@ -30,6 +28,7 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
cd /data/zhyncs/venv && source ./bin/activate && cd -
|
cd /data/zhyncs/venv && source ./bin/activate && cd -
|
||||||
|
|
||||||
pip cache purge
|
pip cache purge
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install -e "python[all]"
|
pip install -e "python[all]"
|
||||||
@@ -39,6 +38,14 @@ jobs:
|
|||||||
- name: Test OpenAI Backend
|
- name: Test OpenAI Backend
|
||||||
run: |
|
run: |
|
||||||
cd /data/zhyncs/venv && source ./bin/activate && cd -
|
cd /data/zhyncs/venv && source ./bin/activate && cd -
|
||||||
cd test/lang
|
|
||||||
export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}
|
export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}
|
||||||
|
|
||||||
|
cd test/lang
|
||||||
python3 test_openai_backend.py
|
python3 test_openai_backend.py
|
||||||
|
|
||||||
|
- name: Test SRT Backend
|
||||||
|
run: |
|
||||||
|
cd /data/zhyncs/venv && source ./bin/activate && cd -
|
||||||
|
|
||||||
|
cd test/lang
|
||||||
|
python3 test_srt_backend.py
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ from sglang.srt.utils import (
|
|||||||
assert_pkg_version,
|
assert_pkg_version,
|
||||||
enable_show_time_cost,
|
enable_show_time_cost,
|
||||||
maybe_set_triton_cache_manager,
|
maybe_set_triton_cache_manager,
|
||||||
|
kill_child_process,
|
||||||
set_ulimit,
|
set_ulimit,
|
||||||
)
|
)
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
@@ -467,16 +468,7 @@ class Runtime:
|
|||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
if self.pid is not None:
|
if self.pid is not None:
|
||||||
try:
|
kill_child_process(self.pid)
|
||||||
parent = psutil.Process(self.pid)
|
|
||||||
except psutil.NoSuchProcess:
|
|
||||||
return
|
|
||||||
children = parent.children(recursive=True)
|
|
||||||
for child in children:
|
|
||||||
child.kill()
|
|
||||||
psutil.wait_procs(children, timeout=5)
|
|
||||||
parent.kill()
|
|
||||||
parent.wait(timeout=5)
|
|
||||||
self.pid = None
|
self.pid = None
|
||||||
|
|
||||||
def cache_prefix(self, prefix: str):
|
def cache_prefix(self, prefix: str):
|
||||||
|
|||||||
@@ -366,6 +366,26 @@ def kill_parent_process():
|
|||||||
os.kill(parent_process.pid, 9)
|
os.kill(parent_process.pid, 9)
|
||||||
|
|
||||||
|
|
||||||
|
def kill_child_process(pid, including_parent=True):
|
||||||
|
try:
|
||||||
|
parent = psutil.Process(pid)
|
||||||
|
except psutil.NoSuchProcess:
|
||||||
|
return
|
||||||
|
|
||||||
|
children = parent.children(recursive=True)
|
||||||
|
for child in children:
|
||||||
|
try:
|
||||||
|
child.kill()
|
||||||
|
except psutil.NoSuchProcess:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if including_parent:
|
||||||
|
try:
|
||||||
|
parent.kill()
|
||||||
|
except psutil.NoSuchProcess:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def monkey_patch_vllm_p2p_access_check(gpu_id: int):
|
def monkey_patch_vllm_p2p_access_check(gpu_id: int):
|
||||||
"""
|
"""
|
||||||
Monkey patch the slow p2p access check in vllm.
|
Monkey patch the slow p2p access check in vllm.
|
||||||
|
|||||||
@@ -105,15 +105,14 @@ def test_decode_json_regex():
|
|||||||
def decode_json(s):
|
def decode_json(s):
|
||||||
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING
|
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING
|
||||||
|
|
||||||
s += "Generate a JSON object to describe the basic information of a city.\n"
|
s += "Generate a JSON object to describe the basic city information of Paris.\n"
|
||||||
|
|
||||||
with s.var_scope("json_output"):
|
with s.var_scope("json_output"):
|
||||||
s += "{\n"
|
s += "{\n"
|
||||||
s += ' "name": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n"
|
s += ' "name": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n"
|
||||||
s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
|
s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
|
||||||
s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
|
s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
|
||||||
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT + ",") + "\n"
|
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT) + "\n"
|
||||||
s += ' "country": ' + sgl.gen(regex=REGEX_STRING) + "\n"
|
|
||||||
s += "}"
|
s += "}"
|
||||||
|
|
||||||
ret = decode_json.run(temperature=0.0)
|
ret = decode_json.run(temperature=0.0)
|
||||||
@@ -129,7 +128,7 @@ def test_decode_json_regex():
|
|||||||
def test_decode_json():
|
def test_decode_json():
|
||||||
@sgl.function
|
@sgl.function
|
||||||
def decode_json(s):
|
def decode_json(s):
|
||||||
s += "Generate a JSON object to describe the basic information of a city.\n"
|
s += "Generate a JSON object to describe the basic city information of Paris.\n"
|
||||||
|
|
||||||
with s.var_scope("json_output"):
|
with s.var_scope("json_output"):
|
||||||
s += "{\n"
|
s += "{\n"
|
||||||
@@ -264,6 +263,7 @@ def test_parallel_decoding():
|
|||||||
s += "\nIn summary," + sgl.gen("summary", max_tokens=512)
|
s += "\nIn summary," + sgl.gen("summary", max_tokens=512)
|
||||||
|
|
||||||
ret = parallel_decoding.run(topic="writing a good blog post", temperature=0.3)
|
ret = parallel_decoding.run(topic="writing a good blog post", temperature=0.3)
|
||||||
|
assert isinstance(ret["summary"], str)
|
||||||
|
|
||||||
|
|
||||||
def test_parallel_encoding(check_answer=True):
|
def test_parallel_encoding(check_answer=True):
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ class TestSRTBackend(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.backend = sgl.Runtime(model_path="meta-llama/Meta-Llama-3-8B-Instruct")
|
cls.backend = sgl.Runtime(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct")
|
||||||
sgl.set_default_backend(cls.backend)
|
sgl.set_default_backend(cls.backend)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
../lang/example_image.png
|
|
||||||
209
test/srt/old/test_openai_server.py
Normal file
209
test/srt/old/test_openai_server.py
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
"""
|
||||||
|
First run the following command to launch the server.
|
||||||
|
Note that TinyLlama adopts different chat templates in different versions.
|
||||||
|
For v0.4, the chat template is chatml.
|
||||||
|
|
||||||
|
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 \
|
||||||
|
--port 30000 --chat-template chatml
|
||||||
|
|
||||||
|
Output example:
|
||||||
|
The capital of France is Paris.
|
||||||
|
The capital of the United States is Washington, D.C.
|
||||||
|
The capital of Canada is Ottawa.
|
||||||
|
The capital of Japan is Tokyo
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
|
||||||
|
import openai
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion(args, echo, logprobs):
|
||||||
|
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
||||||
|
response = client.completions.create(
|
||||||
|
model="default",
|
||||||
|
prompt="The capital of France is",
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=32,
|
||||||
|
echo=echo,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
text = response.choices[0].text
|
||||||
|
print(response.choices[0].text)
|
||||||
|
if echo:
|
||||||
|
assert text.startswith("The capital of France is")
|
||||||
|
if logprobs:
|
||||||
|
print(response.choices[0].logprobs.top_logprobs)
|
||||||
|
assert response.choices[0].logprobs
|
||||||
|
if echo:
|
||||||
|
assert response.choices[0].logprobs.token_logprobs[0] == None
|
||||||
|
else:
|
||||||
|
assert response.choices[0].logprobs.token_logprobs[0] != None
|
||||||
|
assert response.id
|
||||||
|
assert response.created
|
||||||
|
assert response.usage.prompt_tokens > 0
|
||||||
|
assert response.usage.completion_tokens > 0
|
||||||
|
assert response.usage.total_tokens > 0
|
||||||
|
print("=" * 100)
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_stream(args, echo, logprobs):
|
||||||
|
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
||||||
|
response = client.completions.create(
|
||||||
|
model="default",
|
||||||
|
prompt="The capital of France is",
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=32,
|
||||||
|
stream=True,
|
||||||
|
echo=echo,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
first = True
|
||||||
|
for r in response:
|
||||||
|
if first:
|
||||||
|
if echo:
|
||||||
|
assert r.choices[0].text.startswith("The capital of France is")
|
||||||
|
first = False
|
||||||
|
if logprobs:
|
||||||
|
print(
|
||||||
|
f"{r.choices[0].text:12s}\t" f"{r.choices[0].logprobs.token_logprobs}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
print(r.choices[0].logprobs.top_logprobs)
|
||||||
|
else:
|
||||||
|
print(r.choices[0].text, end="", flush=True)
|
||||||
|
assert r.id
|
||||||
|
assert r.usage.prompt_tokens > 0
|
||||||
|
assert r.usage.completion_tokens > 0
|
||||||
|
assert r.usage.total_tokens > 0
|
||||||
|
print("=" * 100)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_completion(args):
|
||||||
|
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||||
|
{"role": "user", "content": "What is the capital of France?"},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=32,
|
||||||
|
)
|
||||||
|
print(response.choices[0].message.content)
|
||||||
|
assert response.id
|
||||||
|
assert response.created
|
||||||
|
assert response.usage.prompt_tokens > 0
|
||||||
|
assert response.usage.completion_tokens > 0
|
||||||
|
assert response.usage.total_tokens > 0
|
||||||
|
print("=" * 100)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_completion_image(args):
|
||||||
|
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Describe this image"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/mixtral_8x7b.jpg"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=32,
|
||||||
|
)
|
||||||
|
print(response.choices[0].message.content)
|
||||||
|
assert response.id
|
||||||
|
assert response.created
|
||||||
|
assert response.usage.prompt_tokens > 0
|
||||||
|
assert response.usage.completion_tokens > 0
|
||||||
|
assert response.usage.total_tokens > 0
|
||||||
|
print("=" * 100)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_completion_stream(args):
|
||||||
|
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||||
|
{"role": "user", "content": "List 3 countries and their capitals."},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=64,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
is_first = True
|
||||||
|
for chunk in response:
|
||||||
|
if is_first:
|
||||||
|
is_first = False
|
||||||
|
assert chunk.choices[0].delta.role == "assistant"
|
||||||
|
continue
|
||||||
|
|
||||||
|
data = chunk.choices[0].delta
|
||||||
|
if not data.content:
|
||||||
|
continue
|
||||||
|
print(data.content, end="", flush=True)
|
||||||
|
print("=" * 100)
|
||||||
|
|
||||||
|
|
||||||
|
def test_regex(args):
|
||||||
|
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
||||||
|
|
||||||
|
regex = (
|
||||||
|
r"""\{\n"""
|
||||||
|
+ r""" "name": "[\w]+",\n"""
|
||||||
|
+ r""" "population": [\d]+\n"""
|
||||||
|
+ r"""\}"""
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||||
|
{"role": "user", "content": "Introduce the capital of France."},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=128,
|
||||||
|
extra_body={"regex": regex},
|
||||||
|
)
|
||||||
|
text = response.choices[0].message.content
|
||||||
|
print(json.loads(text))
|
||||||
|
print("=" * 100)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1")
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-image", action="store_true", help="Enables testing image inputs"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
test_completion(args, echo=False, logprobs=False)
|
||||||
|
test_completion(args, echo=True, logprobs=False)
|
||||||
|
test_completion(args, echo=False, logprobs=True)
|
||||||
|
test_completion(args, echo=True, logprobs=True)
|
||||||
|
test_completion(args, echo=False, logprobs=3)
|
||||||
|
test_completion(args, echo=True, logprobs=3)
|
||||||
|
test_completion_stream(args, echo=False, logprobs=False)
|
||||||
|
test_completion_stream(args, echo=True, logprobs=False)
|
||||||
|
test_completion_stream(args, echo=False, logprobs=True)
|
||||||
|
test_completion_stream(args, echo=True, logprobs=True)
|
||||||
|
test_completion_stream(args, echo=False, logprobs=3)
|
||||||
|
test_completion_stream(args, echo=True, logprobs=3)
|
||||||
|
test_chat_completion(args)
|
||||||
|
test_chat_completion_stream(args)
|
||||||
|
test_regex(args)
|
||||||
|
if args.test_image:
|
||||||
|
test_chat_completion_image(args)
|
||||||
@@ -1,209 +1,123 @@
|
|||||||
"""
|
import subprocess
|
||||||
First run the following command to launch the server.
|
import time
|
||||||
Note that TinyLlama adopts different chat templates in different versions.
|
import unittest
|
||||||
For v0.4, the chat template is chatml.
|
|
||||||
|
|
||||||
python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 \
|
|
||||||
--port 30000 --chat-template chatml
|
|
||||||
|
|
||||||
Output example:
|
|
||||||
The capital of France is Paris.
|
|
||||||
The capital of the United States is Washington, D.C.
|
|
||||||
The capital of Canada is Ottawa.
|
|
||||||
The capital of Japan is Tokyo
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_child_process
|
||||||
|
|
||||||
|
|
||||||
def test_completion(args, echo, logprobs):
|
class TestOpenAIServer(unittest.TestCase):
|
||||||
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
|
||||||
response = client.completions.create(
|
@classmethod
|
||||||
model="default",
|
def setUpClass(cls):
|
||||||
prompt="The capital of France is",
|
model = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||||
temperature=0,
|
port = 30000
|
||||||
max_tokens=32,
|
timeout = 300
|
||||||
echo=echo,
|
|
||||||
logprobs=logprobs,
|
command = [
|
||||||
)
|
"python3", "-m", "sglang.launch_server",
|
||||||
text = response.choices[0].text
|
"--model-path", model,
|
||||||
print(response.choices[0].text)
|
"--host", "localhost",
|
||||||
if echo:
|
"--port", str(port),
|
||||||
assert text.startswith("The capital of France is")
|
]
|
||||||
if logprobs:
|
cls.process = subprocess.Popen(command, stdout=None, stderr=None)
|
||||||
print(response.choices[0].logprobs.top_logprobs)
|
cls.base_url = f"http://localhost:{port}/v1"
|
||||||
assert response.choices[0].logprobs
|
cls.model = model
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{cls.base_url}/models")
|
||||||
|
if response.status_code == 200:
|
||||||
|
return
|
||||||
|
except requests.RequestException:
|
||||||
|
pass
|
||||||
|
time.sleep(10)
|
||||||
|
raise TimeoutError("Server failed to start within the timeout period.")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_child_process(cls.process.pid)
|
||||||
|
|
||||||
|
def run_completion(self, echo, logprobs):
|
||||||
|
client = openai.Client(api_key="EMPTY", base_url=self.base_url)
|
||||||
|
prompt = "The capital of France is"
|
||||||
|
response = client.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
prompt=prompt,
|
||||||
|
temperature=0.1,
|
||||||
|
max_tokens=32,
|
||||||
|
echo=echo,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
text = response.choices[0].text
|
||||||
if echo:
|
if echo:
|
||||||
assert response.choices[0].logprobs.token_logprobs[0] == None
|
assert text.startswith(prompt)
|
||||||
else:
|
|
||||||
assert response.choices[0].logprobs.token_logprobs[0] != None
|
|
||||||
assert response.id
|
|
||||||
assert response.created
|
|
||||||
assert response.usage.prompt_tokens > 0
|
|
||||||
assert response.usage.completion_tokens > 0
|
|
||||||
assert response.usage.total_tokens > 0
|
|
||||||
print("=" * 100)
|
|
||||||
|
|
||||||
|
|
||||||
def test_completion_stream(args, echo, logprobs):
|
|
||||||
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
|
||||||
response = client.completions.create(
|
|
||||||
model="default",
|
|
||||||
prompt="The capital of France is",
|
|
||||||
temperature=0,
|
|
||||||
max_tokens=32,
|
|
||||||
stream=True,
|
|
||||||
echo=echo,
|
|
||||||
logprobs=logprobs,
|
|
||||||
)
|
|
||||||
first = True
|
|
||||||
for r in response:
|
|
||||||
if first:
|
|
||||||
if echo:
|
|
||||||
assert r.choices[0].text.startswith("The capital of France is")
|
|
||||||
first = False
|
|
||||||
if logprobs:
|
if logprobs:
|
||||||
print(
|
assert response.choices[0].logprobs
|
||||||
f"{r.choices[0].text:12s}\t" f"{r.choices[0].logprobs.token_logprobs}",
|
assert isinstance(response.choices[0].logprobs.tokens[0], str)
|
||||||
flush=True,
|
assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
|
||||||
)
|
assert len(response.choices[0].logprobs.top_logprobs[1]) == logprobs
|
||||||
print(r.choices[0].logprobs.top_logprobs)
|
if echo:
|
||||||
else:
|
assert response.choices[0].logprobs.token_logprobs[0] == None
|
||||||
print(r.choices[0].text, end="", flush=True)
|
else:
|
||||||
assert r.id
|
assert response.choices[0].logprobs.token_logprobs[0] != None
|
||||||
assert r.usage.prompt_tokens > 0
|
assert response.id
|
||||||
assert r.usage.completion_tokens > 0
|
assert response.created
|
||||||
assert r.usage.total_tokens > 0
|
assert response.usage.prompt_tokens > 0
|
||||||
print("=" * 100)
|
assert response.usage.completion_tokens > 0
|
||||||
|
assert response.usage.total_tokens > 0
|
||||||
|
|
||||||
|
def run_completion_stream(self, echo, logprobs):
|
||||||
|
client = openai.Client(api_key="EMPTY", base_url=self.base_url)
|
||||||
|
prompt = "The capital of France is"
|
||||||
|
generator = client.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
prompt=prompt,
|
||||||
|
temperature=0.1,
|
||||||
|
max_tokens=32,
|
||||||
|
echo=echo,
|
||||||
|
logprobs=logprobs,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
def test_chat_completion(args):
|
first = True
|
||||||
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
for response in generator:
|
||||||
response = client.chat.completions.create(
|
if logprobs:
|
||||||
model="default",
|
assert response.choices[0].logprobs
|
||||||
messages=[
|
assert isinstance(response.choices[0].logprobs.tokens[0], str)
|
||||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
if not (first and echo):
|
||||||
{"role": "user", "content": "What is the capital of France?"},
|
assert isinstance(response.choices[0].logprobs.top_logprobs[0], dict)
|
||||||
],
|
#assert len(response.choices[0].logprobs.top_logprobs[0]) == logprobs
|
||||||
temperature=0,
|
|
||||||
max_tokens=32,
|
|
||||||
)
|
|
||||||
print(response.choices[0].message.content)
|
|
||||||
assert response.id
|
|
||||||
assert response.created
|
|
||||||
assert response.usage.prompt_tokens > 0
|
|
||||||
assert response.usage.completion_tokens > 0
|
|
||||||
assert response.usage.total_tokens > 0
|
|
||||||
print("=" * 100)
|
|
||||||
|
|
||||||
|
if first:
|
||||||
|
if echo:
|
||||||
|
assert response.choices[0].text.startswith(prompt)
|
||||||
|
first = False
|
||||||
|
|
||||||
def test_chat_completion_image(args):
|
assert response.id
|
||||||
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
assert response.created
|
||||||
response = client.chat.completions.create(
|
assert response.usage.prompt_tokens > 0
|
||||||
model="default",
|
assert response.usage.completion_tokens > 0
|
||||||
messages=[
|
assert response.usage.total_tokens > 0
|
||||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": "Describe this image"},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/mixtral_8x7b.jpg"
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
],
|
|
||||||
temperature=0,
|
|
||||||
max_tokens=32,
|
|
||||||
)
|
|
||||||
print(response.choices[0].message.content)
|
|
||||||
assert response.id
|
|
||||||
assert response.created
|
|
||||||
assert response.usage.prompt_tokens > 0
|
|
||||||
assert response.usage.completion_tokens > 0
|
|
||||||
assert response.usage.total_tokens > 0
|
|
||||||
print("=" * 100)
|
|
||||||
|
|
||||||
|
def test_completion(self):
|
||||||
|
for echo in [False, True]:
|
||||||
|
for logprobs in [None, 5]:
|
||||||
|
self.run_completion(echo, logprobs)
|
||||||
|
|
||||||
def test_chat_completion_stream(args):
|
def test_completion_stream(self):
|
||||||
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
for echo in [True]:
|
||||||
response = client.chat.completions.create(
|
for logprobs in [5]:
|
||||||
model="default",
|
self.run_completion_stream(echo, logprobs)
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
|
||||||
{"role": "user", "content": "List 3 countries and their capitals."},
|
|
||||||
],
|
|
||||||
temperature=0,
|
|
||||||
max_tokens=64,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
is_first = True
|
|
||||||
for chunk in response:
|
|
||||||
if is_first:
|
|
||||||
is_first = False
|
|
||||||
assert chunk.choices[0].delta.role == "assistant"
|
|
||||||
continue
|
|
||||||
|
|
||||||
data = chunk.choices[0].delta
|
|
||||||
if not data.content:
|
|
||||||
continue
|
|
||||||
print(data.content, end="", flush=True)
|
|
||||||
print("=" * 100)
|
|
||||||
|
|
||||||
|
|
||||||
def test_regex(args):
|
|
||||||
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
|
||||||
|
|
||||||
regex = (
|
|
||||||
r"""\{\n"""
|
|
||||||
+ r""" "name": "[\w]+",\n"""
|
|
||||||
+ r""" "population": [\d]+\n"""
|
|
||||||
+ r"""\}"""
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.chat.completions.create(
|
|
||||||
model="default",
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
|
||||||
{"role": "user", "content": "Introduce the capital of France."},
|
|
||||||
],
|
|
||||||
temperature=0,
|
|
||||||
max_tokens=128,
|
|
||||||
extra_body={"regex": regex},
|
|
||||||
)
|
|
||||||
text = response.choices[0].message.content
|
|
||||||
print(json.loads(text))
|
|
||||||
print("=" * 100)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
# unittest.main(warnings="ignore")
|
||||||
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1")
|
|
||||||
parser.add_argument(
|
|
||||||
"--test-image", action="store_true", help="Enables testing image inputs"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
test_completion(args, echo=False, logprobs=False)
|
t = TestOpenAIServer()
|
||||||
test_completion(args, echo=True, logprobs=False)
|
t.setUpClass()
|
||||||
test_completion(args, echo=False, logprobs=True)
|
t.test_completion_stream()
|
||||||
test_completion(args, echo=True, logprobs=True)
|
t.tearDownClass()
|
||||||
test_completion(args, echo=False, logprobs=3)
|
|
||||||
test_completion(args, echo=True, logprobs=3)
|
|
||||||
test_completion_stream(args, echo=False, logprobs=False)
|
|
||||||
test_completion_stream(args, echo=True, logprobs=False)
|
|
||||||
test_completion_stream(args, echo=False, logprobs=True)
|
|
||||||
test_completion_stream(args, echo=True, logprobs=True)
|
|
||||||
test_completion_stream(args, echo=False, logprobs=3)
|
|
||||||
test_completion_stream(args, echo=True, logprobs=3)
|
|
||||||
test_chat_completion(args)
|
|
||||||
test_chat_completion_stream(args)
|
|
||||||
test_regex(args)
|
|
||||||
if args.test_image:
|
|
||||||
test_chat_completion_image(args)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user