Add support for logprobs in OpenAI chat API (#852)
This commit is contained in:
@@ -106,13 +106,12 @@ response = client.chat.completions.create(
|
|||||||
{"role": "user", "content": "List 3 countries and their capitals."},
|
{"role": "user", "content": "List 3 countries and their capitals."},
|
||||||
],
|
],
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
max_tokens=64,
|
max_tokens=1,
|
||||||
logprobs=True,
|
logprobs=True,
|
||||||
n=1,
|
top_logprobs=3,
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
# Chat completion
|
# Chat completion
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
model="default",
|
model="default",
|
||||||
@@ -121,8 +120,34 @@ response = client.chat.completions.create(
|
|||||||
{"role": "user", "content": "List 3 countries and their capitals."},
|
{"role": "user", "content": "List 3 countries and their capitals."},
|
||||||
],
|
],
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
max_tokens=64,
|
max_tokens=1,
|
||||||
|
n=1,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
# Chat completion
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||||
|
{"role": "user", "content": "List 3 countries and their capitals."},
|
||||||
|
],
|
||||||
|
temperature=0.8,
|
||||||
|
max_tokens=1,
|
||||||
logprobs=True,
|
logprobs=True,
|
||||||
|
top_logprobs=3,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
# Chat completion
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||||
|
{"role": "user", "content": "List 3 countries and their capitals."},
|
||||||
|
],
|
||||||
|
temperature=0.8,
|
||||||
|
max_tokens=1,
|
||||||
n=4,
|
n=4,
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
|
|||||||
@@ -43,7 +43,9 @@ from sglang.srt.openai_api.protocol import (
|
|||||||
ChatCompletionResponseChoice,
|
ChatCompletionResponseChoice,
|
||||||
ChatCompletionResponseStreamChoice,
|
ChatCompletionResponseStreamChoice,
|
||||||
ChatCompletionStreamResponse,
|
ChatCompletionStreamResponse,
|
||||||
|
ChatCompletionTokenLogprob,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
|
ChoiceLogprobs,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
CompletionResponseChoice,
|
CompletionResponseChoice,
|
||||||
@@ -54,6 +56,7 @@ from sglang.srt.openai_api.protocol import (
|
|||||||
FileRequest,
|
FileRequest,
|
||||||
FileResponse,
|
FileResponse,
|
||||||
LogProbs,
|
LogProbs,
|
||||||
|
TopLogprob,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -70,7 +73,7 @@ class FileMetadata:
|
|||||||
batch_storage: Dict[str, BatchResponse] = {}
|
batch_storage: Dict[str, BatchResponse] = {}
|
||||||
file_id_request: Dict[str, FileMetadata] = {}
|
file_id_request: Dict[str, FileMetadata] = {}
|
||||||
file_id_response: Dict[str, FileResponse] = {}
|
file_id_response: Dict[str, FileResponse] = {}
|
||||||
## map file id to file path in SGlang backend
|
# map file id to file path in SGlang backend
|
||||||
file_id_storage: Dict[str, str] = {}
|
file_id_storage: Dict[str, str] = {}
|
||||||
|
|
||||||
|
|
||||||
@@ -261,7 +264,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|||||||
failed_requests += len(file_request_list)
|
failed_requests += len(file_request_list)
|
||||||
|
|
||||||
for idx, response in enumerate(responses):
|
for idx, response in enumerate(responses):
|
||||||
## the batch_req here can be changed to be named within a batch granularity
|
# the batch_req here can be changed to be named within a batch granularity
|
||||||
response_json = {
|
response_json = {
|
||||||
"id": f"batch_req_{uuid.uuid4()}",
|
"id": f"batch_req_{uuid.uuid4()}",
|
||||||
"custom_id": file_request_list[idx].get("custom_id"),
|
"custom_id": file_request_list[idx].get("custom_id"),
|
||||||
@@ -333,6 +336,8 @@ def v1_generate_request(all_requests):
|
|||||||
|
|
||||||
prompts = []
|
prompts = []
|
||||||
sampling_params_list = []
|
sampling_params_list = []
|
||||||
|
return_logprobs = []
|
||||||
|
top_logprobs_nums = []
|
||||||
first_prompt_type = type(all_requests[0].prompt)
|
first_prompt_type = type(all_requests[0].prompt)
|
||||||
for request in all_requests:
|
for request in all_requests:
|
||||||
prompt = request.prompt
|
prompt = request.prompt
|
||||||
@@ -340,6 +345,10 @@ def v1_generate_request(all_requests):
|
|||||||
type(prompt) == first_prompt_type
|
type(prompt) == first_prompt_type
|
||||||
), "All prompts must be of the same type in file input settings"
|
), "All prompts must be of the same type in file input settings"
|
||||||
prompts.append(prompt)
|
prompts.append(prompt)
|
||||||
|
return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
|
||||||
|
top_logprobs_nums.append(
|
||||||
|
request.logprobs if request.logprobs is not None else 0
|
||||||
|
)
|
||||||
sampling_params_list.append(
|
sampling_params_list.append(
|
||||||
{
|
{
|
||||||
"temperature": request.temperature,
|
"temperature": request.temperature,
|
||||||
@@ -361,6 +370,8 @@ def v1_generate_request(all_requests):
|
|||||||
if len(all_requests) == 1:
|
if len(all_requests) == 1:
|
||||||
prompt = prompts[0]
|
prompt = prompts[0]
|
||||||
sampling_params_list = sampling_params_list[0]
|
sampling_params_list = sampling_params_list[0]
|
||||||
|
return_logprobs = return_logprobs[0]
|
||||||
|
top_logprobs_nums = top_logprobs_nums[0]
|
||||||
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
||||||
prompt_kwargs = {"text": prompt}
|
prompt_kwargs = {"text": prompt}
|
||||||
else:
|
else:
|
||||||
@@ -370,15 +381,11 @@ def v1_generate_request(all_requests):
|
|||||||
prompt_kwargs = {"text": prompts}
|
prompt_kwargs = {"text": prompts}
|
||||||
else:
|
else:
|
||||||
prompt_kwargs = {"input_ids": prompts}
|
prompt_kwargs = {"input_ids": prompts}
|
||||||
|
|
||||||
adapted_request = GenerateReqInput(
|
adapted_request = GenerateReqInput(
|
||||||
**prompt_kwargs,
|
**prompt_kwargs,
|
||||||
sampling_params=sampling_params_list,
|
sampling_params=sampling_params_list,
|
||||||
return_logprob=all_requests[0].logprobs is not None
|
return_logprob=return_logprobs,
|
||||||
and all_requests[0].logprobs > 0,
|
top_logprobs_num=top_logprobs_nums,
|
||||||
top_logprobs_num=(
|
|
||||||
all_requests[0].logprobs if all_requests[0].logprobs is not None else 0
|
|
||||||
),
|
|
||||||
return_text_in_logprobs=True,
|
return_text_in_logprobs=True,
|
||||||
stream=all_requests[0].stream,
|
stream=all_requests[0].stream,
|
||||||
)
|
)
|
||||||
@@ -430,7 +437,7 @@ def v1_generate_response(request, ret, to_file=False):
|
|||||||
logprobs = None
|
logprobs = None
|
||||||
|
|
||||||
if to_file:
|
if to_file:
|
||||||
## to make the choise data json serializable
|
# to make the choise data json serializable
|
||||||
choice_data = {
|
choice_data = {
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"text": text,
|
"text": text,
|
||||||
@@ -454,7 +461,7 @@ def v1_generate_response(request, ret, to_file=False):
|
|||||||
"status_code": 200,
|
"status_code": 200,
|
||||||
"request_id": ret[i]["meta_info"]["id"],
|
"request_id": ret[i]["meta_info"]["id"],
|
||||||
"body": {
|
"body": {
|
||||||
## remain the same but if needed we can change that
|
# remain the same but if needed we can change that
|
||||||
"id": ret[i]["meta_info"]["id"],
|
"id": ret[i]["meta_info"]["id"],
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"created": int(time.time()),
|
"created": int(time.time()),
|
||||||
@@ -590,6 +597,8 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|||||||
texts = []
|
texts = []
|
||||||
sampling_params_list = []
|
sampling_params_list = []
|
||||||
image_data_list = []
|
image_data_list = []
|
||||||
|
return_logprobs = []
|
||||||
|
top_logprobs_nums = []
|
||||||
for request in all_requests:
|
for request in all_requests:
|
||||||
# Prep the data needed for the underlying GenerateReqInput:
|
# Prep the data needed for the underlying GenerateReqInput:
|
||||||
# - prompt: The full prompt string.
|
# - prompt: The full prompt string.
|
||||||
@@ -620,6 +629,8 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|||||||
stop = request.stop
|
stop = request.stop
|
||||||
image_data = None
|
image_data = None
|
||||||
texts.append(prompt)
|
texts.append(prompt)
|
||||||
|
return_logprobs.append(request.logprobs)
|
||||||
|
top_logprobs_nums.append(request.top_logprobs)
|
||||||
sampling_params_list.append(
|
sampling_params_list.append(
|
||||||
{
|
{
|
||||||
"temperature": request.temperature,
|
"temperature": request.temperature,
|
||||||
@@ -637,11 +648,16 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|||||||
texts = texts[0]
|
texts = texts[0]
|
||||||
sampling_params_list = sampling_params_list[0]
|
sampling_params_list = sampling_params_list[0]
|
||||||
image_data = image_data_list[0]
|
image_data = image_data_list[0]
|
||||||
|
return_logprobs = return_logprobs[0]
|
||||||
|
top_logprobs_nums = top_logprobs_nums[0]
|
||||||
adapted_request = GenerateReqInput(
|
adapted_request = GenerateReqInput(
|
||||||
text=texts,
|
text=texts,
|
||||||
image_data=image_data,
|
image_data=image_data,
|
||||||
sampling_params=sampling_params_list,
|
sampling_params=sampling_params_list,
|
||||||
stream=request.stream,
|
return_logprob=return_logprobs,
|
||||||
|
top_logprobs_num=top_logprobs_nums,
|
||||||
|
stream=all_requests[0].stream,
|
||||||
|
return_text_in_logprobs=True,
|
||||||
)
|
)
|
||||||
if len(all_requests) == 1:
|
if len(all_requests) == 1:
|
||||||
return adapted_request, all_requests[0]
|
return adapted_request, all_requests[0]
|
||||||
@@ -654,26 +670,63 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|||||||
total_completion_tokens = 0
|
total_completion_tokens = 0
|
||||||
|
|
||||||
for idx, ret_item in enumerate(ret):
|
for idx, ret_item in enumerate(ret):
|
||||||
|
logprobs = False
|
||||||
|
if isinstance(request, List) and request[idx].logprobs:
|
||||||
|
logprobs = True
|
||||||
|
elif (not isinstance(request, List)) and request.logprobs:
|
||||||
|
logprobs = True
|
||||||
|
if logprobs:
|
||||||
|
logprobs = to_openai_style_logprobs(
|
||||||
|
output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
|
||||||
|
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
|
||||||
|
)
|
||||||
|
token_logprobs = []
|
||||||
|
for token, logprob in zip(logprobs.tokens, logprobs.token_logprobs):
|
||||||
|
token_bytes = list(token.encode("utf-8"))
|
||||||
|
top_logprobs = []
|
||||||
|
if logprobs.top_logprobs:
|
||||||
|
for top_token, top_logprob in logprobs.top_logprobs[0].items():
|
||||||
|
top_token_bytes = list(top_token.encode("utf-8"))
|
||||||
|
top_logprobs.append(
|
||||||
|
TopLogprob(
|
||||||
|
token=top_token,
|
||||||
|
bytes=top_token_bytes,
|
||||||
|
logprob=top_logprob,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
token_logprobs.append(
|
||||||
|
ChatCompletionTokenLogprob(
|
||||||
|
token=token,
|
||||||
|
bytes=token_bytes,
|
||||||
|
logprob=logprob,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
choice_logprobs = ChoiceLogprobs(content=token_logprobs)
|
||||||
|
else:
|
||||||
|
choice_logprobs = None
|
||||||
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
|
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
|
||||||
completion_tokens = ret_item["meta_info"]["completion_tokens"]
|
completion_tokens = ret_item["meta_info"]["completion_tokens"]
|
||||||
|
|
||||||
if to_file:
|
if to_file:
|
||||||
## to make the choice data json serializable
|
# to make the choice data json serializable
|
||||||
choice_data = {
|
choice_data = {
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"message": {"role": "assistant", "content": ret_item["text"]},
|
"message": {"role": "assistant", "content": ret_item["text"]},
|
||||||
"logprobs": None,
|
"logprobs": choice_logprobs,
|
||||||
"finish_reason": ret_item["meta_info"]["finish_reason"],
|
"finish_reason": ret_item["meta_info"]["finish_reason"],
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
choice_data = ChatCompletionResponseChoice(
|
choice_data = ChatCompletionResponseChoice(
|
||||||
index=idx,
|
index=idx,
|
||||||
message=ChatMessage(role="assistant", content=ret_item["text"]),
|
message=ChatMessage(role="assistant", content=ret_item["text"]),
|
||||||
|
logprobs=choice_logprobs,
|
||||||
finish_reason=ret_item["meta_info"]["finish_reason"],
|
finish_reason=ret_item["meta_info"]["finish_reason"],
|
||||||
)
|
)
|
||||||
|
|
||||||
choices.append(choice_data)
|
choices.append(choice_data)
|
||||||
total_prompt_tokens = prompt_tokens
|
total_prompt_tokens += prompt_tokens
|
||||||
total_completion_tokens += completion_tokens
|
total_completion_tokens += completion_tokens
|
||||||
if to_file:
|
if to_file:
|
||||||
responses = []
|
responses = []
|
||||||
@@ -683,7 +736,7 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|||||||
"status_code": 200,
|
"status_code": 200,
|
||||||
"request_id": ret[i]["meta_info"]["id"],
|
"request_id": ret[i]["meta_info"]["id"],
|
||||||
"body": {
|
"body": {
|
||||||
## remain the same but if needed we can change that
|
# remain the same but if needed we can change that
|
||||||
"id": ret[i]["meta_info"]["id"],
|
"id": ret[i]["meta_info"]["id"],
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
"created": int(time.time()),
|
"created": int(time.time()),
|
||||||
|
|||||||
@@ -54,6 +54,24 @@ class LogProbs(BaseModel):
|
|||||||
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
|
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class TopLogprob(BaseModel):
|
||||||
|
token: str
|
||||||
|
bytes: List[int]
|
||||||
|
logprob: float
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionTokenLogprob(BaseModel):
|
||||||
|
token: str
|
||||||
|
bytes: List[int]
|
||||||
|
logprob: float
|
||||||
|
top_logprobs: List[TopLogprob]
|
||||||
|
|
||||||
|
|
||||||
|
class ChoiceLogprobs(BaseModel):
|
||||||
|
# build for v1/chat/completions response
|
||||||
|
content: List[ChatCompletionTokenLogprob]
|
||||||
|
|
||||||
|
|
||||||
class UsageInfo(BaseModel):
|
class UsageInfo(BaseModel):
|
||||||
prompt_tokens: int = 0
|
prompt_tokens: int = 0
|
||||||
total_tokens: int = 0
|
total_tokens: int = 0
|
||||||
@@ -239,8 +257,8 @@ class ChatMessage(BaseModel):
|
|||||||
class ChatCompletionResponseChoice(BaseModel):
|
class ChatCompletionResponseChoice(BaseModel):
|
||||||
index: int
|
index: int
|
||||||
message: ChatMessage
|
message: ChatMessage
|
||||||
logprobs: Optional[LogProbs] = None
|
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
|
||||||
finish_reason: Optional[str] = None
|
finish_reason: str
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponse(BaseModel):
|
class ChatCompletionResponse(BaseModel):
|
||||||
|
|||||||
Reference in New Issue
Block a user