Add support for logprobs in OpenAI chat API (#852)

This commit is contained in:
yichuan~
2024-08-01 15:08:21 +08:00
committed by GitHub
parent 0c0c81372e
commit ca600e8cd6
3 changed files with 117 additions and 21 deletions

View File

@@ -43,7 +43,9 @@ from sglang.srt.openai_api.protocol import (
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatCompletionTokenLogprob,
ChatMessage,
ChoiceLogprobs,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
@@ -54,6 +56,7 @@ from sglang.srt.openai_api.protocol import (
FileRequest,
FileResponse,
LogProbs,
TopLogprob,
UsageInfo,
)
@@ -70,7 +73,7 @@ class FileMetadata:
batch_storage: Dict[str, BatchResponse] = {}
file_id_request: Dict[str, FileMetadata] = {}
file_id_response: Dict[str, FileResponse] = {}
## map file id to file path in SGlang backend
# map file id to file path in SGlang backend
file_id_storage: Dict[str, str] = {}
@@ -261,7 +264,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
failed_requests += len(file_request_list)
for idx, response in enumerate(responses):
## the batch_req here can be changed to be named within a batch granularity
# the batch_req here can be changed to be named within a batch granularity
response_json = {
"id": f"batch_req_{uuid.uuid4()}",
"custom_id": file_request_list[idx].get("custom_id"),
@@ -333,6 +336,8 @@ def v1_generate_request(all_requests):
prompts = []
sampling_params_list = []
return_logprobs = []
top_logprobs_nums = []
first_prompt_type = type(all_requests[0].prompt)
for request in all_requests:
prompt = request.prompt
@@ -340,6 +345,10 @@ def v1_generate_request(all_requests):
type(prompt) == first_prompt_type
), "All prompts must be of the same type in file input settings"
prompts.append(prompt)
return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
top_logprobs_nums.append(
request.logprobs if request.logprobs is not None else 0
)
sampling_params_list.append(
{
"temperature": request.temperature,
@@ -361,6 +370,8 @@ def v1_generate_request(all_requests):
if len(all_requests) == 1:
prompt = prompts[0]
sampling_params_list = sampling_params_list[0]
return_logprobs = return_logprobs[0]
top_logprobs_nums = top_logprobs_nums[0]
if isinstance(prompt, str) or isinstance(prompt[0], str):
prompt_kwargs = {"text": prompt}
else:
@@ -370,15 +381,11 @@ def v1_generate_request(all_requests):
prompt_kwargs = {"text": prompts}
else:
prompt_kwargs = {"input_ids": prompts}
adapted_request = GenerateReqInput(
**prompt_kwargs,
sampling_params=sampling_params_list,
return_logprob=all_requests[0].logprobs is not None
and all_requests[0].logprobs > 0,
top_logprobs_num=(
all_requests[0].logprobs if all_requests[0].logprobs is not None else 0
),
return_logprob=return_logprobs,
top_logprobs_num=top_logprobs_nums,
return_text_in_logprobs=True,
stream=all_requests[0].stream,
)
@@ -430,7 +437,7 @@ def v1_generate_response(request, ret, to_file=False):
logprobs = None
if to_file:
## to make the choise data json serializable
# to make the choise data json serializable
choice_data = {
"index": 0,
"text": text,
@@ -454,7 +461,7 @@ def v1_generate_response(request, ret, to_file=False):
"status_code": 200,
"request_id": ret[i]["meta_info"]["id"],
"body": {
## remain the same but if needed we can change that
# remain the same but if needed we can change that
"id": ret[i]["meta_info"]["id"],
"object": "text_completion",
"created": int(time.time()),
@@ -590,6 +597,8 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
texts = []
sampling_params_list = []
image_data_list = []
return_logprobs = []
top_logprobs_nums = []
for request in all_requests:
# Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string.
@@ -620,6 +629,8 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
stop = request.stop
image_data = None
texts.append(prompt)
return_logprobs.append(request.logprobs)
top_logprobs_nums.append(request.top_logprobs)
sampling_params_list.append(
{
"temperature": request.temperature,
@@ -637,11 +648,16 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
texts = texts[0]
sampling_params_list = sampling_params_list[0]
image_data = image_data_list[0]
return_logprobs = return_logprobs[0]
top_logprobs_nums = top_logprobs_nums[0]
adapted_request = GenerateReqInput(
text=texts,
image_data=image_data,
sampling_params=sampling_params_list,
stream=request.stream,
return_logprob=return_logprobs,
top_logprobs_num=top_logprobs_nums,
stream=all_requests[0].stream,
return_text_in_logprobs=True,
)
if len(all_requests) == 1:
return adapted_request, all_requests[0]
@@ -654,26 +670,63 @@ def v1_chat_generate_response(request, ret, to_file=False):
total_completion_tokens = 0
for idx, ret_item in enumerate(ret):
logprobs = False
if isinstance(request, List) and request[idx].logprobs:
logprobs = True
elif (not isinstance(request, List)) and request.logprobs:
logprobs = True
if logprobs:
logprobs = to_openai_style_logprobs(
output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
)
token_logprobs = []
for token, logprob in zip(logprobs.tokens, logprobs.token_logprobs):
token_bytes = list(token.encode("utf-8"))
top_logprobs = []
if logprobs.top_logprobs:
for top_token, top_logprob in logprobs.top_logprobs[0].items():
top_token_bytes = list(top_token.encode("utf-8"))
top_logprobs.append(
TopLogprob(
token=top_token,
bytes=top_token_bytes,
logprob=top_logprob,
)
)
token_logprobs.append(
ChatCompletionTokenLogprob(
token=token,
bytes=token_bytes,
logprob=logprob,
top_logprobs=top_logprobs,
)
)
choice_logprobs = ChoiceLogprobs(content=token_logprobs)
else:
choice_logprobs = None
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
completion_tokens = ret_item["meta_info"]["completion_tokens"]
if to_file:
## to make the choice data json serializable
# to make the choice data json serializable
choice_data = {
"index": 0,
"message": {"role": "assistant", "content": ret_item["text"]},
"logprobs": None,
"logprobs": choice_logprobs,
"finish_reason": ret_item["meta_info"]["finish_reason"],
}
else:
choice_data = ChatCompletionResponseChoice(
index=idx,
message=ChatMessage(role="assistant", content=ret_item["text"]),
logprobs=choice_logprobs,
finish_reason=ret_item["meta_info"]["finish_reason"],
)
choices.append(choice_data)
total_prompt_tokens = prompt_tokens
total_prompt_tokens += prompt_tokens
total_completion_tokens += completion_tokens
if to_file:
responses = []
@@ -683,7 +736,7 @@ def v1_chat_generate_response(request, ret, to_file=False):
"status_code": 200,
"request_id": ret[i]["meta_info"]["id"],
"body": {
## remain the same but if needed we can change that
# remain the same but if needed we can change that
"id": ret[i]["meta_info"]["id"],
"object": "chat.completion",
"created": int(time.time()),

View File

@@ -54,6 +54,24 @@ class LogProbs(BaseModel):
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
class TopLogprob(BaseModel):
token: str
bytes: List[int]
logprob: float
class ChatCompletionTokenLogprob(BaseModel):
token: str
bytes: List[int]
logprob: float
top_logprobs: List[TopLogprob]
class ChoiceLogprobs(BaseModel):
# build for v1/chat/completions response
content: List[ChatCompletionTokenLogprob]
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
@@ -239,8 +257,8 @@ class ChatMessage(BaseModel):
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
finish_reason: str
class ChatCompletionResponse(BaseModel):