Add support for logprobs in OpenAI chat API (#852)
This commit is contained in:
@@ -106,13 +106,12 @@ response = client.chat.completions.create(
|
||||
{"role": "user", "content": "List 3 countries and their capitals."},
|
||||
],
|
||||
temperature=0.8,
|
||||
max_tokens=64,
|
||||
max_tokens=1,
|
||||
logprobs=True,
|
||||
n=1,
|
||||
top_logprobs=3,
|
||||
)
|
||||
print(response)
|
||||
|
||||
|
||||
# Chat completion
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
@@ -121,8 +120,34 @@ response = client.chat.completions.create(
|
||||
{"role": "user", "content": "List 3 countries and their capitals."},
|
||||
],
|
||||
temperature=0.8,
|
||||
max_tokens=64,
|
||||
max_tokens=1,
|
||||
n=1,
|
||||
)
|
||||
print(response)
|
||||
|
||||
# Chat completion
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "List 3 countries and their capitals."},
|
||||
],
|
||||
temperature=0.8,
|
||||
max_tokens=1,
|
||||
logprobs=True,
|
||||
top_logprobs=3,
|
||||
)
|
||||
print(response)
|
||||
|
||||
# Chat completion
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "List 3 countries and their capitals."},
|
||||
],
|
||||
temperature=0.8,
|
||||
max_tokens=1,
|
||||
n=4,
|
||||
)
|
||||
print(response)
|
||||
|
||||
@@ -43,7 +43,9 @@ from sglang.srt.openai_api.protocol import (
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatCompletionTokenLogprob,
|
||||
ChatMessage,
|
||||
ChoiceLogprobs,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
@@ -54,6 +56,7 @@ from sglang.srt.openai_api.protocol import (
|
||||
FileRequest,
|
||||
FileResponse,
|
||||
LogProbs,
|
||||
TopLogprob,
|
||||
UsageInfo,
|
||||
)
|
||||
|
||||
@@ -70,7 +73,7 @@ class FileMetadata:
|
||||
batch_storage: Dict[str, BatchResponse] = {}
|
||||
file_id_request: Dict[str, FileMetadata] = {}
|
||||
file_id_response: Dict[str, FileResponse] = {}
|
||||
## map file id to file path in SGlang backend
|
||||
# map file id to file path in SGlang backend
|
||||
file_id_storage: Dict[str, str] = {}
|
||||
|
||||
|
||||
@@ -261,7 +264,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
||||
failed_requests += len(file_request_list)
|
||||
|
||||
for idx, response in enumerate(responses):
|
||||
## the batch_req here can be changed to be named within a batch granularity
|
||||
# the batch_req here can be changed to be named within a batch granularity
|
||||
response_json = {
|
||||
"id": f"batch_req_{uuid.uuid4()}",
|
||||
"custom_id": file_request_list[idx].get("custom_id"),
|
||||
@@ -333,6 +336,8 @@ def v1_generate_request(all_requests):
|
||||
|
||||
prompts = []
|
||||
sampling_params_list = []
|
||||
return_logprobs = []
|
||||
top_logprobs_nums = []
|
||||
first_prompt_type = type(all_requests[0].prompt)
|
||||
for request in all_requests:
|
||||
prompt = request.prompt
|
||||
@@ -340,6 +345,10 @@ def v1_generate_request(all_requests):
|
||||
type(prompt) == first_prompt_type
|
||||
), "All prompts must be of the same type in file input settings"
|
||||
prompts.append(prompt)
|
||||
return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
|
||||
top_logprobs_nums.append(
|
||||
request.logprobs if request.logprobs is not None else 0
|
||||
)
|
||||
sampling_params_list.append(
|
||||
{
|
||||
"temperature": request.temperature,
|
||||
@@ -361,6 +370,8 @@ def v1_generate_request(all_requests):
|
||||
if len(all_requests) == 1:
|
||||
prompt = prompts[0]
|
||||
sampling_params_list = sampling_params_list[0]
|
||||
return_logprobs = return_logprobs[0]
|
||||
top_logprobs_nums = top_logprobs_nums[0]
|
||||
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
||||
prompt_kwargs = {"text": prompt}
|
||||
else:
|
||||
@@ -370,15 +381,11 @@ def v1_generate_request(all_requests):
|
||||
prompt_kwargs = {"text": prompts}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": prompts}
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
**prompt_kwargs,
|
||||
sampling_params=sampling_params_list,
|
||||
return_logprob=all_requests[0].logprobs is not None
|
||||
and all_requests[0].logprobs > 0,
|
||||
top_logprobs_num=(
|
||||
all_requests[0].logprobs if all_requests[0].logprobs is not None else 0
|
||||
),
|
||||
return_logprob=return_logprobs,
|
||||
top_logprobs_num=top_logprobs_nums,
|
||||
return_text_in_logprobs=True,
|
||||
stream=all_requests[0].stream,
|
||||
)
|
||||
@@ -430,7 +437,7 @@ def v1_generate_response(request, ret, to_file=False):
|
||||
logprobs = None
|
||||
|
||||
if to_file:
|
||||
## to make the choise data json serializable
|
||||
# to make the choise data json serializable
|
||||
choice_data = {
|
||||
"index": 0,
|
||||
"text": text,
|
||||
@@ -454,7 +461,7 @@ def v1_generate_response(request, ret, to_file=False):
|
||||
"status_code": 200,
|
||||
"request_id": ret[i]["meta_info"]["id"],
|
||||
"body": {
|
||||
## remain the same but if needed we can change that
|
||||
# remain the same but if needed we can change that
|
||||
"id": ret[i]["meta_info"]["id"],
|
||||
"object": "text_completion",
|
||||
"created": int(time.time()),
|
||||
@@ -590,6 +597,8 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
||||
texts = []
|
||||
sampling_params_list = []
|
||||
image_data_list = []
|
||||
return_logprobs = []
|
||||
top_logprobs_nums = []
|
||||
for request in all_requests:
|
||||
# Prep the data needed for the underlying GenerateReqInput:
|
||||
# - prompt: The full prompt string.
|
||||
@@ -620,6 +629,8 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
||||
stop = request.stop
|
||||
image_data = None
|
||||
texts.append(prompt)
|
||||
return_logprobs.append(request.logprobs)
|
||||
top_logprobs_nums.append(request.top_logprobs)
|
||||
sampling_params_list.append(
|
||||
{
|
||||
"temperature": request.temperature,
|
||||
@@ -637,11 +648,16 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
||||
texts = texts[0]
|
||||
sampling_params_list = sampling_params_list[0]
|
||||
image_data = image_data_list[0]
|
||||
return_logprobs = return_logprobs[0]
|
||||
top_logprobs_nums = top_logprobs_nums[0]
|
||||
adapted_request = GenerateReqInput(
|
||||
text=texts,
|
||||
image_data=image_data,
|
||||
sampling_params=sampling_params_list,
|
||||
stream=request.stream,
|
||||
return_logprob=return_logprobs,
|
||||
top_logprobs_num=top_logprobs_nums,
|
||||
stream=all_requests[0].stream,
|
||||
return_text_in_logprobs=True,
|
||||
)
|
||||
if len(all_requests) == 1:
|
||||
return adapted_request, all_requests[0]
|
||||
@@ -654,26 +670,63 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
||||
total_completion_tokens = 0
|
||||
|
||||
for idx, ret_item in enumerate(ret):
|
||||
logprobs = False
|
||||
if isinstance(request, List) and request[idx].logprobs:
|
||||
logprobs = True
|
||||
elif (not isinstance(request, List)) and request.logprobs:
|
||||
logprobs = True
|
||||
if logprobs:
|
||||
logprobs = to_openai_style_logprobs(
|
||||
output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
|
||||
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
|
||||
)
|
||||
token_logprobs = []
|
||||
for token, logprob in zip(logprobs.tokens, logprobs.token_logprobs):
|
||||
token_bytes = list(token.encode("utf-8"))
|
||||
top_logprobs = []
|
||||
if logprobs.top_logprobs:
|
||||
for top_token, top_logprob in logprobs.top_logprobs[0].items():
|
||||
top_token_bytes = list(top_token.encode("utf-8"))
|
||||
top_logprobs.append(
|
||||
TopLogprob(
|
||||
token=top_token,
|
||||
bytes=top_token_bytes,
|
||||
logprob=top_logprob,
|
||||
)
|
||||
)
|
||||
token_logprobs.append(
|
||||
ChatCompletionTokenLogprob(
|
||||
token=token,
|
||||
bytes=token_bytes,
|
||||
logprob=logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
choice_logprobs = ChoiceLogprobs(content=token_logprobs)
|
||||
else:
|
||||
choice_logprobs = None
|
||||
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
|
||||
completion_tokens = ret_item["meta_info"]["completion_tokens"]
|
||||
|
||||
if to_file:
|
||||
## to make the choice data json serializable
|
||||
# to make the choice data json serializable
|
||||
choice_data = {
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": ret_item["text"]},
|
||||
"logprobs": None,
|
||||
"logprobs": choice_logprobs,
|
||||
"finish_reason": ret_item["meta_info"]["finish_reason"],
|
||||
}
|
||||
else:
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=idx,
|
||||
message=ChatMessage(role="assistant", content=ret_item["text"]),
|
||||
logprobs=choice_logprobs,
|
||||
finish_reason=ret_item["meta_info"]["finish_reason"],
|
||||
)
|
||||
|
||||
choices.append(choice_data)
|
||||
total_prompt_tokens = prompt_tokens
|
||||
total_prompt_tokens += prompt_tokens
|
||||
total_completion_tokens += completion_tokens
|
||||
if to_file:
|
||||
responses = []
|
||||
@@ -683,7 +736,7 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
||||
"status_code": 200,
|
||||
"request_id": ret[i]["meta_info"]["id"],
|
||||
"body": {
|
||||
## remain the same but if needed we can change that
|
||||
# remain the same but if needed we can change that
|
||||
"id": ret[i]["meta_info"]["id"],
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
|
||||
@@ -54,6 +54,24 @@ class LogProbs(BaseModel):
|
||||
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TopLogprob(BaseModel):
|
||||
token: str
|
||||
bytes: List[int]
|
||||
logprob: float
|
||||
|
||||
|
||||
class ChatCompletionTokenLogprob(BaseModel):
|
||||
token: str
|
||||
bytes: List[int]
|
||||
logprob: float
|
||||
top_logprobs: List[TopLogprob]
|
||||
|
||||
|
||||
class ChoiceLogprobs(BaseModel):
|
||||
# build for v1/chat/completions response
|
||||
content: List[ChatCompletionTokenLogprob]
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
@@ -239,8 +257,8 @@ class ChatMessage(BaseModel):
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[str] = None
|
||||
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
|
||||
finish_reason: str
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user