Decoder-only Scoring API (#6460)
Co-authored-by: Chanh Nguyen <cnguyen@linkedin.com>
This commit is contained in:
@@ -10,6 +10,7 @@ import time
|
||||
import unittest
|
||||
|
||||
import openai
|
||||
import requests
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
@@ -599,7 +600,6 @@ class TestOpenAIServerEBNF(CustomTestCase):
|
||||
extra_body={"ebnf": ebnf_grammar},
|
||||
)
|
||||
text = response.choices[0].message.content.strip()
|
||||
print("EBNF test output:", repr(text))
|
||||
self.assertTrue(len(text) > 0, "Got empty text from EBNF generation")
|
||||
self.assertRegex(text, pattern, f"Text '{text}' doesn't match EBNF choices")
|
||||
|
||||
@@ -630,7 +630,6 @@ class TestOpenAIServerEBNF(CustomTestCase):
|
||||
extra_body={"ebnf": ebnf_grammar},
|
||||
)
|
||||
text = response.choices[0].message.content.strip()
|
||||
print("EBNF strict JSON test output:", repr(text))
|
||||
self.assertTrue(len(text) > 0, "Got empty text from EBNF strict JSON test")
|
||||
self.assertRegex(
|
||||
text, pattern, f"Text '{text}' not matching the EBNF strict JSON shape"
|
||||
@@ -766,5 +765,168 @@ class TestOpenAIServerIgnoreEOS(CustomTestCase):
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAIV1Score(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
)
|
||||
cls.base_url += "/v1/score"
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_score(
|
||||
self, query, items, label_token_ids, apply_softmax=False, item_first=False
|
||||
):
|
||||
response = requests.post(
|
||||
self.base_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": self.model,
|
||||
"query": query,
|
||||
"items": items,
|
||||
"label_token_ids": label_token_ids,
|
||||
"apply_softmax": apply_softmax,
|
||||
"item_first": item_first,
|
||||
},
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def test_score_text_input(self):
|
||||
"""Test scoring with text input"""
|
||||
query = "The capital of France is"
|
||||
items = ["Paris", "London", "Berlin"]
|
||||
|
||||
# Get valid token IDs from the tokenizer
|
||||
label_token_ids = []
|
||||
for item in items:
|
||||
token_ids = self.tokenizer.encode(item, add_special_tokens=False)
|
||||
if not token_ids:
|
||||
self.fail(f"Failed to encode item: {item}")
|
||||
label_token_ids.append(token_ids[0])
|
||||
|
||||
response = self.run_score(query, items, label_token_ids, apply_softmax=True)
|
||||
|
||||
# Handle error responses
|
||||
if response.get("type") == "BadRequestError":
|
||||
self.fail(f"Score request failed with error: {response['message']}")
|
||||
|
||||
# Verify response structure
|
||||
self.assertIn("scores", response, "Response should have a 'scores' field")
|
||||
self.assertIsInstance(response["scores"], list, "scores should be a list")
|
||||
self.assertEqual(
|
||||
len(response["scores"]),
|
||||
len(items),
|
||||
"Number of scores should match number of items",
|
||||
)
|
||||
|
||||
# Each score should be a list of floats in the order of label_token_ids
|
||||
for i, score_list in enumerate(response["scores"]):
|
||||
self.assertIsInstance(score_list, list, f"Score {i} should be a list")
|
||||
self.assertEqual(
|
||||
len(score_list),
|
||||
len(label_token_ids),
|
||||
f"Score {i} length should match label_token_ids",
|
||||
)
|
||||
self.assertTrue(
|
||||
all(isinstance(v, float) for v in score_list),
|
||||
f"Score {i} values should be floats",
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
sum(score_list),
|
||||
1.0,
|
||||
places=6,
|
||||
msg=f"Score {i} probabilities should sum to 1",
|
||||
)
|
||||
|
||||
def test_score_token_input(self):
|
||||
"""Test scoring with token IDs input"""
|
||||
query = "The capital of France is"
|
||||
items = ["Paris", "London", "Berlin"]
|
||||
|
||||
# Get valid token IDs
|
||||
query_ids = self.tokenizer.encode(query, add_special_tokens=False)
|
||||
item_ids = [
|
||||
self.tokenizer.encode(item, add_special_tokens=False) for item in items
|
||||
]
|
||||
label_token_ids = [
|
||||
ids[0] for ids in item_ids if ids
|
||||
] # Get first token ID of each item
|
||||
|
||||
response = self.run_score(
|
||||
query_ids, item_ids, label_token_ids, apply_softmax=True
|
||||
)
|
||||
|
||||
# Handle error responses
|
||||
if response.get("type") == "BadRequestError":
|
||||
self.fail(f"Score request failed with error: {response['message']}")
|
||||
|
||||
# Verify response structure
|
||||
self.assertIn("scores", response, "Response should have a 'scores' field")
|
||||
self.assertIsInstance(response["scores"], list, "scores should be a list")
|
||||
self.assertEqual(
|
||||
len(response["scores"]),
|
||||
len(items),
|
||||
"Number of scores should match number of items",
|
||||
)
|
||||
|
||||
# Each score should be a list of floats in the order of label_token_ids
|
||||
for i, score_list in enumerate(response["scores"]):
|
||||
self.assertIsInstance(score_list, list, f"Score {i} should be a list")
|
||||
self.assertEqual(
|
||||
len(score_list),
|
||||
len(label_token_ids),
|
||||
f"Score {i} length should match label_token_ids",
|
||||
)
|
||||
self.assertTrue(
|
||||
all(isinstance(v, float) for v in score_list),
|
||||
f"Score {i} values should be floats",
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
sum(score_list),
|
||||
1.0,
|
||||
places=6,
|
||||
msg=f"Score {i} probabilities should sum to 1",
|
||||
)
|
||||
|
||||
def test_score_error_handling(self):
|
||||
"""Test error handling for invalid inputs"""
|
||||
query = "The capital of France is"
|
||||
items = ["Paris", "London", "Berlin"]
|
||||
|
||||
# Test with invalid token ID
|
||||
response = requests.post(
|
||||
self.base_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": self.model,
|
||||
"query": query,
|
||||
"items": items,
|
||||
"label_token_ids": [999999], # Invalid token ID
|
||||
"apply_softmax": True,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
error_response = response.json()
|
||||
self.assertEqual(error_response["type"], "BadRequestError")
|
||||
self.assertIn("Token ID 999999 is out of vocabulary", error_response["message"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user