Decoder-only Scoring API (#6460)
Co-authored-by: Chanh Nguyen <cnguyen@linkedin.com>
This commit is contained in:
@@ -472,6 +472,79 @@ class Engine(EngineBase):
|
||||
def save_sharded_model(self, **kwargs):
|
||||
self.collective_rpc("save_sharded_model", **kwargs)
|
||||
|
||||
def score(
|
||||
self,
|
||||
query: Optional[Union[str, List[int]]] = None,
|
||||
items: Optional[Union[str, List[str], List[List[int]]]] = None,
|
||||
label_token_ids: Optional[List[int]] = None,
|
||||
apply_softmax: bool = False,
|
||||
item_first: bool = False,
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Score the probability of specified token IDs appearing after the given (query + item) pair. For example:
|
||||
query = "<|user|>Is the following city the capital of France? "
|
||||
items = ["Paris <|assistant|>", "London <|assistant|>", "Berlin <|assistant|>"]
|
||||
label_token_ids = [2332, 1223] # Token IDs for "Yes" and "No"
|
||||
item_first = False
|
||||
|
||||
This would pass the following prompts to the model:
|
||||
"<|user|>Is the following city the capital of France? Paris <|assistant|>"
|
||||
"<|user|>Is the following city the capital of France? London <|assistant|>"
|
||||
"<|user|>Is the following city the capital of France? Berlin <|assistant|>"
|
||||
The api would then return the probabilities of the model producing "Yes" and "No" as the next token.
|
||||
The output would look like:
|
||||
[[0.9, 0.1], [0.2, 0.8], [0.1, 0.9]]
|
||||
|
||||
|
||||
Args:
|
||||
query: The query text or pre-tokenized query token IDs. Must be provided.
|
||||
items: The item text(s) or pre-tokenized item token IDs. Must be provided.
|
||||
label_token_ids: List of token IDs to compute probabilities for. If None, no token probabilities will be computed.
|
||||
apply_softmax: Whether to normalize probabilities using softmax.
|
||||
item_first: If True, prepend items to query. Otherwise append items to query.
|
||||
|
||||
Returns:
|
||||
List of dictionaries mapping token IDs to their probabilities for each item.
|
||||
Each dictionary in the list corresponds to one item input.
|
||||
|
||||
Raises:
|
||||
ValueError: If query is not provided, or if items is not provided,
|
||||
or if token IDs are out of vocabulary, or if logprobs are not available for the specified tokens.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(
|
||||
self.tokenizer_manager.score_request(
|
||||
query=query,
|
||||
items=items,
|
||||
label_token_ids=label_token_ids,
|
||||
apply_softmax=apply_softmax,
|
||||
item_first=item_first,
|
||||
request=None,
|
||||
)
|
||||
)
|
||||
|
||||
async def async_score(
|
||||
self,
|
||||
query: Optional[Union[str, List[int]]] = None,
|
||||
items: Optional[Union[str, List[str], List[List[int]]]] = None,
|
||||
label_token_ids: Optional[List[int]] = None,
|
||||
apply_softmax: bool = False,
|
||||
item_first: bool = False,
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Asynchronous version of score method.
|
||||
|
||||
See score() for detailed documentation.
|
||||
"""
|
||||
return await self.tokenizer_manager.score_request(
|
||||
query=query,
|
||||
items=items,
|
||||
label_token_ids=label_token_ids,
|
||||
apply_softmax=apply_softmax,
|
||||
item_first=item_first,
|
||||
request=None,
|
||||
)
|
||||
|
||||
|
||||
def _set_envs_and_config(server_args: ServerArgs):
|
||||
# Set global environments
|
||||
|
||||
@@ -82,6 +82,7 @@ from sglang.srt.openai_api.adapter import (
|
||||
v1_retrieve_batch,
|
||||
v1_retrieve_file,
|
||||
v1_retrieve_file_content,
|
||||
v1_score,
|
||||
)
|
||||
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
||||
from sglang.srt.reasoning_parser import ReasoningParser
|
||||
@@ -720,6 +721,12 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
|
||||
return ORJSONResponse({"predictions": ret})
|
||||
|
||||
|
||||
@app.post("/v1/score")
|
||||
async def v1_score_request(raw_request: Request):
|
||||
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
|
||||
return await v1_score(_global_state.tokenizer_manager, raw_request)
|
||||
|
||||
|
||||
def _create_error_response(e):
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
|
||||
@@ -18,6 +18,7 @@ import copy
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
import signal
|
||||
@@ -42,6 +43,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import fastapi
|
||||
import torch
|
||||
import uvloop
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
@@ -1433,6 +1435,100 @@ class TokenizerManager:
|
||||
if len(self.model_update_tmp) == self.server_args.dp_size:
|
||||
self.model_update_result.set_result(self.model_update_tmp)
|
||||
|
||||
async def score_request(
|
||||
self,
|
||||
query: Optional[Union[str, List[int]]] = None,
|
||||
items: Optional[Union[str, List[str], List[List[int]]]] = None,
|
||||
label_token_ids: Optional[List[int]] = None,
|
||||
apply_softmax: bool = False,
|
||||
item_first: bool = False,
|
||||
request: Optional[Any] = None,
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
See Engine.score() for more details.
|
||||
"""
|
||||
if label_token_ids is None:
|
||||
raise ValueError("label_token_ids must be provided")
|
||||
|
||||
if self.tokenizer is not None:
|
||||
vocab_size = self.tokenizer.vocab_size
|
||||
for token_id in label_token_ids:
|
||||
if token_id >= vocab_size:
|
||||
raise ValueError(
|
||||
f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
|
||||
)
|
||||
|
||||
# Handle string or tokenized query/items
|
||||
if isinstance(query, str) and (
|
||||
isinstance(items, str)
|
||||
or (isinstance(items, list) and (not items or isinstance(items[0], str)))
|
||||
):
|
||||
# Both query and items are text
|
||||
items_list = [items] if isinstance(items, str) else items
|
||||
if item_first:
|
||||
prompts = [f"{item}{query}" for item in items_list]
|
||||
else:
|
||||
prompts = [f"{query}{item}" for item in items_list]
|
||||
batch_request = GenerateReqInput(
|
||||
text=prompts,
|
||||
return_logprob=True,
|
||||
token_ids_logprob=label_token_ids,
|
||||
stream=False,
|
||||
sampling_params={"max_new_tokens": 1},
|
||||
)
|
||||
elif (
|
||||
isinstance(query, list)
|
||||
and isinstance(items, list)
|
||||
and items
|
||||
and isinstance(items[0], list)
|
||||
):
|
||||
# Both query and items are token IDs
|
||||
if item_first:
|
||||
input_ids_list = [item + query for item in items]
|
||||
else:
|
||||
input_ids_list = [query + item for item in items]
|
||||
batch_request = GenerateReqInput(
|
||||
input_ids=input_ids_list,
|
||||
return_logprob=True,
|
||||
token_ids_logprob=label_token_ids,
|
||||
stream=False,
|
||||
sampling_params={"max_new_tokens": 1},
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid combination of query/items types for score_request."
|
||||
)
|
||||
|
||||
results = await self.generate_request(batch_request, request).__anext__()
|
||||
scores = []
|
||||
|
||||
for result in results:
|
||||
# Get logprobs for each token
|
||||
logprobs = {}
|
||||
for logprob, token_id, _ in result["meta_info"].get(
|
||||
"output_token_ids_logprobs", []
|
||||
)[0]:
|
||||
if token_id in label_token_ids:
|
||||
logprobs[token_id] = logprob
|
||||
|
||||
# Get scores in order of label_token_ids
|
||||
score_list = [
|
||||
logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
|
||||
]
|
||||
|
||||
# Apply softmax to logprobs if needed
|
||||
if apply_softmax:
|
||||
score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
|
||||
else:
|
||||
# Convert logprobs to probabilities if not using softmax
|
||||
score_list = [
|
||||
math.exp(x) if x != float("-inf") else 0.0 for x in score_list
|
||||
]
|
||||
|
||||
scores.append(score_list)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
async def print_exception_wrapper(func):
|
||||
"""
|
||||
|
||||
@@ -69,6 +69,8 @@ from sglang.srt.openai_api.protocol import (
|
||||
FunctionResponse,
|
||||
LogProbs,
|
||||
MultimodalEmbeddingInput,
|
||||
ScoringRequest,
|
||||
ScoringResponse,
|
||||
ToolCall,
|
||||
TopLogprob,
|
||||
UsageInfo,
|
||||
@@ -1928,3 +1930,31 @@ def to_openai_style_logprobs(
|
||||
append_top_logprobs(output_top_logprobs)
|
||||
|
||||
return ret_logprobs
|
||||
|
||||
|
||||
async def v1_score(tokenizer_manager, raw_request):
|
||||
try:
|
||||
# Parse request
|
||||
request_data = await raw_request.json()
|
||||
request = ScoringRequest(**request_data)
|
||||
|
||||
# Use tokenizer_manager's score_request method directly
|
||||
scores = await tokenizer_manager.score_request(
|
||||
query=request.query,
|
||||
items=request.items,
|
||||
label_token_ids=request.label_token_ids,
|
||||
apply_softmax=request.apply_softmax,
|
||||
item_first=request.item_first,
|
||||
request=request,
|
||||
)
|
||||
|
||||
# Create response with just the scores, without usage info
|
||||
response = ScoringResponse(
|
||||
scores=scores,
|
||||
model=request.model,
|
||||
)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in v1_score: {str(e)}")
|
||||
return create_error_response(str(e))
|
||||
|
||||
@@ -489,3 +489,27 @@ class EmbeddingResponse(BaseModel):
|
||||
model: str
|
||||
object: str = "list"
|
||||
usage: Optional[UsageInfo] = None
|
||||
|
||||
|
||||
class ScoringRequest(BaseModel):
|
||||
query: Optional[Union[str, List[int]]] = (
|
||||
None # Query text or pre-tokenized token IDs
|
||||
)
|
||||
items: Optional[Union[str, List[str], List[List[int]]]] = (
|
||||
None # Item text(s) or pre-tokenized token IDs
|
||||
)
|
||||
label_token_ids: Optional[List[int]] = (
|
||||
None # Token IDs to compute probabilities for
|
||||
)
|
||||
apply_softmax: bool = False
|
||||
item_first: bool = False
|
||||
model: str
|
||||
|
||||
|
||||
class ScoringResponse(BaseModel):
|
||||
scores: List[
|
||||
List[float]
|
||||
] # List of lists of probabilities, each in the order of label_token_ids
|
||||
model: str
|
||||
usage: Optional[UsageInfo] = None
|
||||
object: str = "scoring"
|
||||
|
||||
Reference in New Issue
Block a user