Decoder-only Scoring API (#6460)
Co-authored-by: Chanh Nguyen <cnguyen@linkedin.com>
This commit is contained in:
@@ -472,6 +472,79 @@ class Engine(EngineBase):
|
|||||||
def save_sharded_model(self, **kwargs):
|
def save_sharded_model(self, **kwargs):
|
||||||
self.collective_rpc("save_sharded_model", **kwargs)
|
self.collective_rpc("save_sharded_model", **kwargs)
|
||||||
|
|
||||||
|
def score(
|
||||||
|
self,
|
||||||
|
query: Optional[Union[str, List[int]]] = None,
|
||||||
|
items: Optional[Union[str, List[str], List[List[int]]]] = None,
|
||||||
|
label_token_ids: Optional[List[int]] = None,
|
||||||
|
apply_softmax: bool = False,
|
||||||
|
item_first: bool = False,
|
||||||
|
) -> List[List[float]]:
|
||||||
|
"""
|
||||||
|
Score the probability of specified token IDs appearing after the given (query + item) pair. For example:
|
||||||
|
query = "<|user|>Is the following city the capital of France? "
|
||||||
|
items = ["Paris <|assistant|>", "London <|assistant|>", "Berlin <|assistant|>"]
|
||||||
|
label_token_ids = [2332, 1223] # Token IDs for "Yes" and "No"
|
||||||
|
item_first = False
|
||||||
|
|
||||||
|
This would pass the following prompts to the model:
|
||||||
|
"<|user|>Is the following city the capital of France? Paris <|assistant|>"
|
||||||
|
"<|user|>Is the following city the capital of France? London <|assistant|>"
|
||||||
|
"<|user|>Is the following city the capital of France? Berlin <|assistant|>"
|
||||||
|
The api would then return the probabilities of the model producing "Yes" and "No" as the next token.
|
||||||
|
The output would look like:
|
||||||
|
[[0.9, 0.1], [0.2, 0.8], [0.1, 0.9]]
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The query text or pre-tokenized query token IDs. Must be provided.
|
||||||
|
items: The item text(s) or pre-tokenized item token IDs. Must be provided.
|
||||||
|
label_token_ids: List of token IDs to compute probabilities for. If None, no token probabilities will be computed.
|
||||||
|
apply_softmax: Whether to normalize probabilities using softmax.
|
||||||
|
item_first: If True, prepend items to query. Otherwise append items to query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dictionaries mapping token IDs to their probabilities for each item.
|
||||||
|
Each dictionary in the list corresponds to one item input.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If query is not provided, or if items is not provided,
|
||||||
|
or if token IDs are out of vocabulary, or if logprobs are not available for the specified tokens.
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.run_until_complete(
|
||||||
|
self.tokenizer_manager.score_request(
|
||||||
|
query=query,
|
||||||
|
items=items,
|
||||||
|
label_token_ids=label_token_ids,
|
||||||
|
apply_softmax=apply_softmax,
|
||||||
|
item_first=item_first,
|
||||||
|
request=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_score(
|
||||||
|
self,
|
||||||
|
query: Optional[Union[str, List[int]]] = None,
|
||||||
|
items: Optional[Union[str, List[str], List[List[int]]]] = None,
|
||||||
|
label_token_ids: Optional[List[int]] = None,
|
||||||
|
apply_softmax: bool = False,
|
||||||
|
item_first: bool = False,
|
||||||
|
) -> List[List[float]]:
|
||||||
|
"""
|
||||||
|
Asynchronous version of score method.
|
||||||
|
|
||||||
|
See score() for detailed documentation.
|
||||||
|
"""
|
||||||
|
return await self.tokenizer_manager.score_request(
|
||||||
|
query=query,
|
||||||
|
items=items,
|
||||||
|
label_token_ids=label_token_ids,
|
||||||
|
apply_softmax=apply_softmax,
|
||||||
|
item_first=item_first,
|
||||||
|
request=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _set_envs_and_config(server_args: ServerArgs):
|
def _set_envs_and_config(server_args: ServerArgs):
|
||||||
# Set global environments
|
# Set global environments
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ from sglang.srt.openai_api.adapter import (
|
|||||||
v1_retrieve_batch,
|
v1_retrieve_batch,
|
||||||
v1_retrieve_file,
|
v1_retrieve_file,
|
||||||
v1_retrieve_file_content,
|
v1_retrieve_file_content,
|
||||||
|
v1_score,
|
||||||
)
|
)
|
||||||
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
||||||
from sglang.srt.reasoning_parser import ReasoningParser
|
from sglang.srt.reasoning_parser import ReasoningParser
|
||||||
@@ -720,6 +721,12 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
|
|||||||
return ORJSONResponse({"predictions": ret})
|
return ORJSONResponse({"predictions": ret})
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/score")
|
||||||
|
async def v1_score_request(raw_request: Request):
|
||||||
|
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
|
||||||
|
return await v1_score(_global_state.tokenizer_manager, raw_request)
|
||||||
|
|
||||||
|
|
||||||
def _create_error_response(e):
|
def _create_error_response(e):
|
||||||
return ORJSONResponse(
|
return ORJSONResponse(
|
||||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import copy
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import signal
|
import signal
|
||||||
@@ -42,6 +43,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
|
import torch
|
||||||
import uvloop
|
import uvloop
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
@@ -1433,6 +1435,100 @@ class TokenizerManager:
|
|||||||
if len(self.model_update_tmp) == self.server_args.dp_size:
|
if len(self.model_update_tmp) == self.server_args.dp_size:
|
||||||
self.model_update_result.set_result(self.model_update_tmp)
|
self.model_update_result.set_result(self.model_update_tmp)
|
||||||
|
|
||||||
|
async def score_request(
|
||||||
|
self,
|
||||||
|
query: Optional[Union[str, List[int]]] = None,
|
||||||
|
items: Optional[Union[str, List[str], List[List[int]]]] = None,
|
||||||
|
label_token_ids: Optional[List[int]] = None,
|
||||||
|
apply_softmax: bool = False,
|
||||||
|
item_first: bool = False,
|
||||||
|
request: Optional[Any] = None,
|
||||||
|
) -> List[List[float]]:
|
||||||
|
"""
|
||||||
|
See Engine.score() for more details.
|
||||||
|
"""
|
||||||
|
if label_token_ids is None:
|
||||||
|
raise ValueError("label_token_ids must be provided")
|
||||||
|
|
||||||
|
if self.tokenizer is not None:
|
||||||
|
vocab_size = self.tokenizer.vocab_size
|
||||||
|
for token_id in label_token_ids:
|
||||||
|
if token_id >= vocab_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle string or tokenized query/items
|
||||||
|
if isinstance(query, str) and (
|
||||||
|
isinstance(items, str)
|
||||||
|
or (isinstance(items, list) and (not items or isinstance(items[0], str)))
|
||||||
|
):
|
||||||
|
# Both query and items are text
|
||||||
|
items_list = [items] if isinstance(items, str) else items
|
||||||
|
if item_first:
|
||||||
|
prompts = [f"{item}{query}" for item in items_list]
|
||||||
|
else:
|
||||||
|
prompts = [f"{query}{item}" for item in items_list]
|
||||||
|
batch_request = GenerateReqInput(
|
||||||
|
text=prompts,
|
||||||
|
return_logprob=True,
|
||||||
|
token_ids_logprob=label_token_ids,
|
||||||
|
stream=False,
|
||||||
|
sampling_params={"max_new_tokens": 1},
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
isinstance(query, list)
|
||||||
|
and isinstance(items, list)
|
||||||
|
and items
|
||||||
|
and isinstance(items[0], list)
|
||||||
|
):
|
||||||
|
# Both query and items are token IDs
|
||||||
|
if item_first:
|
||||||
|
input_ids_list = [item + query for item in items]
|
||||||
|
else:
|
||||||
|
input_ids_list = [query + item for item in items]
|
||||||
|
batch_request = GenerateReqInput(
|
||||||
|
input_ids=input_ids_list,
|
||||||
|
return_logprob=True,
|
||||||
|
token_ids_logprob=label_token_ids,
|
||||||
|
stream=False,
|
||||||
|
sampling_params={"max_new_tokens": 1},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid combination of query/items types for score_request."
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await self.generate_request(batch_request, request).__anext__()
|
||||||
|
scores = []
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
# Get logprobs for each token
|
||||||
|
logprobs = {}
|
||||||
|
for logprob, token_id, _ in result["meta_info"].get(
|
||||||
|
"output_token_ids_logprobs", []
|
||||||
|
)[0]:
|
||||||
|
if token_id in label_token_ids:
|
||||||
|
logprobs[token_id] = logprob
|
||||||
|
|
||||||
|
# Get scores in order of label_token_ids
|
||||||
|
score_list = [
|
||||||
|
logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
# Apply softmax to logprobs if needed
|
||||||
|
if apply_softmax:
|
||||||
|
score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
|
||||||
|
else:
|
||||||
|
# Convert logprobs to probabilities if not using softmax
|
||||||
|
score_list = [
|
||||||
|
math.exp(x) if x != float("-inf") else 0.0 for x in score_list
|
||||||
|
]
|
||||||
|
|
||||||
|
scores.append(score_list)
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
async def print_exception_wrapper(func):
|
async def print_exception_wrapper(func):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -69,6 +69,8 @@ from sglang.srt.openai_api.protocol import (
|
|||||||
FunctionResponse,
|
FunctionResponse,
|
||||||
LogProbs,
|
LogProbs,
|
||||||
MultimodalEmbeddingInput,
|
MultimodalEmbeddingInput,
|
||||||
|
ScoringRequest,
|
||||||
|
ScoringResponse,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
TopLogprob,
|
TopLogprob,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
@@ -1928,3 +1930,31 @@ def to_openai_style_logprobs(
|
|||||||
append_top_logprobs(output_top_logprobs)
|
append_top_logprobs(output_top_logprobs)
|
||||||
|
|
||||||
return ret_logprobs
|
return ret_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
async def v1_score(tokenizer_manager, raw_request):
|
||||||
|
try:
|
||||||
|
# Parse request
|
||||||
|
request_data = await raw_request.json()
|
||||||
|
request = ScoringRequest(**request_data)
|
||||||
|
|
||||||
|
# Use tokenizer_manager's score_request method directly
|
||||||
|
scores = await tokenizer_manager.score_request(
|
||||||
|
query=request.query,
|
||||||
|
items=request.items,
|
||||||
|
label_token_ids=request.label_token_ids,
|
||||||
|
apply_softmax=request.apply_softmax,
|
||||||
|
item_first=request.item_first,
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create response with just the scores, without usage info
|
||||||
|
response = ScoringResponse(
|
||||||
|
scores=scores,
|
||||||
|
model=request.model,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in v1_score: {str(e)}")
|
||||||
|
return create_error_response(str(e))
|
||||||
|
|||||||
@@ -489,3 +489,27 @@ class EmbeddingResponse(BaseModel):
|
|||||||
model: str
|
model: str
|
||||||
object: str = "list"
|
object: str = "list"
|
||||||
usage: Optional[UsageInfo] = None
|
usage: Optional[UsageInfo] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ScoringRequest(BaseModel):
|
||||||
|
query: Optional[Union[str, List[int]]] = (
|
||||||
|
None # Query text or pre-tokenized token IDs
|
||||||
|
)
|
||||||
|
items: Optional[Union[str, List[str], List[List[int]]]] = (
|
||||||
|
None # Item text(s) or pre-tokenized token IDs
|
||||||
|
)
|
||||||
|
label_token_ids: Optional[List[int]] = (
|
||||||
|
None # Token IDs to compute probabilities for
|
||||||
|
)
|
||||||
|
apply_softmax: bool = False
|
||||||
|
item_first: bool = False
|
||||||
|
model: str
|
||||||
|
|
||||||
|
|
||||||
|
class ScoringResponse(BaseModel):
|
||||||
|
scores: List[
|
||||||
|
List[float]
|
||||||
|
] # List of lists of probabilities, each in the order of label_token_ids
|
||||||
|
model: str
|
||||||
|
usage: Optional[UsageInfo] = None
|
||||||
|
object: str = "scoring"
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import time
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
import requests
|
||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
@@ -599,7 +600,6 @@ class TestOpenAIServerEBNF(CustomTestCase):
|
|||||||
extra_body={"ebnf": ebnf_grammar},
|
extra_body={"ebnf": ebnf_grammar},
|
||||||
)
|
)
|
||||||
text = response.choices[0].message.content.strip()
|
text = response.choices[0].message.content.strip()
|
||||||
print("EBNF test output:", repr(text))
|
|
||||||
self.assertTrue(len(text) > 0, "Got empty text from EBNF generation")
|
self.assertTrue(len(text) > 0, "Got empty text from EBNF generation")
|
||||||
self.assertRegex(text, pattern, f"Text '{text}' doesn't match EBNF choices")
|
self.assertRegex(text, pattern, f"Text '{text}' doesn't match EBNF choices")
|
||||||
|
|
||||||
@@ -630,7 +630,6 @@ class TestOpenAIServerEBNF(CustomTestCase):
|
|||||||
extra_body={"ebnf": ebnf_grammar},
|
extra_body={"ebnf": ebnf_grammar},
|
||||||
)
|
)
|
||||||
text = response.choices[0].message.content.strip()
|
text = response.choices[0].message.content.strip()
|
||||||
print("EBNF strict JSON test output:", repr(text))
|
|
||||||
self.assertTrue(len(text) > 0, "Got empty text from EBNF strict JSON test")
|
self.assertTrue(len(text) > 0, "Got empty text from EBNF strict JSON test")
|
||||||
self.assertRegex(
|
self.assertRegex(
|
||||||
text, pattern, f"Text '{text}' not matching the EBNF strict JSON shape"
|
text, pattern, f"Text '{text}' not matching the EBNF strict JSON shape"
|
||||||
@@ -766,5 +765,168 @@ class TestOpenAIServerIgnoreEOS(CustomTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIV1Score(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.api_key = "sk-123456"
|
||||||
|
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
api_key=cls.api_key,
|
||||||
|
)
|
||||||
|
cls.base_url += "/v1/score"
|
||||||
|
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def run_score(
|
||||||
|
self, query, items, label_token_ids, apply_softmax=False, item_first=False
|
||||||
|
):
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": self.model,
|
||||||
|
"query": query,
|
||||||
|
"items": items,
|
||||||
|
"label_token_ids": label_token_ids,
|
||||||
|
"apply_softmax": apply_softmax,
|
||||||
|
"item_first": item_first,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def test_score_text_input(self):
|
||||||
|
"""Test scoring with text input"""
|
||||||
|
query = "The capital of France is"
|
||||||
|
items = ["Paris", "London", "Berlin"]
|
||||||
|
|
||||||
|
# Get valid token IDs from the tokenizer
|
||||||
|
label_token_ids = []
|
||||||
|
for item in items:
|
||||||
|
token_ids = self.tokenizer.encode(item, add_special_tokens=False)
|
||||||
|
if not token_ids:
|
||||||
|
self.fail(f"Failed to encode item: {item}")
|
||||||
|
label_token_ids.append(token_ids[0])
|
||||||
|
|
||||||
|
response = self.run_score(query, items, label_token_ids, apply_softmax=True)
|
||||||
|
|
||||||
|
# Handle error responses
|
||||||
|
if response.get("type") == "BadRequestError":
|
||||||
|
self.fail(f"Score request failed with error: {response['message']}")
|
||||||
|
|
||||||
|
# Verify response structure
|
||||||
|
self.assertIn("scores", response, "Response should have a 'scores' field")
|
||||||
|
self.assertIsInstance(response["scores"], list, "scores should be a list")
|
||||||
|
self.assertEqual(
|
||||||
|
len(response["scores"]),
|
||||||
|
len(items),
|
||||||
|
"Number of scores should match number of items",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Each score should be a list of floats in the order of label_token_ids
|
||||||
|
for i, score_list in enumerate(response["scores"]):
|
||||||
|
self.assertIsInstance(score_list, list, f"Score {i} should be a list")
|
||||||
|
self.assertEqual(
|
||||||
|
len(score_list),
|
||||||
|
len(label_token_ids),
|
||||||
|
f"Score {i} length should match label_token_ids",
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
all(isinstance(v, float) for v in score_list),
|
||||||
|
f"Score {i} values should be floats",
|
||||||
|
)
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
sum(score_list),
|
||||||
|
1.0,
|
||||||
|
places=6,
|
||||||
|
msg=f"Score {i} probabilities should sum to 1",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_score_token_input(self):
|
||||||
|
"""Test scoring with token IDs input"""
|
||||||
|
query = "The capital of France is"
|
||||||
|
items = ["Paris", "London", "Berlin"]
|
||||||
|
|
||||||
|
# Get valid token IDs
|
||||||
|
query_ids = self.tokenizer.encode(query, add_special_tokens=False)
|
||||||
|
item_ids = [
|
||||||
|
self.tokenizer.encode(item, add_special_tokens=False) for item in items
|
||||||
|
]
|
||||||
|
label_token_ids = [
|
||||||
|
ids[0] for ids in item_ids if ids
|
||||||
|
] # Get first token ID of each item
|
||||||
|
|
||||||
|
response = self.run_score(
|
||||||
|
query_ids, item_ids, label_token_ids, apply_softmax=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle error responses
|
||||||
|
if response.get("type") == "BadRequestError":
|
||||||
|
self.fail(f"Score request failed with error: {response['message']}")
|
||||||
|
|
||||||
|
# Verify response structure
|
||||||
|
self.assertIn("scores", response, "Response should have a 'scores' field")
|
||||||
|
self.assertIsInstance(response["scores"], list, "scores should be a list")
|
||||||
|
self.assertEqual(
|
||||||
|
len(response["scores"]),
|
||||||
|
len(items),
|
||||||
|
"Number of scores should match number of items",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Each score should be a list of floats in the order of label_token_ids
|
||||||
|
for i, score_list in enumerate(response["scores"]):
|
||||||
|
self.assertIsInstance(score_list, list, f"Score {i} should be a list")
|
||||||
|
self.assertEqual(
|
||||||
|
len(score_list),
|
||||||
|
len(label_token_ids),
|
||||||
|
f"Score {i} length should match label_token_ids",
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
all(isinstance(v, float) for v in score_list),
|
||||||
|
f"Score {i} values should be floats",
|
||||||
|
)
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
sum(score_list),
|
||||||
|
1.0,
|
||||||
|
places=6,
|
||||||
|
msg=f"Score {i} probabilities should sum to 1",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_score_error_handling(self):
|
||||||
|
"""Test error handling for invalid inputs"""
|
||||||
|
query = "The capital of France is"
|
||||||
|
items = ["Paris", "London", "Berlin"]
|
||||||
|
|
||||||
|
# Test with invalid token ID
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": self.model,
|
||||||
|
"query": query,
|
||||||
|
"items": items,
|
||||||
|
"label_token_ids": [999999], # Invalid token ID
|
||||||
|
"apply_softmax": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 400)
|
||||||
|
error_response = response.json()
|
||||||
|
self.assertEqual(error_response["type"], "BadRequestError")
|
||||||
|
self.assertIn("Token ID 999999 is out of vocabulary", error_response["message"])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
218
test/srt/test_score_api.py
Normal file
218
test/srt/test_score_api.py
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.engine import Engine
|
||||||
|
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase
|
||||||
|
|
||||||
|
TEST_MODEL_NAME = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
|
||||||
|
|
||||||
|
class TestScoreAPI(CustomTestCase):
|
||||||
|
"""Test the scoring API functionality."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up each test case."""
|
||||||
|
self.engine = Engine(model_path=TEST_MODEL_NAME)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
"""Clean up after each test case."""
|
||||||
|
if self.engine is not None:
|
||||||
|
self.engine.shutdown()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def compute_hf_scores(
|
||||||
|
self, query, items, label_token_ids, apply_softmax=False, item_first=False
|
||||||
|
):
|
||||||
|
"""Compute scores using direct HuggingFace model inference.
|
||||||
|
Returns probabilities for each token ID, optionally normalized with softmax.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The query text
|
||||||
|
items: List of item texts
|
||||||
|
label_token_ids: List of token IDs to compute probabilities for
|
||||||
|
apply_softmax: Whether to normalize probabilities using softmax
|
||||||
|
item_first: If True, prepend items to query. Otherwise append items to query.
|
||||||
|
"""
|
||||||
|
# Initialize HF model and tokenizer
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
TEST_MODEL_NAME, trust_remote_code=True
|
||||||
|
)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
TEST_MODEL_NAME, trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
scores = []
|
||||||
|
for item in items:
|
||||||
|
# Construct full text based on item_first parameter
|
||||||
|
full_text = f"{item}{query}" if item_first else f"{query}{item}"
|
||||||
|
inputs = tokenizer(full_text, return_tensors="pt").to(model.device)
|
||||||
|
|
||||||
|
# Get logits for the last token
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
last_token_logits = outputs.logits[0, -1]
|
||||||
|
|
||||||
|
# Get logits for just our target tokens
|
||||||
|
target_logits = last_token_logits[label_token_ids]
|
||||||
|
|
||||||
|
# Apply softmax over just the target tokens
|
||||||
|
target_probs = torch.softmax(target_logits, dim=-1)
|
||||||
|
|
||||||
|
# Convert to list of probabilities in order of label_token_ids
|
||||||
|
probs = [target_probs[i].item() for i in range(len(label_token_ids))]
|
||||||
|
|
||||||
|
scores.append(probs)
|
||||||
|
|
||||||
|
return scores
|
||||||
|
finally:
|
||||||
|
# Clean up HF resources
|
||||||
|
model.cpu()
|
||||||
|
del model
|
||||||
|
del tokenizer
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def _get_token_ids(self, tokens):
|
||||||
|
"""Helper method to get token IDs for a list of tokens."""
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
TEST_MODEL_NAME, trust_remote_code=True
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
label_token_ids = []
|
||||||
|
for token in tokens:
|
||||||
|
encoding = tokenizer.encode_plus(token, add_special_tokens=False)
|
||||||
|
token_ids = encoding["input_ids"]
|
||||||
|
label_token_ids.append(token_ids[0])
|
||||||
|
return label_token_ids
|
||||||
|
finally:
|
||||||
|
del tokenizer
|
||||||
|
|
||||||
|
def _compare_scores(self, hf_scores, sglang_scores, label_token_ids, case_name=""):
|
||||||
|
"""Helper method to compare scores between HF and SGLang using relative tolerance."""
|
||||||
|
self.assertEqual(
|
||||||
|
len(hf_scores),
|
||||||
|
len(sglang_scores),
|
||||||
|
f"Score lengths don't match for {case_name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use a relative tolerance of 1%
|
||||||
|
TOLERANCE = 0.01
|
||||||
|
|
||||||
|
for hf_score_list, sglang_score_list in zip(hf_scores, sglang_scores):
|
||||||
|
self.assertEqual(
|
||||||
|
len(hf_score_list),
|
||||||
|
len(sglang_score_list),
|
||||||
|
f"Score list lengths don't match for {case_name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
for hf_score, sglang_score in zip(hf_score_list, sglang_score_list):
|
||||||
|
diff = abs(hf_score - sglang_score)
|
||||||
|
self.assertLessEqual(
|
||||||
|
diff,
|
||||||
|
TOLERANCE,
|
||||||
|
msg=f"Scores differ by {diff:.2%} ({case_name}): "
|
||||||
|
f"HF={hf_score:.6f}, SGLang={sglang_score:.6f}",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertGreaterEqual(
|
||||||
|
sglang_score, 0, f"SGLang score {sglang_score:.6f} not in [0,1]"
|
||||||
|
)
|
||||||
|
self.assertLessEqual(
|
||||||
|
sglang_score, 1, f"SGLang score {sglang_score:.6f} not in [0,1]"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
sum(sglang_score_list),
|
||||||
|
1.0,
|
||||||
|
places=6,
|
||||||
|
msg=f"SGLang scores don't sum to 1 ({case_name}): {sum(sglang_score_list):.6f}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_score_consistency(self):
|
||||||
|
"""Test that SGLang scoring matches direct HuggingFace model scoring."""
|
||||||
|
# Define test cases
|
||||||
|
test_cases = [
|
||||||
|
{
|
||||||
|
"name": "default case",
|
||||||
|
"query": "I pledge allegiance",
|
||||||
|
"items": ["", " to"],
|
||||||
|
"item_first": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "item_first case",
|
||||||
|
"query": " is a city",
|
||||||
|
"items": ["Tokyo", "Japan"],
|
||||||
|
"item_first": True,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Common tokens to test for all cases
|
||||||
|
tokens = [" to", " the"]
|
||||||
|
label_token_ids = self._get_token_ids(tokens)
|
||||||
|
|
||||||
|
# Run each test case
|
||||||
|
for case in test_cases:
|
||||||
|
# Get scores from SGLang
|
||||||
|
sglang_scores = self.engine.score(
|
||||||
|
query=case["query"],
|
||||||
|
items=case["items"],
|
||||||
|
label_token_ids=label_token_ids,
|
||||||
|
apply_softmax=True,
|
||||||
|
item_first=case["item_first"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get scores from HuggingFace using the same parameters
|
||||||
|
hf_scores = self.compute_hf_scores(
|
||||||
|
query=case["query"],
|
||||||
|
items=case["items"],
|
||||||
|
label_token_ids=label_token_ids,
|
||||||
|
apply_softmax=True,
|
||||||
|
item_first=case["item_first"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare scores
|
||||||
|
self._compare_scores(
|
||||||
|
hf_scores, sglang_scores, label_token_ids, case["name"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_score_batch_handling(self):
|
||||||
|
"""Test that batch scoring works correctly."""
|
||||||
|
# Test with different batch sizes
|
||||||
|
batch_sizes = [1, 2, 4, 8]
|
||||||
|
label_token_ids = [1, 2, 3]
|
||||||
|
|
||||||
|
for batch_size in batch_sizes:
|
||||||
|
texts = [f"test {i}" for i in range(batch_size)]
|
||||||
|
scores = self.engine.score(
|
||||||
|
query="The test was",
|
||||||
|
items=texts,
|
||||||
|
label_token_ids=label_token_ids,
|
||||||
|
apply_softmax=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
len(scores),
|
||||||
|
batch_size,
|
||||||
|
f"Expected {batch_size} scores, got {len(scores)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify each score list has the correct length
|
||||||
|
for score_list in scores:
|
||||||
|
self.assertEqual(
|
||||||
|
len(score_list),
|
||||||
|
len(label_token_ids),
|
||||||
|
f"Score list length {len(score_list)} doesn't match label_token_ids length {len(label_token_ids)}",
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
all(isinstance(v, float) for v in score_list),
|
||||||
|
"All scores should be floats",
|
||||||
|
)
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
1.0, sum(score_list), 6, "Scores should sum to 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user