2024-08-01 16:01:30 -07:00
import json
2024-08-01 14:34:55 -07:00
import unittest
2024-01-18 17:00:56 -08:00
import openai
2024-08-01 14:34:55 -07:00
2024-08-05 07:43:09 +08:00
from sglang . srt . hf_transformers_utils import get_tokenizer
2024-08-01 14:34:55 -07:00
from sglang . srt . utils import kill_child_process
2024-08-01 21:20:17 -07:00
from sglang . test . test_utils import MODEL_NAME_FOR_TEST , popen_launch_server
2024-08-01 14:34:55 -07:00
class TestOpenAIServer ( unittest . TestCase ) :
@classmethod
def setUpClass ( cls ) :
2024-08-01 21:20:17 -07:00
cls . model = MODEL_NAME_FOR_TEST
2024-08-04 16:02:05 -07:00
cls . base_url = f " http://localhost:8157 "
2024-08-04 13:35:44 -07:00
cls . api_key = " sk-123456 "
cls . process = popen_launch_server (
cls . model , cls . base_url , timeout = 300 , api_key = cls . api_key
)
2024-08-03 23:09:21 -07:00
cls . base_url + = " /v1 "
2024-08-05 07:43:09 +08:00
cls . tokenizer = get_tokenizer ( MODEL_NAME_FOR_TEST )
2024-08-01 14:34:55 -07:00
@classmethod
def tearDownClass ( cls ) :
kill_child_process ( cls . process . pid )
2024-08-05 07:43:09 +08:00
def run_completion (
self , echo , logprobs , use_list_input , parallel_sample_num , token_input
) :
2024-08-04 13:35:44 -07:00
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
2024-08-01 14:34:55 -07:00
prompt = " The capital of France is "
2024-08-05 07:43:09 +08:00
if token_input :
prompt_input = self . tokenizer . encode ( prompt )
num_prompt_tokens = len ( prompt_input )
else :
prompt_input = prompt
num_prompt_tokens = len ( self . tokenizer . encode ( prompt ) )
2024-08-01 16:01:30 -07:00
if use_list_input :
2024-08-05 07:43:09 +08:00
prompt_arg = [ prompt_input , prompt_input ]
2024-08-01 16:01:30 -07:00
num_choices = len ( prompt_arg )
2024-08-05 07:43:09 +08:00
num_prompt_tokens * = 2
2024-08-01 16:01:30 -07:00
else :
2024-08-05 07:43:09 +08:00
prompt_arg = prompt_input
2024-08-01 16:01:30 -07:00
num_choices = 1
2024-08-05 07:43:09 +08:00
if parallel_sample_num :
# FIXME: This is wrong. We should not count the prompt tokens multiple times for
# parallel sampling.
num_prompt_tokens * = parallel_sample_num
2024-08-01 14:34:55 -07:00
response = client . completions . create (
model = self . model ,
2024-08-01 16:01:30 -07:00
prompt = prompt_arg ,
2024-08-05 07:43:09 +08:00
temperature = 0 ,
2024-08-01 14:34:55 -07:00
max_tokens = 32 ,
echo = echo ,
logprobs = logprobs ,
2024-08-05 07:43:09 +08:00
n = parallel_sample_num ,
2024-08-01 14:34:55 -07:00
)
2024-08-01 16:01:30 -07:00
2024-08-05 07:43:09 +08:00
assert len ( response . choices ) == num_choices * parallel_sample_num
2024-08-01 16:01:30 -07:00
2024-02-06 12:24:55 -08:00
if echo :
2024-08-01 16:01:30 -07:00
text = response . choices [ 0 ] . text
2024-08-01 14:34:55 -07:00
assert text . startswith ( prompt )
2024-08-05 07:43:09 +08:00
2024-02-06 12:24:55 -08:00
if logprobs :
2024-08-01 14:34:55 -07:00
assert response . choices [ 0 ] . logprobs
assert isinstance ( response . choices [ 0 ] . logprobs . tokens [ 0 ] , str )
assert isinstance ( response . choices [ 0 ] . logprobs . top_logprobs [ 1 ] , dict )
2024-08-01 16:01:30 -07:00
ret_num_top_logprobs = len ( response . choices [ 0 ] . logprobs . top_logprobs [ 1 ] )
2024-08-05 07:43:09 +08:00
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
2024-08-01 16:01:30 -07:00
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
2024-08-05 07:43:09 +08:00
assert ret_num_top_logprobs > 0
2024-08-01 14:34:55 -07:00
if echo :
assert response . choices [ 0 ] . logprobs . token_logprobs [ 0 ] == None
else :
assert response . choices [ 0 ] . logprobs . token_logprobs [ 0 ] != None
2024-08-05 07:43:09 +08:00
2024-08-01 14:34:55 -07:00
assert response . id
assert response . created
2024-08-05 07:43:09 +08:00
assert (
response . usage . prompt_tokens == num_prompt_tokens
) , f " { response . usage . prompt_tokens } vs { num_prompt_tokens } "
2024-08-01 14:34:55 -07:00
assert response . usage . completion_tokens > 0
assert response . usage . total_tokens > 0
2024-08-05 07:43:09 +08:00
def run_completion_stream ( self , echo , logprobs , token_input ) :
2024-08-04 13:35:44 -07:00
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
2024-08-01 14:34:55 -07:00
prompt = " The capital of France is "
2024-08-05 07:43:09 +08:00
if token_input :
prompt_arg = self . tokenizer . encode ( prompt )
else :
prompt_arg = prompt
2024-08-01 14:34:55 -07:00
generator = client . completions . create (
model = self . model ,
2024-08-05 07:43:09 +08:00
prompt = prompt_arg ,
temperature = 0 ,
2024-08-01 14:34:55 -07:00
max_tokens = 32 ,
echo = echo ,
logprobs = logprobs ,
stream = True ,
)
first = True
for response in generator :
if logprobs :
assert response . choices [ 0 ] . logprobs
assert isinstance ( response . choices [ 0 ] . logprobs . tokens [ 0 ] , str )
if not ( first and echo ) :
2024-08-01 16:01:30 -07:00
assert isinstance (
response . choices [ 0 ] . logprobs . top_logprobs [ 0 ] , dict
)
ret_num_top_logprobs = len (
response . choices [ 0 ] . logprobs . top_logprobs [ 0 ]
)
2024-08-05 07:43:09 +08:00
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
2024-08-01 16:01:30 -07:00
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
2024-08-05 07:43:09 +08:00
assert ret_num_top_logprobs > 0
2024-08-01 14:34:55 -07:00
if first :
if echo :
2024-08-05 07:43:09 +08:00
assert response . choices [ 0 ] . text . startswith (
prompt
) , f " { response . choices [ 0 ] . text } and all args { echo } { logprobs } { token_input } { first } "
2024-08-01 14:34:55 -07:00
first = False
assert response . id
assert response . created
assert response . usage . prompt_tokens > 0
assert response . usage . completion_tokens > 0
assert response . usage . total_tokens > 0
2024-08-05 07:43:09 +08:00
def run_chat_completion ( self , logprobs , parallel_sample_num ) :
2024-08-04 13:35:44 -07:00
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
2024-08-01 16:01:30 -07:00
response = client . chat . completions . create (
model = self . model ,
messages = [
{ " role " : " system " , " content " : " You are a helpful AI assistant " } ,
{ " role " : " user " , " content " : " What is the capital of France? " } ,
] ,
temperature = 0 ,
max_tokens = 32 ,
logprobs = logprobs is not None and logprobs > 0 ,
top_logprobs = logprobs ,
2024-08-05 07:43:09 +08:00
n = parallel_sample_num ,
2024-08-01 16:01:30 -07:00
)
if logprobs :
assert isinstance (
response . choices [ 0 ] . logprobs . content [ 0 ] . top_logprobs [ 0 ] . token , str
)
ret_num_top_logprobs = len (
response . choices [ 0 ] . logprobs . content [ 0 ] . top_logprobs
)
assert (
ret_num_top_logprobs == logprobs
) , f " { ret_num_top_logprobs } vs { logprobs } "
2024-08-05 07:43:09 +08:00
assert len ( response . choices ) == parallel_sample_num
2024-08-01 16:01:30 -07:00
assert response . choices [ 0 ] . message . role == " assistant "
assert isinstance ( response . choices [ 0 ] . message . content , str )
assert response . id
assert response . created
assert response . usage . prompt_tokens > 0
assert response . usage . completion_tokens > 0
assert response . usage . total_tokens > 0
def run_chat_completion_stream ( self , logprobs ) :
2024-08-04 13:35:44 -07:00
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
2024-08-01 16:01:30 -07:00
generator = client . chat . completions . create (
model = self . model ,
messages = [
{ " role " : " system " , " content " : " You are a helpful AI assistant " } ,
{ " role " : " user " , " content " : " What is the capital of France? " } ,
] ,
temperature = 0 ,
max_tokens = 32 ,
logprobs = logprobs is not None and logprobs > 0 ,
top_logprobs = logprobs ,
stream = True ,
)
is_first = True
for response in generator :
data = response . choices [ 0 ] . delta
if is_first :
data . role == " assistant "
is_first = False
continue
if logprobs :
2024-08-05 07:43:09 +08:00
assert response . choices [ 0 ] . logprobs
assert isinstance (
response . choices [ 0 ] . logprobs . content [ 0 ] . top_logprobs [ 0 ] . token , str
)
assert isinstance (
response . choices [ 0 ] . logprobs . content [ 0 ] . top_logprobs , list
)
ret_num_top_logprobs = len (
response . choices [ 0 ] . logprobs . content [ 0 ] . top_logprobs
)
assert (
ret_num_top_logprobs == logprobs
) , f " { ret_num_top_logprobs } vs { logprobs } "
2024-08-01 16:01:30 -07:00
assert isinstance ( data . content , str )
assert response . id
assert response . created
2024-08-01 14:34:55 -07:00
def test_completion ( self ) :
for echo in [ False , True ] :
for logprobs in [ None , 5 ] :
2024-08-01 16:01:30 -07:00
for use_list_input in [ True , False ] :
2024-08-05 07:43:09 +08:00
for parallel_sample_num in [ 1 , 2 ] :
for token_input in [ False , True ] :
self . run_completion (
echo ,
logprobs ,
use_list_input ,
parallel_sample_num ,
token_input ,
)
2024-08-01 14:34:55 -07:00
def test_completion_stream ( self ) :
2024-08-05 07:43:09 +08:00
# parallel sampling adn list input are not supported in streaming mode
2024-08-01 16:01:30 -07:00
for echo in [ False , True ] :
for logprobs in [ None , 5 ] :
2024-08-05 07:43:09 +08:00
for token_input in [ False , True ] :
self . run_completion_stream ( echo , logprobs , token_input )
2024-02-10 17:21:33 -08:00
2024-08-01 16:01:30 -07:00
def test_chat_completion ( self ) :
for logprobs in [ None , 5 ] :
2024-08-05 07:43:09 +08:00
for parallel_sample_num in [ 1 , 2 ] :
self . run_chat_completion ( logprobs , parallel_sample_num )
2024-08-01 16:01:30 -07:00
def test_chat_completion_stream ( self ) :
for logprobs in [ None , 5 ] :
self . run_chat_completion_stream ( logprobs )
def test_regex ( self ) :
2024-08-04 13:35:44 -07:00
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
2024-08-01 16:01:30 -07:00
regex = (
r """ \ { \ n """
+ r """ " name " : " [ \ w]+ " , \ n """
+ r """ " population " : [ \ d]+ \ n """
+ r """ \ } """
)
response = client . chat . completions . create (
model = self . model ,
messages = [
{ " role " : " system " , " content " : " You are a helpful AI assistant " } ,
{ " role " : " user " , " content " : " Introduce the capital of France. " } ,
] ,
temperature = 0 ,
max_tokens = 128 ,
extra_body = { " regex " : regex } ,
)
text = response . choices [ 0 ] . message . content
try :
js_obj = json . loads ( text )
except ( TypeError , json . decoder . JSONDecodeError ) :
print ( " JSONDecodeError " , text )
raise
assert isinstance ( js_obj [ " name " ] , str )
assert isinstance ( js_obj [ " population " ] , int )
2024-02-10 17:21:33 -08:00
2024-01-18 17:00:56 -08:00
if __name__ == " __main__ " :
2024-08-01 16:01:30 -07:00
unittest . main ( warnings = " ignore " )
2024-08-01 14:34:55 -07:00
2024-08-01 16:01:30 -07:00
# t = TestOpenAIServer()
# t.setUpClass()
2024-08-05 07:43:09 +08:00
# t.test_completion()
2024-08-01 16:01:30 -07:00
# t.tearDownClass()