2024-11-03 08:38:26 -08:00
"""
2025-06-21 19:37:48 -07:00
python3 - m unittest openai_server . basic . test_openai_server . TestOpenAIServer . test_completion
python3 - m unittest openai_server . basic . test_openai_server . TestOpenAIServer . test_completion_stream
python3 - m unittest openai_server . basic . test_openai_server . TestOpenAIServer . test_chat_completion
python3 - m unittest openai_server . basic . test_openai_server . TestOpenAIServer . test_chat_completion_stream
2024-11-03 08:38:26 -08:00
"""
2024-11-07 15:42:47 -08:00
2024-08-01 16:01:30 -07:00
import json
2024-12-26 18:42:41 +05:30
import re
2024-08-01 14:34:55 -07:00
import unittest
2024-01-18 17:00:56 -08:00
2025-06-10 17:37:29 -04:00
import numpy as np
2024-01-18 17:00:56 -08:00
import openai
2025-06-04 14:14:54 -07:00
import requests
2024-08-01 14:34:55 -07:00
2024-08-05 07:43:09 +08:00
from sglang . srt . hf_transformers_utils import get_tokenizer
2024-11-28 00:22:39 -08:00
from sglang . srt . utils import kill_process_tree
2025-06-17 01:50:01 +08:00
from sglang . test . runners import TEST_RERANK_QUERY_DOCS
2024-08-11 18:27:33 -07:00
from sglang . test . test_utils import (
2025-06-17 01:50:01 +08:00
DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST ,
2024-11-09 15:43:20 -08:00
DEFAULT_SMALL_MODEL_NAME_FOR_TEST ,
2024-08-25 19:02:08 -07:00
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ,
DEFAULT_URL_FOR_TEST ,
2025-03-26 07:53:12 +08:00
CustomTestCase ,
2024-08-11 18:27:33 -07:00
popen_launch_server ,
)
2024-08-01 14:34:55 -07:00
2025-03-26 07:53:12 +08:00
class TestOpenAIServer ( CustomTestCase ) :
2024-08-01 14:34:55 -07:00
@classmethod
def setUpClass ( cls ) :
2024-11-09 15:43:20 -08:00
cls . model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
2024-08-25 19:02:08 -07:00
cls . base_url = DEFAULT_URL_FOR_TEST
2024-08-04 13:35:44 -07:00
cls . api_key = " sk-123456 "
cls . process = popen_launch_server (
2024-08-25 19:02:08 -07:00
cls . model ,
cls . base_url ,
timeout = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ,
api_key = cls . api_key ,
2024-08-04 13:35:44 -07:00
)
2024-08-03 23:09:21 -07:00
cls . base_url + = " /v1 "
2024-11-09 15:43:20 -08:00
cls . tokenizer = get_tokenizer ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST )
2024-08-01 14:34:55 -07:00
@classmethod
def tearDownClass ( cls ) :
2024-11-28 00:22:39 -08:00
kill_process_tree ( cls . process . pid )
2024-08-01 14:34:55 -07:00
2024-08-05 07:43:09 +08:00
def run_completion (
2025-05-19 18:21:29 -07:00
self , echo , logprobs , use_list_input , parallel_sample_num , token_input
2024-08-05 07:43:09 +08:00
) :
2024-08-04 13:35:44 -07:00
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
2024-08-01 14:34:55 -07:00
prompt = " The capital of France is "
2024-08-05 07:43:09 +08:00
if token_input :
prompt_input = self . tokenizer . encode ( prompt )
num_prompt_tokens = len ( prompt_input )
else :
prompt_input = prompt
num_prompt_tokens = len ( self . tokenizer . encode ( prompt ) )
2024-08-01 16:01:30 -07:00
if use_list_input :
2024-08-05 07:43:09 +08:00
prompt_arg = [ prompt_input , prompt_input ]
2024-08-01 16:01:30 -07:00
num_choices = len ( prompt_arg )
2024-08-05 07:43:09 +08:00
num_prompt_tokens * = 2
2024-08-01 16:01:30 -07:00
else :
2024-08-05 07:43:09 +08:00
prompt_arg = prompt_input
2024-08-01 16:01:30 -07:00
num_choices = 1
2024-08-01 14:34:55 -07:00
response = client . completions . create (
model = self . model ,
2024-08-01 16:01:30 -07:00
prompt = prompt_arg ,
2024-08-05 07:43:09 +08:00
temperature = 0 ,
2024-08-01 14:34:55 -07:00
max_tokens = 32 ,
echo = echo ,
logprobs = logprobs ,
2024-08-05 07:43:09 +08:00
n = parallel_sample_num ,
2024-08-01 14:34:55 -07:00
)
2024-08-01 16:01:30 -07:00
2024-08-05 07:43:09 +08:00
assert len ( response . choices ) == num_choices * parallel_sample_num
2024-08-01 16:01:30 -07:00
2024-02-06 12:24:55 -08:00
if echo :
2024-08-01 16:01:30 -07:00
text = response . choices [ 0 ] . text
2024-08-01 14:34:55 -07:00
assert text . startswith ( prompt )
2024-08-05 07:43:09 +08:00
2024-02-06 12:24:55 -08:00
if logprobs :
2024-08-01 14:34:55 -07:00
assert response . choices [ 0 ] . logprobs
assert isinstance ( response . choices [ 0 ] . logprobs . tokens [ 0 ] , str )
assert isinstance ( response . choices [ 0 ] . logprobs . top_logprobs [ 1 ] , dict )
2024-08-01 16:01:30 -07:00
ret_num_top_logprobs = len ( response . choices [ 0 ] . logprobs . top_logprobs [ 1 ] )
2024-08-18 23:45:41 -07:00
2024-09-15 06:36:06 -07:00
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
2024-08-01 16:01:30 -07:00
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
2024-08-05 07:43:09 +08:00
assert ret_num_top_logprobs > 0
2024-09-15 06:36:06 -07:00
2024-11-12 15:21:20 +08:00
# when echo=True and request.logprobs>0, logprob_start_len is 0, so the first token's logprob would be None.
if not echo :
assert response . choices [ 0 ] . logprobs . token_logprobs [ 0 ]
2024-08-05 07:43:09 +08:00
2024-08-01 14:34:55 -07:00
assert response . id
assert response . created
2024-08-05 07:43:09 +08:00
assert (
response . usage . prompt_tokens == num_prompt_tokens
) , f " { response . usage . prompt_tokens } vs { num_prompt_tokens } "
2024-08-01 14:34:55 -07:00
assert response . usage . completion_tokens > 0
assert response . usage . total_tokens > 0
2024-08-20 08:06:55 -07:00
def run_completion_stream (
2025-05-19 18:21:29 -07:00
self , echo , logprobs , use_list_input , parallel_sample_num , token_input
2024-08-20 08:06:55 -07:00
) :
2024-08-04 13:35:44 -07:00
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
2024-08-01 14:34:55 -07:00
prompt = " The capital of France is "
2024-08-05 07:43:09 +08:00
if token_input :
2024-08-20 08:06:55 -07:00
prompt_input = self . tokenizer . encode ( prompt )
num_prompt_tokens = len ( prompt_input )
2024-08-05 07:43:09 +08:00
else :
2024-08-20 08:06:55 -07:00
prompt_input = prompt
num_prompt_tokens = len ( self . tokenizer . encode ( prompt ) )
if use_list_input :
prompt_arg = [ prompt_input , prompt_input ]
num_choices = len ( prompt_arg )
num_prompt_tokens * = 2
else :
prompt_arg = prompt_input
num_choices = 1
2024-08-01 14:34:55 -07:00
generator = client . completions . create (
model = self . model ,
2024-08-05 07:43:09 +08:00
prompt = prompt_arg ,
temperature = 0 ,
2024-08-01 14:34:55 -07:00
max_tokens = 32 ,
echo = echo ,
logprobs = logprobs ,
stream = True ,
2024-08-08 17:41:57 +08:00
stream_options = { " include_usage " : True } ,
2024-08-20 08:06:55 -07:00
n = parallel_sample_num ,
2024-08-01 14:34:55 -07:00
)
2024-08-20 08:06:55 -07:00
is_firsts = { }
2024-08-01 14:34:55 -07:00
for response in generator :
2024-08-08 17:41:57 +08:00
usage = response . usage
if usage is not None :
2025-06-10 17:37:29 -04:00
assert usage . prompt_tokens > 0 , f " usage.prompt_tokens was zero "
assert usage . completion_tokens > 0 , f " usage.completion_tokens was zero "
assert usage . total_tokens > 0 , f " usage.total_tokens was zero "
2024-08-08 17:41:57 +08:00
continue
2024-08-20 08:06:55 -07:00
index = response . choices [ 0 ] . index
is_first = is_firsts . get ( index , True )
2024-08-01 14:34:55 -07:00
if logprobs :
2025-06-10 17:37:29 -04:00
assert response . choices [ 0 ] . logprobs , f " no logprobs in response "
assert isinstance (
response . choices [ 0 ] . logprobs . tokens [ 0 ] , str
) , f " { response . choices [ 0 ] . logprobs . tokens [ 0 ] } is not a string "
2024-08-20 08:06:55 -07:00
if not ( is_first and echo ) :
2024-08-01 16:01:30 -07:00
assert isinstance (
response . choices [ 0 ] . logprobs . top_logprobs [ 0 ] , dict
2025-06-10 17:37:29 -04:00
) , f " top_logprobs was not a dictionary "
2024-08-01 16:01:30 -07:00
ret_num_top_logprobs = len (
response . choices [ 0 ] . logprobs . top_logprobs [ 0 ]
)
2024-09-15 06:36:06 -07:00
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
2024-08-01 16:01:30 -07:00
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
2025-06-10 17:37:29 -04:00
assert ret_num_top_logprobs > 0 , f " ret_num_top_logprobs was 0 "
2024-08-01 14:34:55 -07:00
2024-08-20 08:06:55 -07:00
if is_first :
2024-08-01 14:34:55 -07:00
if echo :
2024-08-05 07:43:09 +08:00
assert response . choices [ 0 ] . text . startswith (
prompt
2024-08-20 08:06:55 -07:00
) , f " { response . choices [ 0 ] . text } and all args { echo } { logprobs } { token_input } { is_first } "
is_firsts [ index ] = False
2025-06-10 17:37:29 -04:00
assert response . id , f " no id in response "
assert response . created , f " no created in response "
2024-08-01 14:34:55 -07:00
2024-08-20 08:06:55 -07:00
for index in [ i for i in range ( parallel_sample_num * num_choices ) ] :
assert not is_firsts . get (
index , True
) , f " index { index } is not found in the response "
2025-05-19 18:21:29 -07:00
def run_chat_completion ( self , logprobs , parallel_sample_num ) :
2024-08-04 13:35:44 -07:00
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
2024-08-01 16:01:30 -07:00
response = client . chat . completions . create (
model = self . model ,
messages = [
{ " role " : " system " , " content " : " You are a helpful AI assistant " } ,
2024-08-04 22:52:41 -07:00
{
" role " : " user " ,
" content " : " What is the capital of France? Answer in a few words. " ,
} ,
2024-08-01 16:01:30 -07:00
] ,
temperature = 0 ,
logprobs = logprobs is not None and logprobs > 0 ,
top_logprobs = logprobs ,
2024-08-05 07:43:09 +08:00
n = parallel_sample_num ,
2024-08-01 16:01:30 -07:00
)
2024-08-04 20:51:55 -07:00
2024-08-01 16:01:30 -07:00
if logprobs :
assert isinstance (
response . choices [ 0 ] . logprobs . content [ 0 ] . top_logprobs [ 0 ] . token , str
)
ret_num_top_logprobs = len (
response . choices [ 0 ] . logprobs . content [ 0 ] . top_logprobs
)
assert (
ret_num_top_logprobs == logprobs
) , f " { ret_num_top_logprobs } vs { logprobs } "
2024-08-04 20:51:55 -07:00
2024-08-05 07:43:09 +08:00
assert len ( response . choices ) == parallel_sample_num
2024-08-01 16:01:30 -07:00
assert response . choices [ 0 ] . message . role == " assistant "
assert isinstance ( response . choices [ 0 ] . message . content , str )
assert response . id
assert response . created
assert response . usage . prompt_tokens > 0
assert response . usage . completion_tokens > 0
assert response . usage . total_tokens > 0
2025-05-19 18:21:29 -07:00
def run_chat_completion_stream ( self , logprobs , parallel_sample_num = 1 ) :
2024-08-04 13:35:44 -07:00
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
2024-08-01 16:01:30 -07:00
generator = client . chat . completions . create (
model = self . model ,
messages = [
{ " role " : " system " , " content " : " You are a helpful AI assistant " } ,
{ " role " : " user " , " content " : " What is the capital of France? " } ,
] ,
temperature = 0 ,
logprobs = logprobs is not None and logprobs > 0 ,
top_logprobs = logprobs ,
stream = True ,
2024-08-08 17:41:57 +08:00
stream_options = { " include_usage " : True } ,
2024-08-20 08:06:55 -07:00
n = parallel_sample_num ,
2024-08-01 16:01:30 -07:00
)
2024-08-20 08:06:55 -07:00
is_firsts = { }
2025-06-21 13:21:06 -07:00
is_finished = { }
2025-07-27 13:31:06 -07:00
finish_reason_counts = { }
2024-08-01 16:01:30 -07:00
for response in generator :
2024-08-08 17:41:57 +08:00
usage = response . usage
if usage is not None :
2025-06-10 17:37:29 -04:00
assert usage . prompt_tokens > 0 , f " usage.prompt_tokens was zero "
assert usage . completion_tokens > 0 , f " usage.completion_tokens was zero "
assert usage . total_tokens > 0 , f " usage.total_tokens was zero "
2024-08-08 17:41:57 +08:00
continue
2024-08-20 08:06:55 -07:00
index = response . choices [ 0 ] . index
2025-06-21 13:21:06 -07:00
finish_reason = response . choices [ 0 ] . finish_reason
if finish_reason is not None :
is_finished [ index ] = True
2025-07-27 13:31:06 -07:00
finish_reason_counts [ index ] = finish_reason_counts . get ( index , 0 ) + 1
2025-06-21 13:21:06 -07:00
2024-08-01 16:01:30 -07:00
data = response . choices [ 0 ] . delta
2024-08-08 17:41:57 +08:00
2024-08-20 08:06:55 -07:00
if is_firsts . get ( index , True ) :
2025-06-10 17:37:29 -04:00
assert (
data . role == " assistant "
) , f " data.role was not ' assistant ' for first chunk "
2024-08-20 08:06:55 -07:00
is_firsts [ index ] = False
2024-08-01 16:01:30 -07:00
continue
2025-06-21 13:21:06 -07:00
if logprobs and not is_finished . get ( index , False ) :
2025-06-10 17:37:29 -04:00
assert response . choices [ 0 ] . logprobs , f " logprobs was not returned "
2024-08-05 07:43:09 +08:00
assert isinstance (
response . choices [ 0 ] . logprobs . content [ 0 ] . top_logprobs [ 0 ] . token , str
2025-06-10 17:37:29 -04:00
) , f " top_logprobs token was not a string "
2024-08-05 07:43:09 +08:00
assert isinstance (
response . choices [ 0 ] . logprobs . content [ 0 ] . top_logprobs , list
2025-06-10 17:37:29 -04:00
) , f " top_logprobs was not a list "
2024-08-05 07:43:09 +08:00
ret_num_top_logprobs = len (
response . choices [ 0 ] . logprobs . content [ 0 ] . top_logprobs
)
assert (
ret_num_top_logprobs == logprobs
) , f " { ret_num_top_logprobs } vs { logprobs } "
2024-08-01 16:01:30 -07:00
2025-03-27 15:16:52 +08:00
assert (
isinstance ( data . content , str )
or isinstance ( data . reasoning_content , str )
2025-06-21 13:21:06 -07:00
or ( isinstance ( data . tool_calls , list ) and len ( data . tool_calls ) > 0 )
2025-03-27 15:16:52 +08:00
or response . choices [ 0 ] . finish_reason
)
2024-08-01 16:01:30 -07:00
assert response . id
assert response . created
2024-08-20 08:06:55 -07:00
for index in [ i for i in range ( parallel_sample_num ) ] :
assert not is_firsts . get (
index , True
) , f " index { index } is not found in the response "
2025-07-27 13:31:06 -07:00
# Verify that each choice gets exactly one finish_reason chunk
for index in range ( parallel_sample_num ) :
assert (
index in finish_reason_counts
) , f " No finish_reason found for index { index } "
assert (
finish_reason_counts [ index ] == 1
) , f " Expected 1 finish_reason chunk for index { index } , got { finish_reason_counts [ index ] } "
2024-08-01 14:34:55 -07:00
def test_completion ( self ) :
2025-05-19 18:21:29 -07:00
for echo in [ False , True ] :
for logprobs in [ None , 5 ] :
for use_list_input in [ True , False ] :
for parallel_sample_num in [ 1 , 2 ] :
for token_input in [ False , True ] :
self . run_completion (
echo ,
logprobs ,
use_list_input ,
parallel_sample_num ,
token_input ,
)
2024-08-01 14:34:55 -07:00
def test_completion_stream ( self ) :
2025-05-11 00:55:00 -04:00
# parallel sampling and list input are not supported in streaming mode
2025-05-19 18:21:29 -07:00
for echo in [ False , True ] :
for logprobs in [ None , 5 ] :
for use_list_input in [ True , False ] :
for parallel_sample_num in [ 1 , 2 ] :
for token_input in [ False , True ] :
self . run_completion_stream (
echo ,
logprobs ,
use_list_input ,
parallel_sample_num ,
token_input ,
)
2024-02-10 17:21:33 -08:00
2024-08-01 16:01:30 -07:00
def test_chat_completion ( self ) :
2025-05-19 18:21:29 -07:00
for logprobs in [ None , 5 ] :
for parallel_sample_num in [ 1 , 2 ] :
self . run_chat_completion ( logprobs , parallel_sample_num )
2024-08-01 16:01:30 -07:00
def test_chat_completion_stream ( self ) :
2025-05-19 18:21:29 -07:00
for logprobs in [ None , 5 ] :
for parallel_sample_num in [ 1 , 2 ] :
self . run_chat_completion_stream ( logprobs , parallel_sample_num )
2024-08-01 16:01:30 -07:00
def test_regex ( self ) :
2024-08-04 13:35:44 -07:00
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
2024-08-01 16:01:30 -07:00
regex = (
r """ \ { \ n """
+ r """ " name " : " [ \ w]+ " , \ n """
+ r """ " population " : [ \ d]+ \ n """
+ r """ \ } """
)
response = client . chat . completions . create (
model = self . model ,
messages = [
{ " role " : " system " , " content " : " You are a helpful AI assistant " } ,
{ " role " : " user " , " content " : " Introduce the capital of France. " } ,
] ,
temperature = 0 ,
max_tokens = 128 ,
extra_body = { " regex " : regex } ,
)
text = response . choices [ 0 ] . message . content
try :
js_obj = json . loads ( text )
except ( TypeError , json . decoder . JSONDecodeError ) :
print ( " JSONDecodeError " , text )
raise
assert isinstance ( js_obj [ " name " ] , str )
assert isinstance ( js_obj [ " population " ] , int )
2024-09-15 06:36:06 -07:00
def test_penalty ( self ) :
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
response = client . chat . completions . create (
model = self . model ,
messages = [
{ " role " : " system " , " content " : " You are a helpful AI assistant " } ,
{ " role " : " user " , " content " : " Introduce the capital of France. " } ,
] ,
temperature = 0 ,
max_tokens = 32 ,
frequency_penalty = 1.0 ,
)
text = response . choices [ 0 ] . message . content
assert isinstance ( text , str )
2024-09-22 06:46:17 -07:00
def test_response_prefill ( self ) :
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
response = client . chat . completions . create (
2024-10-02 10:12:07 -07:00
model = " meta-llama/Llama-3.1-8B-Instruct " ,
2024-09-22 06:46:17 -07:00
messages = [
{ " role " : " system " , " content " : " You are a helpful AI assistant " } ,
{
" role " : " user " ,
" content " : """
Extract the name , size , price , and color from this product description as a JSON object :
< description >
The SmartHome Mini is a compact smart home assistant available in black or white for only $ 49.99 . At just 5 inches wide , it lets you control lights , thermostats , and other connected devices via voice or app — no matter where you place it in your home . This affordable little hub brings convenient hands - free control to your smart devices .
< / description >
""" ,
} ,
{
" role " : " assistant " ,
" content " : " { \n " ,
} ,
] ,
temperature = 0 ,
2025-04-21 06:07:18 +05:30
extra_body = { " continue_final_message " : True } ,
2024-09-22 06:46:17 -07:00
)
assert (
response . choices [ 0 ]
. message . content . strip ( )
. startswith ( ' " name " : " SmartHome Mini " , ' )
)
2025-03-27 23:23:18 -04:00
def test_model_list ( self ) :
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
models = list ( client . models . list ( ) )
assert len ( models ) == 1
assert isinstance ( getattr ( models [ 0 ] , " max_model_len " , None ) , int )
2025-06-21 13:21:06 -07:00
def test_retrieve_model ( self ) :
client = openai . Client ( api_key = self . api_key , base_url = self . base_url )
# Test retrieving an existing model
retrieved_model = client . models . retrieve ( self . model )
self . assertEqual ( retrieved_model . id , self . model )
self . assertEqual ( retrieved_model . root , self . model )
# Test retrieving a non-existent model
with self . assertRaises ( openai . NotFoundError ) :
client . models . retrieve ( " non-existent-model " )
2024-02-10 17:21:33 -08:00
2025-06-17 01:50:01 +08:00
class TestOpenAIV1Rerank ( CustomTestCase ) :
@classmethod
def setUpClass ( cls ) :
cls . model = DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST
cls . base_url = DEFAULT_URL_FOR_TEST
cls . api_key = " sk-123456 "
cls . score_tolerance = 1e-2
# Configure embedding-specific args
other_args = [
" --is-embedding " ,
" --enable-metrics " ,
" --disable-radix-cache " ,
" --chunked-prefill-size " ,
" -1 " ,
" --attention-backend " ,
" torch_native " ,
]
cls . process = popen_launch_server (
cls . model ,
cls . base_url ,
timeout = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ,
api_key = cls . api_key ,
other_args = other_args ,
)
cls . base_url + = " /v1/rerank "
@classmethod
def tearDownClass ( cls ) :
kill_process_tree ( cls . process . pid )
def run_rerank ( self , query , docs ) :
response = requests . post (
self . base_url ,
headers = {
" Authorization " : f " Bearer { self . api_key } " ,
" Content-Type " : " application/json " ,
} ,
json = { " query " : query , " documents " : docs } ,
)
return response . json ( )
def test_rerank_single ( self ) :
""" Test single rerank request """
query = TEST_RERANK_QUERY_DOCS [ 0 ] [ " query " ]
docs = TEST_RERANK_QUERY_DOCS [ 0 ] [ " documents " ]
response = self . run_rerank ( query , docs )
self . assertEqual ( len ( response ) , 1 )
self . assertTrue ( isinstance ( response [ 0 ] [ " score " ] , float ) )
self . assertTrue ( isinstance ( response [ 0 ] [ " document " ] , str ) )
self . assertTrue ( isinstance ( response [ 0 ] [ " index " ] , int ) )
def test_rerank_batch ( self ) :
""" Test batch rerank request """
query = TEST_RERANK_QUERY_DOCS [ 1 ] [ " query " ]
docs = TEST_RERANK_QUERY_DOCS [ 1 ] [ " documents " ]
response = self . run_rerank ( query , docs )
self . assertEqual ( len ( response ) , 2 )
self . assertTrue ( isinstance ( response [ 0 ] [ " score " ] , float ) )
self . assertTrue ( isinstance ( response [ 1 ] [ " score " ] , float ) )
self . assertTrue ( isinstance ( response [ 0 ] [ " document " ] , str ) )
self . assertTrue ( isinstance ( response [ 1 ] [ " document " ] , str ) )
self . assertTrue ( isinstance ( response [ 0 ] [ " index " ] , int ) )
self . assertTrue ( isinstance ( response [ 1 ] [ " index " ] , int ) )
2025-06-04 14:14:54 -07:00
class TestOpenAIV1Score ( CustomTestCase ) :
@classmethod
def setUpClass ( cls ) :
cls . model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls . base_url = DEFAULT_URL_FOR_TEST
cls . api_key = " sk-123456 "
cls . process = popen_launch_server (
cls . model ,
cls . base_url ,
timeout = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH ,
api_key = cls . api_key ,
)
cls . base_url + = " /v1/score "
cls . tokenizer = get_tokenizer ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST )
@classmethod
def tearDownClass ( cls ) :
kill_process_tree ( cls . process . pid )
def run_score (
self , query , items , label_token_ids , apply_softmax = False , item_first = False
) :
response = requests . post (
self . base_url ,
headers = {
" Authorization " : f " Bearer { self . api_key } " ,
" Content-Type " : " application/json " ,
} ,
json = {
" model " : self . model ,
" query " : query ,
" items " : items ,
" label_token_ids " : label_token_ids ,
" apply_softmax " : apply_softmax ,
" item_first " : item_first ,
} ,
)
return response . json ( )
def test_score_text_input ( self ) :
""" Test scoring with text input """
query = " The capital of France is "
items = [ " Paris " , " London " , " Berlin " ]
# Get valid token IDs from the tokenizer
label_token_ids = [ ]
for item in items :
token_ids = self . tokenizer . encode ( item , add_special_tokens = False )
if not token_ids :
self . fail ( f " Failed to encode item: { item } " )
label_token_ids . append ( token_ids [ 0 ] )
response = self . run_score ( query , items , label_token_ids , apply_softmax = True )
# Handle error responses
if response . get ( " type " ) == " BadRequestError " :
self . fail ( f " Score request failed with error: { response [ ' message ' ] } " )
# Verify response structure
self . assertIn ( " scores " , response , " Response should have a ' scores ' field " )
self . assertIsInstance ( response [ " scores " ] , list , " scores should be a list " )
self . assertEqual (
len ( response [ " scores " ] ) ,
len ( items ) ,
" Number of scores should match number of items " ,
)
# Each score should be a list of floats in the order of label_token_ids
for i , score_list in enumerate ( response [ " scores " ] ) :
self . assertIsInstance ( score_list , list , f " Score { i } should be a list " )
self . assertEqual (
len ( score_list ) ,
len ( label_token_ids ) ,
f " Score { i } length should match label_token_ids " ,
)
self . assertTrue (
all ( isinstance ( v , float ) for v in score_list ) ,
f " Score { i } values should be floats " ,
)
self . assertAlmostEqual (
sum ( score_list ) ,
1.0 ,
places = 6 ,
msg = f " Score { i } probabilities should sum to 1 " ,
)
def test_score_token_input ( self ) :
""" Test scoring with token IDs input """
query = " The capital of France is "
items = [ " Paris " , " London " , " Berlin " ]
# Get valid token IDs
query_ids = self . tokenizer . encode ( query , add_special_tokens = False )
item_ids = [
self . tokenizer . encode ( item , add_special_tokens = False ) for item in items
]
label_token_ids = [
ids [ 0 ] for ids in item_ids if ids
] # Get first token ID of each item
response = self . run_score (
query_ids , item_ids , label_token_ids , apply_softmax = True
)
# Handle error responses
if response . get ( " type " ) == " BadRequestError " :
self . fail ( f " Score request failed with error: { response [ ' message ' ] } " )
# Verify response structure
self . assertIn ( " scores " , response , " Response should have a ' scores ' field " )
self . assertIsInstance ( response [ " scores " ] , list , " scores should be a list " )
self . assertEqual (
len ( response [ " scores " ] ) ,
len ( items ) ,
" Number of scores should match number of items " ,
)
# Each score should be a list of floats in the order of label_token_ids
for i , score_list in enumerate ( response [ " scores " ] ) :
self . assertIsInstance ( score_list , list , f " Score { i } should be a list " )
self . assertEqual (
len ( score_list ) ,
len ( label_token_ids ) ,
f " Score { i } length should match label_token_ids " ,
)
self . assertTrue (
all ( isinstance ( v , float ) for v in score_list ) ,
f " Score { i } values should be floats " ,
)
self . assertAlmostEqual (
sum ( score_list ) ,
1.0 ,
places = 6 ,
msg = f " Score { i } probabilities should sum to 1 " ,
)
def test_score_error_handling ( self ) :
""" Test error handling for invalid inputs """
query = " The capital of France is "
items = [ " Paris " , " London " , " Berlin " ]
# Test with invalid token ID
response = requests . post (
self . base_url ,
headers = {
" Authorization " : f " Bearer { self . api_key } " ,
" Content-Type " : " application/json " ,
} ,
json = {
" model " : self . model ,
" query " : query ,
" items " : items ,
" label_token_ids " : [ 999999 ] , # Invalid token ID
" apply_softmax " : True ,
} ,
)
self . assertEqual ( response . status_code , 400 )
error_response = response . json ( )
self . assertEqual ( error_response [ " type " ] , " BadRequestError " )
self . assertIn ( " Token ID 999999 is out of vocabulary " , error_response [ " message " ] )
2024-01-18 17:00:56 -08:00
if __name__ == " __main__ " :
2024-08-10 15:09:03 -07:00
unittest . main ( )