2024-11-03 08:38:26 -08:00
"""
python3 - m unittest test_openai_server . TestOpenAIServer . test_batch
python3 - m unittest test_openai_server . TestOpenAIServer . test_completion
2025-05-19 18:21:29 -07:00
2024-11-03 08:38:26 -08:00
"""
2024-11-07 15:42:47 -08:00
2024-08-01 16:01:30 -07:00
import json
2024-12-26 18:42:41 +05:30
import re
2024-08-07 14:52:10 +08:00
import time
2024-08-01 14:34:55 -07:00
import unittest
2024-01-18 17:00:56 -08:00
import openai
2024-08-01 14:34:55 -07:00
2024-08-05 07:43:09 +08:00
from sglang . srt . hf_transformers_utils import get_tokenizer
2024-11-28 00:22:39 -08:00
from sglang . srt . utils import kill_process_tree
2024-08-11 18:27:33 -07:00
from sglang . test . test_utils import (
2025-01-10 13:14:51 -08:00
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST ,
2024-11-09 15:43:20 -08:00
DEFAULT_SMALL_MODEL_NAME_FOR_TEST ,
2024-08-25 19:02:08 -07:00
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ,
DEFAULT_URL_FOR_TEST ,
2025-03-26 07:53:12 +08:00
CustomTestCase ,
2024-08-11 18:27:33 -07:00
popen_launch_server ,
)
2024-08-01 14:34:55 -07:00
2025-03-26 07:53:12 +08:00
class TestOpenAIServer ( CustomTestCase ) :
2024-08-01 14:34:55 -07:00
@classmethod
def setUpClass ( cls ) :
2024-11-09 15:43:20 -08:00
cls . model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
2024-08-25 19:02:08 -07:00
cls . base_url = DEFAULT_URL_FOR_TEST
2024-08-04 13:35:44 -07:00
cls . api_key = " sk-123456 "
cls . process = popen_launch_server (
2024-08-25 19:02:08 -07:00
cls . model ,
cls . base_url ,
timeout = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ,
api_key = cls . api_key ,
2024-08-04 13:35:44 -07:00
)
2024-08-03 23:09:21 -07:00
cls . base_url + = " /v1 "
2024-11-09 15:43:20 -08:00
cls . tokenizer = get_tokenizer ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST )
2024-08-01 14:34:55 -07:00
@classmethod
def tearDownClass ( cls ) :
2024-11-28 00:22:39 -08:00
kill_process_tree ( cls . process . pid )
2024-08-01 14:34:55 -07:00
2024-08-05 07:43:09 +08:00
def run_completion (
2025-05-19 18:21:29 -07:00
self , echo , logprobs , use_list_input , parallel_sample_num , token_input
2024-08-05 07:43:09 +08:00
) :
2024-08-04 13:35:44 -07:00
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
2024-08-01 14:34:55 -07:00
prompt = " The capital of France is "
2024-08-05 07:43:09 +08:00
if token_input :
prompt_input = self . tokenizer . encode ( prompt )
num_prompt_tokens = len ( prompt_input )
else :
prompt_input = prompt
num_prompt_tokens = len ( self . tokenizer . encode ( prompt ) )
2024-08-01 16:01:30 -07:00
if use_list_input :
2024-08-05 07:43:09 +08:00
prompt_arg = [ prompt_input , prompt_input ]
2024-08-01 16:01:30 -07:00
num_choices = len ( prompt_arg )
2024-08-05 07:43:09 +08:00
num_prompt_tokens * = 2
2024-08-01 16:01:30 -07:00
else :
2024-08-05 07:43:09 +08:00
prompt_arg = prompt_input
2024-08-01 16:01:30 -07:00
num_choices = 1
2024-08-01 14:34:55 -07:00
response = client . completions . create (
model = self . model ,
2024-08-01 16:01:30 -07:00
prompt = prompt_arg ,
2024-08-05 07:43:09 +08:00
temperature = 0 ,
2024-08-01 14:34:55 -07:00
max_tokens = 32 ,
echo = echo ,
logprobs = logprobs ,
2024-08-05 07:43:09 +08:00
n = parallel_sample_num ,
2024-08-01 14:34:55 -07:00
)
2024-08-01 16:01:30 -07:00
2024-08-05 07:43:09 +08:00
assert len ( response . choices ) == num_choices * parallel_sample_num
2024-08-01 16:01:30 -07:00
2024-02-06 12:24:55 -08:00
if echo :
2024-08-01 16:01:30 -07:00
text = response . choices [ 0 ] . text
2024-08-01 14:34:55 -07:00
assert text . startswith ( prompt )
2024-08-05 07:43:09 +08:00
2024-02-06 12:24:55 -08:00
if logprobs :
2024-08-01 14:34:55 -07:00
assert response . choices [ 0 ] . logprobs
assert isinstance ( response . choices [ 0 ] . logprobs . tokens [ 0 ] , str )
assert isinstance ( response . choices [ 0 ] . logprobs . top_logprobs [ 1 ] , dict )
2024-08-01 16:01:30 -07:00
ret_num_top_logprobs = len ( response . choices [ 0 ] . logprobs . top_logprobs [ 1 ] )
2024-08-18 23:45:41 -07:00
2024-09-15 06:36:06 -07:00
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
2024-08-01 16:01:30 -07:00
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
2024-08-05 07:43:09 +08:00
assert ret_num_top_logprobs > 0
2024-09-15 06:36:06 -07:00
2024-11-12 15:21:20 +08:00
# when echo=True and request.logprobs>0, logprob_start_len is 0, so the first token's logprob would be None.
if not echo :
assert response . choices [ 0 ] . logprobs . token_logprobs [ 0 ]
2024-08-05 07:43:09 +08:00
2024-08-01 14:34:55 -07:00
assert response . id
assert response . created
2024-08-05 07:43:09 +08:00
assert (
response . usage . prompt_tokens == num_prompt_tokens
) , f " { response . usage . prompt_tokens } vs { num_prompt_tokens } "
2024-08-01 14:34:55 -07:00
assert response . usage . completion_tokens > 0
assert response . usage . total_tokens > 0
2024-08-20 08:06:55 -07:00
def run_completion_stream (
2025-05-19 18:21:29 -07:00
self , echo , logprobs , use_list_input , parallel_sample_num , token_input
2024-08-20 08:06:55 -07:00
) :
2024-08-04 13:35:44 -07:00
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
2024-08-01 14:34:55 -07:00
prompt = " The capital of France is "
2024-08-05 07:43:09 +08:00
if token_input :
2024-08-20 08:06:55 -07:00
prompt_input = self . tokenizer . encode ( prompt )
num_prompt_tokens = len ( prompt_input )
2024-08-05 07:43:09 +08:00
else :
2024-08-20 08:06:55 -07:00
prompt_input = prompt
num_prompt_tokens = len ( self . tokenizer . encode ( prompt ) )
if use_list_input :
prompt_arg = [ prompt_input , prompt_input ]
num_choices = len ( prompt_arg )
num_prompt_tokens * = 2
else :
prompt_arg = prompt_input
num_choices = 1
2024-08-01 14:34:55 -07:00
generator = client . completions . create (
model = self . model ,
2024-08-05 07:43:09 +08:00
prompt = prompt_arg ,
temperature = 0 ,
2024-08-01 14:34:55 -07:00
max_tokens = 32 ,
echo = echo ,
logprobs = logprobs ,
stream = True ,
2024-08-08 17:41:57 +08:00
stream_options = { " include_usage " : True } ,
2024-08-20 08:06:55 -07:00
n = parallel_sample_num ,
2024-08-01 14:34:55 -07:00
)
2024-08-20 08:06:55 -07:00
is_firsts = { }
2024-08-01 14:34:55 -07:00
for response in generator :
2024-08-08 17:41:57 +08:00
usage = response . usage
if usage is not None :
2025-05-19 18:21:29 -07:00
assert usage . prompt_tokens > 0
assert usage . completion_tokens > 0
assert usage . total_tokens > 0
2024-08-08 17:41:57 +08:00
continue
2024-08-20 08:06:55 -07:00
index = response . choices [ 0 ] . index
is_first = is_firsts . get ( index , True )
2024-08-01 14:34:55 -07:00
if logprobs :
2025-05-19 18:21:29 -07:00
assert response . choices [ 0 ] . logprobs
assert isinstance ( response . choices [ 0 ] . logprobs . tokens [ 0 ] , str )
2024-08-20 08:06:55 -07:00
if not ( is_first and echo ) :
2024-08-01 16:01:30 -07:00
assert isinstance (
response . choices [ 0 ] . logprobs . top_logprobs [ 0 ] , dict
2025-05-19 18:21:29 -07:00
)
2024-08-01 16:01:30 -07:00
ret_num_top_logprobs = len (
response . choices [ 0 ] . logprobs . top_logprobs [ 0 ]
)
2024-09-15 06:36:06 -07:00
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
2024-08-01 16:01:30 -07:00
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
2025-05-19 18:21:29 -07:00
assert ret_num_top_logprobs > 0
2024-08-01 14:34:55 -07:00
2024-08-20 08:06:55 -07:00
if is_first :
2024-08-01 14:34:55 -07:00
if echo :
2024-08-05 07:43:09 +08:00
assert response . choices [ 0 ] . text . startswith (
prompt
2024-08-20 08:06:55 -07:00
) , f " { response . choices [ 0 ] . text } and all args { echo } { logprobs } { token_input } { is_first } "
is_firsts [ index ] = False
2025-05-19 18:21:29 -07:00
assert response . id
assert response . created
2024-08-01 14:34:55 -07:00
2024-08-20 08:06:55 -07:00
for index in [ i for i in range ( parallel_sample_num * num_choices ) ] :
assert not is_firsts . get (
index , True
) , f " index { index } is not found in the response "
2025-05-19 18:21:29 -07:00
def run_chat_completion ( self , logprobs , parallel_sample_num ) :
2024-08-04 13:35:44 -07:00
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
2024-08-01 16:01:30 -07:00
response = client . chat . completions . create (
model = self . model ,
messages = [
{ " role " : " system " , " content " : " You are a helpful AI assistant " } ,
2024-08-04 22:52:41 -07:00
{
" role " : " user " ,
" content " : " What is the capital of France? Answer in a few words. " ,
} ,
2024-08-01 16:01:30 -07:00
] ,
temperature = 0 ,
logprobs = logprobs is not None and logprobs > 0 ,
top_logprobs = logprobs ,
2024-08-05 07:43:09 +08:00
n = parallel_sample_num ,
2024-08-01 16:01:30 -07:00
)
2024-08-04 20:51:55 -07:00
2024-08-01 16:01:30 -07:00
if logprobs :
assert isinstance (
response . choices [ 0 ] . logprobs . content [ 0 ] . top_logprobs [ 0 ] . token , str
)
ret_num_top_logprobs = len (
response . choices [ 0 ] . logprobs . content [ 0 ] . top_logprobs
)
assert (
ret_num_top_logprobs == logprobs
) , f " { ret_num_top_logprobs } vs { logprobs } "
2024-08-04 20:51:55 -07:00
2024-08-05 07:43:09 +08:00
assert len ( response . choices ) == parallel_sample_num
2024-08-01 16:01:30 -07:00
assert response . choices [ 0 ] . message . role == " assistant "
assert isinstance ( response . choices [ 0 ] . message . content , str )
assert response . id
assert response . created
assert response . usage . prompt_tokens > 0
assert response . usage . completion_tokens > 0
assert response . usage . total_tokens > 0
2025-05-19 18:21:29 -07:00
def run_chat_completion_stream ( self , logprobs , parallel_sample_num = 1 ) :
2024-08-04 13:35:44 -07:00
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
2024-08-01 16:01:30 -07:00
generator = client . chat . completions . create (
model = self . model ,
messages = [
{ " role " : " system " , " content " : " You are a helpful AI assistant " } ,
{ " role " : " user " , " content " : " What is the capital of France? " } ,
] ,
temperature = 0 ,
logprobs = logprobs is not None and logprobs > 0 ,
top_logprobs = logprobs ,
stream = True ,
2024-08-08 17:41:57 +08:00
stream_options = { " include_usage " : True } ,
2024-08-20 08:06:55 -07:00
n = parallel_sample_num ,
2024-08-01 16:01:30 -07:00
)
2024-08-20 08:06:55 -07:00
is_firsts = { }
2024-08-01 16:01:30 -07:00
for response in generator :
2024-08-08 17:41:57 +08:00
usage = response . usage
if usage is not None :
2025-05-19 18:21:29 -07:00
assert usage . prompt_tokens > 0
assert usage . completion_tokens > 0
assert usage . total_tokens > 0
2024-08-08 17:41:57 +08:00
continue
2024-08-20 08:06:55 -07:00
index = response . choices [ 0 ] . index
2024-08-01 16:01:30 -07:00
data = response . choices [ 0 ] . delta
2024-08-08 17:41:57 +08:00
2024-08-20 08:06:55 -07:00
if is_firsts . get ( index , True ) :
2025-05-19 18:21:29 -07:00
assert data . role == " assistant "
2024-08-20 08:06:55 -07:00
is_firsts [ index ] = False
2024-08-01 16:01:30 -07:00
continue
if logprobs :
2025-05-19 18:21:29 -07:00
assert response . choices [ 0 ] . logprobs
2024-08-05 07:43:09 +08:00
assert isinstance (
response . choices [ 0 ] . logprobs . content [ 0 ] . top_logprobs [ 0 ] . token , str
2025-05-19 18:21:29 -07:00
)
2024-08-05 07:43:09 +08:00
assert isinstance (
response . choices [ 0 ] . logprobs . content [ 0 ] . top_logprobs , list
2025-05-19 18:21:29 -07:00
)
2024-08-05 07:43:09 +08:00
ret_num_top_logprobs = len (
response . choices [ 0 ] . logprobs . content [ 0 ] . top_logprobs
)
assert (
ret_num_top_logprobs == logprobs
) , f " { ret_num_top_logprobs } vs { logprobs } "
2024-08-01 16:01:30 -07:00
2025-03-27 15:16:52 +08:00
assert (
isinstance ( data . content , str )
or isinstance ( data . reasoning_content , str )
or len ( data . tool_calls ) > 0
or response . choices [ 0 ] . finish_reason
)
2024-08-01 16:01:30 -07:00
assert response . id
assert response . created
2024-08-20 08:06:55 -07:00
for index in [ i for i in range ( parallel_sample_num ) ] :
assert not is_firsts . get (
index , True
) , f " index { index } is not found in the response "
2024-08-27 07:28:26 +08:00
def _create_batch ( self , mode , client ) :
2024-08-07 14:52:10 +08:00
if mode == " completion " :
input_file_path = " complete_input.jsonl "
# write content to input file
content = [
{
" custom_id " : " request-1 " ,
" method " : " POST " ,
" url " : " /v1/completions " ,
" body " : {
" model " : " gpt-3.5-turbo-instruct " ,
" prompt " : " List 3 names of famous soccer player: " ,
" max_tokens " : 20 ,
} ,
} ,
{
" custom_id " : " request-2 " ,
" method " : " POST " ,
" url " : " /v1/completions " ,
" body " : {
" model " : " gpt-3.5-turbo-instruct " ,
" prompt " : " List 6 names of famous basketball player: " ,
" max_tokens " : 40 ,
} ,
} ,
{
" custom_id " : " request-3 " ,
" method " : " POST " ,
" url " : " /v1/completions " ,
" body " : {
" model " : " gpt-3.5-turbo-instruct " ,
" prompt " : " List 6 names of famous tenniss player: " ,
" max_tokens " : 40 ,
} ,
} ,
]
else :
input_file_path = " chat_input.jsonl "
content = [
{
" custom_id " : " request-1 " ,
" method " : " POST " ,
" url " : " /v1/chat/completions " ,
" body " : {
" model " : " gpt-3.5-turbo-0125 " ,
" messages " : [
{
" role " : " system " ,
" content " : " You are a helpful assistant. " ,
} ,
{
" role " : " user " ,
" content " : " Hello! List 3 NBA players and tell a story " ,
} ,
] ,
" max_tokens " : 30 ,
} ,
} ,
{
" custom_id " : " request-2 " ,
" method " : " POST " ,
" url " : " /v1/chat/completions " ,
" body " : {
" model " : " gpt-3.5-turbo-0125 " ,
" messages " : [
{ " role " : " system " , " content " : " You are an assistant. " } ,
{
" role " : " user " ,
" content " : " Hello! List three capital and tell a story " ,
} ,
] ,
" max_tokens " : 50 ,
} ,
} ,
]
2024-08-27 07:28:26 +08:00
2024-08-07 14:52:10 +08:00
with open ( input_file_path , " w " ) as file :
for line in content :
file . write ( json . dumps ( line ) + " \n " )
2024-08-27 07:28:26 +08:00
2024-08-07 14:52:10 +08:00
with open ( input_file_path , " rb " ) as file :
uploaded_file = client . files . create ( file = file , purpose = " batch " )
if mode == " completion " :
endpoint = " /v1/completions "
elif mode == " chat " :
endpoint = " /v1/chat/completions "
completion_window = " 24h "
batch_job = client . batches . create (
input_file_id = uploaded_file . id ,
endpoint = endpoint ,
completion_window = completion_window ,
)
2024-08-27 07:28:26 +08:00
2024-08-26 22:04:52 -07:00
return batch_job , content , uploaded_file
2024-08-27 07:28:26 +08:00
def run_batch ( self , mode ) :
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
2024-08-26 22:04:52 -07:00
batch_job , content , uploaded_file = self . _create_batch ( mode = mode , client = client )
2024-08-27 07:28:26 +08:00
2024-08-07 14:52:10 +08:00
while batch_job . status not in [ " completed " , " failed " , " cancelled " ] :
time . sleep ( 3 )
print (
f " Batch job status: { batch_job . status } ...trying again in 3 seconds... "
)
batch_job = client . batches . retrieve ( batch_job . id )
2024-08-20 08:06:55 -07:00
assert (
batch_job . status == " completed "
) , f " Batch job status is not completed: { batch_job . status } "
2024-08-07 14:52:10 +08:00
assert batch_job . request_counts . completed == len ( content )
assert batch_job . request_counts . failed == 0
assert batch_job . request_counts . total == len ( content )
result_file_id = batch_job . output_file_id
file_response = client . files . content ( result_file_id )
2024-08-07 16:23:27 +08:00
result_content = file_response . read ( ) . decode ( " utf-8 " ) # Decode bytes to string
results = [
json . loads ( line )
for line in result_content . split ( " \n " )
if line . strip ( ) != " "
]
2024-08-07 14:52:10 +08:00
assert len ( results ) == len ( content )
2024-08-26 22:04:52 -07:00
for delete_fid in [ uploaded_file . id , result_file_id ] :
del_pesponse = client . files . delete ( delete_fid )
assert del_pesponse . deleted
2024-08-07 14:52:10 +08:00
2024-08-27 07:28:26 +08:00
def run_cancel_batch ( self , mode ) :
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
2024-08-26 22:04:52 -07:00
batch_job , _ , uploaded_file = self . _create_batch ( mode = mode , client = client )
2024-08-27 07:28:26 +08:00
assert batch_job . status not in [ " cancelling " , " cancelled " ]
batch_job = client . batches . cancel ( batch_id = batch_job . id )
assert batch_job . status == " cancelling "
while batch_job . status not in [ " failed " , " cancelled " ] :
batch_job = client . batches . retrieve ( batch_job . id )
print (
f " Batch job status: { batch_job . status } ...trying again in 3 seconds... "
)
time . sleep ( 3 )
assert batch_job . status == " cancelled "
2024-08-26 22:04:52 -07:00
del_response = client . files . delete ( uploaded_file . id )
assert del_response . deleted
2024-08-27 07:28:26 +08:00
2024-08-01 14:34:55 -07:00
def test_completion ( self ) :
2025-05-19 18:21:29 -07:00
for echo in [ False , True ] :
for logprobs in [ None , 5 ] :
for use_list_input in [ True , False ] :
for parallel_sample_num in [ 1 , 2 ] :
for token_input in [ False , True ] :
self . run_completion (
echo ,
logprobs ,
use_list_input ,
parallel_sample_num ,
token_input ,
)
2024-08-01 14:34:55 -07:00
def test_completion_stream ( self ) :
2025-05-11 00:55:00 -04:00
# parallel sampling and list input are not supported in streaming mode
2025-05-19 18:21:29 -07:00
for echo in [ False , True ] :
for logprobs in [ None , 5 ] :
for use_list_input in [ True , False ] :
for parallel_sample_num in [ 1 , 2 ] :
for token_input in [ False , True ] :
self . run_completion_stream (
echo ,
logprobs ,
use_list_input ,
parallel_sample_num ,
token_input ,
)
2024-02-10 17:21:33 -08:00
2024-08-01 16:01:30 -07:00
def test_chat_completion ( self ) :
2025-05-19 18:21:29 -07:00
for logprobs in [ None , 5 ] :
for parallel_sample_num in [ 1 , 2 ] :
self . run_chat_completion ( logprobs , parallel_sample_num )
2024-08-01 16:01:30 -07:00
def test_chat_completion_stream ( self ) :
2025-05-19 18:21:29 -07:00
for logprobs in [ None , 5 ] :
for parallel_sample_num in [ 1 , 2 ] :
self . run_chat_completion_stream ( logprobs , parallel_sample_num )
2024-08-01 16:01:30 -07:00
2024-08-07 14:52:10 +08:00
def test_batch ( self ) :
for mode in [ " completion " , " chat " ] :
self . run_batch ( mode )
2024-09-22 06:46:17 -07:00
def test_cancel_batch ( self ) :
2024-08-27 07:28:26 +08:00
for mode in [ " completion " , " chat " ] :
self . run_cancel_batch ( mode )
2024-08-01 16:01:30 -07:00
def test_regex ( self ) :
2024-08-04 13:35:44 -07:00
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
2024-08-01 16:01:30 -07:00
regex = (
r """ \ { \ n """
+ r """ " name " : " [ \ w]+ " , \ n """
+ r """ " population " : [ \ d]+ \ n """
+ r """ \ } """
)
response = client . chat . completions . create (
model = self . model ,
messages = [
{ " role " : " system " , " content " : " You are a helpful AI assistant " } ,
{ " role " : " user " , " content " : " Introduce the capital of France. " } ,
] ,
temperature = 0 ,
max_tokens = 128 ,
extra_body = { " regex " : regex } ,
)
text = response . choices [ 0 ] . message . content
try :
js_obj = json . loads ( text )
except ( TypeError , json . decoder . JSONDecodeError ) :
print ( " JSONDecodeError " , text )
raise
assert isinstance ( js_obj [ " name " ] , str )
assert isinstance ( js_obj [ " population " ] , int )
2024-09-15 06:36:06 -07:00
def test_penalty ( self ) :
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
response = client . chat . completions . create (
model = self . model ,
messages = [
{ " role " : " system " , " content " : " You are a helpful AI assistant " } ,
{ " role " : " user " , " content " : " Introduce the capital of France. " } ,
] ,
temperature = 0 ,
max_tokens = 32 ,
frequency_penalty = 1.0 ,
)
text = response . choices [ 0 ] . message . content
assert isinstance ( text , str )
2024-09-22 06:46:17 -07:00
def test_response_prefill ( self ) :
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
response = client . chat . completions . create (
2024-10-02 10:12:07 -07:00
model = " meta-llama/Llama-3.1-8B-Instruct " ,
2024-09-22 06:46:17 -07:00
messages = [
{ " role " : " system " , " content " : " You are a helpful AI assistant " } ,
{
" role " : " user " ,
" content " : """
Extract the name , size , price , and color from this product description as a JSON object :
< description >
The SmartHome Mini is a compact smart home assistant available in black or white for only $ 49.99 . At just 5 inches wide , it lets you control lights , thermostats , and other connected devices via voice or app — no matter where you place it in your home . This affordable little hub brings convenient hands - free control to your smart devices .
< / description >
""" ,
} ,
{
" role " : " assistant " ,
" content " : " { \n " ,
} ,
] ,
temperature = 0 ,
2025-04-21 06:07:18 +05:30
extra_body = { " continue_final_message " : True } ,
2024-09-22 06:46:17 -07:00
)
assert (
response . choices [ 0 ]
. message . content . strip ( )
. startswith ( ' " name " : " SmartHome Mini " , ' )
)
2025-03-27 23:23:18 -04:00
def test_model_list ( self ) :
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
models = list ( client . models . list ( ) )
assert len ( models ) == 1
assert isinstance ( getattr ( models [ 0 ] , " max_model_len " , None ) , int )
2024-02-10 17:21:33 -08:00
2024-12-26 18:42:41 +05:30
# -------------------------------------------------------------------------
# EBNF Test Class: TestOpenAIServerEBNF
# Launches the server with xgrammar, has only EBNF tests
# -------------------------------------------------------------------------
2025-03-26 07:53:12 +08:00
class TestOpenAIServerEBNF ( CustomTestCase ) :
2024-12-26 18:42:41 +05:30
@classmethod
def setUpClass ( cls ) :
cls . model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls . base_url = DEFAULT_URL_FOR_TEST
cls . api_key = " sk-123456 "
# passing xgrammar specifically
other_args = [ " --grammar-backend " , " xgrammar " ]
cls . process = popen_launch_server (
cls . model ,
cls . base_url ,
timeout = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ,
api_key = cls . api_key ,
other_args = other_args ,
)
cls . base_url + = " /v1 "
cls . tokenizer = get_tokenizer ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST )
@classmethod
def tearDownClass ( cls ) :
kill_process_tree ( cls . process . pid )
def test_ebnf ( self ) :
"""
Ensure we can pass ` ebnf ` to the local openai server
and that it enforces the grammar .
"""
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
ebnf_grammar = r """
root : := " Hello " | " Hi " | " Hey "
"""
pattern = re . compile ( r " ^(Hello|Hi|Hey)[.!?]* \ s*$ " )
response = client . chat . completions . create (
model = self . model ,
messages = [
{ " role " : " system " , " content " : " You are a helpful EBNF test bot. " } ,
{ " role " : " user " , " content " : " Say a greeting (Hello, Hi, or Hey). " } ,
] ,
temperature = 0 ,
max_tokens = 32 ,
extra_body = { " ebnf " : ebnf_grammar } ,
)
text = response . choices [ 0 ] . message . content . strip ( )
print ( " EBNF test output: " , repr ( text ) )
self . assertTrue ( len ( text ) > 0 , " Got empty text from EBNF generation " )
self . assertRegex ( text , pattern , f " Text ' { text } ' doesn ' t match EBNF choices " )
def test_ebnf_strict_json ( self ) :
"""
A stricter EBNF that produces exactly { " name " : " Alice " } format
with no trailing punctuation or extra fields .
"""
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
ebnf_grammar = r """
root : := " { " pair " } "
pair : := " \" name \" " " : " string
string : := " \" " [ A - Za - z ] + " \" "
"""
pattern = re . compile ( r ' ^ \ { " name " : " [A-Za-z]+ " \ }$ ' )
response = client . chat . completions . create (
model = self . model ,
messages = [
{ " role " : " system " , " content " : " EBNF mini-JSON generator. " } ,
{
" role " : " user " ,
" content " : " Generate single key JSON with only letters. " ,
} ,
] ,
temperature = 0 ,
max_tokens = 64 ,
extra_body = { " ebnf " : ebnf_grammar } ,
)
text = response . choices [ 0 ] . message . content . strip ( )
print ( " EBNF strict JSON test output: " , repr ( text ) )
self . assertTrue ( len ( text ) > 0 , " Got empty text from EBNF strict JSON test " )
self . assertRegex (
text , pattern , f " Text ' { text } ' not matching the EBNF strict JSON shape "
)
2025-03-26 07:53:12 +08:00
class TestOpenAIEmbedding ( CustomTestCase ) :
2025-01-10 13:14:51 -08:00
@classmethod
def setUpClass ( cls ) :
cls . model = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
cls . base_url = DEFAULT_URL_FOR_TEST
cls . api_key = " sk-123456 "
# Configure embedding-specific args
other_args = [ " --is-embedding " , " --enable-metrics " ]
cls . process = popen_launch_server (
cls . model ,
cls . base_url ,
timeout = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ,
api_key = cls . api_key ,
other_args = other_args ,
)
cls . base_url + = " /v1 "
@classmethod
def tearDownClass ( cls ) :
kill_process_tree ( cls . process . pid )
def test_embedding_single ( self ) :
""" Test single embedding request """
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
response = client . embeddings . create ( model = self . model , input = " Hello world " )
self . assertEqual ( len ( response . data ) , 1 )
self . assertTrue ( len ( response . data [ 0 ] . embedding ) > 0 )
def test_embedding_batch ( self ) :
""" Test batch embedding request """
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
response = client . embeddings . create (
model = self . model , input = [ " Hello world " , " Test text " ]
)
self . assertEqual ( len ( response . data ) , 2 )
self . assertTrue ( len ( response . data [ 0 ] . embedding ) > 0 )
self . assertTrue ( len ( response . data [ 1 ] . embedding ) > 0 )
2025-05-11 20:47:15 +05:30
def test_empty_string_embedding ( self ) :
""" Test embedding an empty string. """
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
# Text embedding example with empty string
text = " "
# Expect a BadRequestError for empty input
with self . assertRaises ( openai . BadRequestError ) as cm :
client . embeddings . create (
model = self . model ,
input = text ,
)
# check the status code
self . assertEqual ( cm . exception . status_code , 400 )
2025-01-10 13:14:51 -08:00
2025-04-15 17:09:45 -07:00
class TestOpenAIServerIgnoreEOS ( CustomTestCase ) :
@classmethod
def setUpClass ( cls ) :
cls . model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls . base_url = DEFAULT_URL_FOR_TEST
cls . api_key = " sk-123456 "
cls . process = popen_launch_server (
cls . model ,
cls . base_url ,
timeout = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ,
api_key = cls . api_key ,
)
cls . base_url + = " /v1 "
cls . tokenizer = get_tokenizer ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST )
@classmethod
def tearDownClass ( cls ) :
kill_process_tree ( cls . process . pid )
def test_ignore_eos ( self ) :
"""
Test that ignore_eos = True allows generation to continue beyond EOS token
and reach the max_tokens limit .
"""
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
max_tokens = 200
response_default = client . chat . completions . create (
model = self . model ,
messages = [
{ " role " : " system " , " content " : " You are a helpful assistant. " } ,
{ " role " : " user " , " content " : " Count from 1 to 20. " } ,
] ,
temperature = 0 ,
max_tokens = max_tokens ,
extra_body = { " ignore_eos " : False } ,
)
response_ignore_eos = client . chat . completions . create (
model = self . model ,
messages = [
{ " role " : " system " , " content " : " You are a helpful assistant. " } ,
{ " role " : " user " , " content " : " Count from 1 to 20. " } ,
] ,
temperature = 0 ,
max_tokens = max_tokens ,
extra_body = { " ignore_eos " : True } ,
)
default_tokens = len (
self . tokenizer . encode ( response_default . choices [ 0 ] . message . content )
)
ignore_eos_tokens = len (
self . tokenizer . encode ( response_ignore_eos . choices [ 0 ] . message . content )
)
# Check if ignore_eos resulted in more tokens or exactly max_tokens
# The ignore_eos response should either:
# 1. Have more tokens than the default response (if default stopped at EOS before max_tokens)
# 2. Have exactly max_tokens (if it reached the max_tokens limit)
self . assertTrue (
ignore_eos_tokens > default_tokens or ignore_eos_tokens > = max_tokens ,
f " ignore_eos did not generate more tokens: { ignore_eos_tokens } vs { default_tokens } " ,
)
self . assertEqual (
response_ignore_eos . choices [ 0 ] . finish_reason ,
" length " ,
f " Expected finish_reason= ' length ' for ignore_eos=True, got { response_ignore_eos . choices [ 0 ] . finish_reason } " ,
)
2024-01-18 17:00:56 -08:00
if __name__ == " __main__ " :
2024-08-10 15:09:03 -07:00
unittest . main ( )