[Feat]Add support for optional start len of logprobs (#1035)

Co-authored-by: Ying Sheng <sqy1415@gmail.com>
Co-authored-by: Yineng Zhang <me@zhyncs.com>
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
This commit is contained in:
yichuan~
2024-08-18 23:45:41 -07:00
committed by GitHub
parent d8627ed16d
commit b997a18d74
8 changed files with 113 additions and 31 deletions

View File

@@ -20,6 +20,7 @@ import json
import os
import time
import uuid
import warnings
from http import HTTPStatus
from typing import Dict, List, Optional
@@ -383,20 +384,33 @@ async def v1_retrieve_file_content(file_id: str):
return StreamingResponse(iter_file(), media_type="application/octet-stream")
def v1_generate_request(all_requests):
def v1_generate_request(all_requests: List[CompletionRequest]):
prompts = []
sampling_params_list = []
return_logprobs = []
logprob_start_lens = []
top_logprobs_nums = []
# NOTE: with openai API, the prompt's logprobs are always not computed
first_prompt_type = type(all_requests[0].prompt)
for request in all_requests:
assert (
type(request.prompt) == first_prompt_type
), "All prompts must be of the same type in file input settings"
if len(all_requests) > 1 and request.n > 1:
raise ValueError(
"Parallel sampling is not supported for completions from files"
)
if request.echo and request.logprobs:
warnings.warn(
"Echo is not compatible with logprobs. "
"To compute logprobs of input prompt, please use SGLang /request API."
)
for request in all_requests:
prompt = request.prompt
assert (
type(prompt) == first_prompt_type
), "All prompts must be of the same type in file input settings"
prompts.append(prompt)
prompts.append(request.prompt)
return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
logprob_start_lens.append(-1)
top_logprobs_nums.append(
request.logprobs if request.logprobs is not None else 0
)
@@ -416,14 +430,11 @@ def v1_generate_request(all_requests):
"ignore_eos": request.ignore_eos,
}
)
if len(all_requests) > 1 and request.n > 1:
raise ValueError(
"Parallel sampling is not supported for completions from files"
)
if len(all_requests) == 1:
prompt = prompts[0]
sampling_params_list = sampling_params_list[0]
logprob_start_lens = logprob_start_lens[0]
return_logprobs = return_logprobs[0]
top_logprobs_nums = top_logprobs_nums[0]
if isinstance(prompt, str) or isinstance(prompt[0], str):
@@ -441,6 +452,7 @@ def v1_generate_request(all_requests):
sampling_params=sampling_params_list,
return_logprob=return_logprobs,
top_logprobs_num=top_logprobs_nums,
logprob_start_len=logprob_start_lens,
return_text_in_logprobs=True,
stream=all_requests[0].stream,
)
@@ -694,12 +706,18 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
return response
def v1_chat_generate_request(all_requests, tokenizer_manager):
def v1_chat_generate_request(
all_requests: List[ChatCompletionRequest], tokenizer_manager
):
input_ids = []
sampling_params_list = []
image_data_list = []
return_logprobs = []
logprob_start_lens = []
top_logprobs_nums = []
# NOTE: with openai API, the prompt's logprobs are always not computed
for request in all_requests:
# Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string.
@@ -732,6 +750,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
image_data = None
input_ids.append(prompt_ids)
return_logprobs.append(request.logprobs)
logprob_start_lens.append(-1)
top_logprobs_nums.append(request.top_logprobs)
sampling_params_list.append(
{
@@ -758,17 +777,20 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
sampling_params_list = sampling_params_list[0]
image_data = image_data_list[0]
return_logprobs = return_logprobs[0]
logprob_start_lens = logprob_start_lens[0]
top_logprobs_nums = top_logprobs_nums[0]
else:
if isinstance(input_ids[0], str):
prompt_kwargs = {"text": input_ids}
else:
prompt_kwargs = {"input_ids": input_ids}
adapted_request = GenerateReqInput(
**prompt_kwargs,
image_data=image_data,
sampling_params=sampling_params_list,
return_logprob=return_logprobs,
logprob_start_len=logprob_start_lens,
top_logprobs_num=top_logprobs_nums,
stream=all_requests[0].stream,
return_text_in_logprobs=True,