[Feat]Add support for optional start len of logprobs (#1035)
Co-authored-by: Ying Sheng <sqy1415@gmail.com> Co-authored-by: Yineng Zhang <me@zhyncs.com> Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
This commit is contained in:
@@ -20,6 +20,7 @@ import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
@@ -383,20 +384,33 @@ async def v1_retrieve_file_content(file_id: str):
|
||||
return StreamingResponse(iter_file(), media_type="application/octet-stream")
|
||||
|
||||
|
||||
def v1_generate_request(all_requests):
|
||||
def v1_generate_request(all_requests: List[CompletionRequest]):
|
||||
prompts = []
|
||||
sampling_params_list = []
|
||||
return_logprobs = []
|
||||
logprob_start_lens = []
|
||||
top_logprobs_nums = []
|
||||
|
||||
# NOTE: with openai API, the prompt's logprobs are always not computed
|
||||
first_prompt_type = type(all_requests[0].prompt)
|
||||
for request in all_requests:
|
||||
assert (
|
||||
type(request.prompt) == first_prompt_type
|
||||
), "All prompts must be of the same type in file input settings"
|
||||
if len(all_requests) > 1 and request.n > 1:
|
||||
raise ValueError(
|
||||
"Parallel sampling is not supported for completions from files"
|
||||
)
|
||||
if request.echo and request.logprobs:
|
||||
warnings.warn(
|
||||
"Echo is not compatible with logprobs. "
|
||||
"To compute logprobs of input prompt, please use SGLang /request API."
|
||||
)
|
||||
|
||||
for request in all_requests:
|
||||
prompt = request.prompt
|
||||
assert (
|
||||
type(prompt) == first_prompt_type
|
||||
), "All prompts must be of the same type in file input settings"
|
||||
prompts.append(prompt)
|
||||
prompts.append(request.prompt)
|
||||
return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
|
||||
logprob_start_lens.append(-1)
|
||||
top_logprobs_nums.append(
|
||||
request.logprobs if request.logprobs is not None else 0
|
||||
)
|
||||
@@ -416,14 +430,11 @@ def v1_generate_request(all_requests):
|
||||
"ignore_eos": request.ignore_eos,
|
||||
}
|
||||
)
|
||||
if len(all_requests) > 1 and request.n > 1:
|
||||
raise ValueError(
|
||||
"Parallel sampling is not supported for completions from files"
|
||||
)
|
||||
|
||||
if len(all_requests) == 1:
|
||||
prompt = prompts[0]
|
||||
sampling_params_list = sampling_params_list[0]
|
||||
logprob_start_lens = logprob_start_lens[0]
|
||||
return_logprobs = return_logprobs[0]
|
||||
top_logprobs_nums = top_logprobs_nums[0]
|
||||
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
||||
@@ -441,6 +452,7 @@ def v1_generate_request(all_requests):
|
||||
sampling_params=sampling_params_list,
|
||||
return_logprob=return_logprobs,
|
||||
top_logprobs_num=top_logprobs_nums,
|
||||
logprob_start_len=logprob_start_lens,
|
||||
return_text_in_logprobs=True,
|
||||
stream=all_requests[0].stream,
|
||||
)
|
||||
@@ -694,12 +706,18 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||
return response
|
||||
|
||||
|
||||
def v1_chat_generate_request(all_requests, tokenizer_manager):
|
||||
def v1_chat_generate_request(
|
||||
all_requests: List[ChatCompletionRequest], tokenizer_manager
|
||||
):
|
||||
input_ids = []
|
||||
sampling_params_list = []
|
||||
image_data_list = []
|
||||
return_logprobs = []
|
||||
logprob_start_lens = []
|
||||
top_logprobs_nums = []
|
||||
|
||||
# NOTE: with openai API, the prompt's logprobs are always not computed
|
||||
|
||||
for request in all_requests:
|
||||
# Prep the data needed for the underlying GenerateReqInput:
|
||||
# - prompt: The full prompt string.
|
||||
@@ -732,6 +750,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
||||
image_data = None
|
||||
input_ids.append(prompt_ids)
|
||||
return_logprobs.append(request.logprobs)
|
||||
logprob_start_lens.append(-1)
|
||||
top_logprobs_nums.append(request.top_logprobs)
|
||||
sampling_params_list.append(
|
||||
{
|
||||
@@ -758,17 +777,20 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
||||
sampling_params_list = sampling_params_list[0]
|
||||
image_data = image_data_list[0]
|
||||
return_logprobs = return_logprobs[0]
|
||||
logprob_start_lens = logprob_start_lens[0]
|
||||
top_logprobs_nums = top_logprobs_nums[0]
|
||||
else:
|
||||
if isinstance(input_ids[0], str):
|
||||
prompt_kwargs = {"text": input_ids}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": input_ids}
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
**prompt_kwargs,
|
||||
image_data=image_data,
|
||||
sampling_params=sampling_params_list,
|
||||
return_logprob=return_logprobs,
|
||||
logprob_start_len=logprob_start_lens,
|
||||
top_logprobs_num=top_logprobs_nums,
|
||||
stream=all_requests[0].stream,
|
||||
return_text_in_logprobs=True,
|
||||
|
||||
Reference in New Issue
Block a user