Simplify tokenizer manager (#1904)
This commit is contained in:
@@ -71,6 +71,7 @@ from sglang.srt.openai_api.protocol import (
|
||||
TopLogprob,
|
||||
UsageInfo,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -314,6 +315,8 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"error: {get_exception_traceback()}")
|
||||
responses = []
|
||||
error_json = {
|
||||
"id": f"batch_req_{uuid.uuid4()}",
|
||||
"custom_id": request_data.get("custom_id"),
|
||||
@@ -363,7 +366,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("error in SGLang:", e)
|
||||
logger.error(f"error: {e}")
|
||||
# Update batch status to "failed"
|
||||
retrieve_batch = batch_storage[batch_id]
|
||||
retrieve_batch.status = "failed"
|
||||
@@ -469,80 +472,67 @@ async def v1_retrieve_file_content(file_id: str):
|
||||
def v1_generate_request(
|
||||
all_requests: List[CompletionRequest], request_ids: List[str] = None
|
||||
):
|
||||
if len(all_requests) > 1:
|
||||
first_prompt_type = type(all_requests[0].prompt)
|
||||
for request in all_requests:
|
||||
assert (
|
||||
type(request.prompt) is first_prompt_type
|
||||
), "All prompts must be of the same type in file input settings"
|
||||
if request.n > 1:
|
||||
raise ValueError(
|
||||
"Parallel sampling is not supported for completions from files"
|
||||
)
|
||||
|
||||
prompts = []
|
||||
sampling_params_list = []
|
||||
return_logprobs = []
|
||||
logprob_start_lens = []
|
||||
top_logprobs_nums = []
|
||||
|
||||
# NOTE: with openai API, the prompt's logprobs are always not computed
|
||||
first_prompt_type = type(all_requests[0].prompt)
|
||||
for request in all_requests:
|
||||
assert (
|
||||
type(request.prompt) is first_prompt_type
|
||||
), "All prompts must be of the same type in file input settings"
|
||||
if len(all_requests) > 1 and request.n > 1:
|
||||
raise ValueError(
|
||||
"Parallel sampling is not supported for completions from files"
|
||||
)
|
||||
# NOTE: with openai API, the prompt's logprobs are always not computed
|
||||
if request.echo and request.logprobs:
|
||||
logger.warning(
|
||||
"Echo is not compatible with logprobs. "
|
||||
"To compute logprobs of input prompt, please use SGLang /request API."
|
||||
"To compute logprobs of input prompt, please use the native /generate API."
|
||||
)
|
||||
|
||||
for request in all_requests:
|
||||
prompts.append(request.prompt)
|
||||
sampling_params_list.append(
|
||||
{
|
||||
"temperature": request.temperature,
|
||||
"max_new_tokens": request.max_tokens,
|
||||
"min_new_tokens": request.min_tokens,
|
||||
"stop": request.stop,
|
||||
"stop_token_ids": request.stop_token_ids,
|
||||
"top_p": request.top_p,
|
||||
"presence_penalty": request.presence_penalty,
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
"repetition_penalty": request.repetition_penalty,
|
||||
"regex": request.regex,
|
||||
"json_schema": request.json_schema,
|
||||
"n": request.n,
|
||||
"ignore_eos": request.ignore_eos,
|
||||
"no_stop_trim": request.no_stop_trim,
|
||||
}
|
||||
)
|
||||
return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
|
||||
logprob_start_lens.append(-1)
|
||||
top_logprobs_nums.append(
|
||||
request.logprobs if request.logprobs is not None else 0
|
||||
)
|
||||
sampling_params = []
|
||||
if isinstance(request.no_stop_trim, list):
|
||||
num_reqs = len(request.prompt)
|
||||
else:
|
||||
num_reqs = 1
|
||||
for i in range(num_reqs):
|
||||
sampling_params.append(
|
||||
{
|
||||
"temperature": request.temperature,
|
||||
"max_new_tokens": request.max_tokens,
|
||||
"min_new_tokens": request.min_tokens,
|
||||
"stop": request.stop,
|
||||
"stop_token_ids": request.stop_token_ids,
|
||||
"top_p": request.top_p,
|
||||
"presence_penalty": request.presence_penalty,
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
"repetition_penalty": request.repetition_penalty,
|
||||
"regex": request.regex,
|
||||
"json_schema": request.json_schema,
|
||||
"n": request.n,
|
||||
"ignore_eos": request.ignore_eos,
|
||||
"no_stop_trim": (
|
||||
request.no_stop_trim
|
||||
if not isinstance(request.no_stop_trim, list)
|
||||
else request.no_stop_trim[i]
|
||||
),
|
||||
}
|
||||
)
|
||||
if num_reqs == 1:
|
||||
sampling_params_list.append(sampling_params[0])
|
||||
else:
|
||||
sampling_params_list.append(sampling_params)
|
||||
|
||||
if len(all_requests) == 1:
|
||||
prompt = prompts[0]
|
||||
sampling_params_list = sampling_params_list[0]
|
||||
logprob_start_lens = logprob_start_lens[0]
|
||||
return_logprobs = return_logprobs[0]
|
||||
top_logprobs_nums = top_logprobs_nums[0]
|
||||
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
||||
prompt_kwargs = {"text": prompt}
|
||||
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
||||
prompt_kwargs = {"text": prompts[0]}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": prompt}
|
||||
prompt_kwargs = {"input_ids": prompts[0]}
|
||||
sampling_params_list = sampling_params_list[0]
|
||||
return_logprobs = return_logprobs[0]
|
||||
logprob_start_lens = logprob_start_lens[0]
|
||||
top_logprobs_nums = top_logprobs_nums[0]
|
||||
else:
|
||||
if isinstance(prompts[0], str):
|
||||
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
||||
prompt_kwargs = {"text": prompts}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": prompts}
|
||||
@@ -558,9 +548,7 @@ def v1_generate_request(
|
||||
rid=request_ids,
|
||||
)
|
||||
|
||||
if len(all_requests) == 1:
|
||||
return adapted_request, all_requests[0]
|
||||
return adapted_request, all_requests
|
||||
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
||||
|
||||
|
||||
def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
||||
@@ -595,7 +583,7 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
||||
if isinstance(request, list) and request[idx].echo:
|
||||
echo = True
|
||||
text = request[idx].prompt + text
|
||||
if (not isinstance(request, list)) and echo:
|
||||
if echo and not isinstance(request, list):
|
||||
prompt_index = idx // request.n
|
||||
text = prompts[prompt_index] + text
|
||||
|
||||
@@ -709,7 +697,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||
async for content in tokenizer_manager.generate_request(
|
||||
adapted_request, raw_request
|
||||
):
|
||||
index = content["index"]
|
||||
index = content.get("index", 0)
|
||||
|
||||
stream_buffer = stream_buffers.get(index, "")
|
||||
n_prev_token = n_prev_tokens.get(index, 0)
|
||||
@@ -945,19 +933,18 @@ def v1_chat_generate_request(
|
||||
sampling_params_list.append(sampling_params)
|
||||
|
||||
image_data_list.append(image_data)
|
||||
modalities_list.extend(modalities)
|
||||
modalities_list.append(modalities)
|
||||
if len(all_requests) == 1:
|
||||
input_ids = input_ids[0]
|
||||
if isinstance(input_ids, str):
|
||||
prompt_kwargs = {"text": input_ids}
|
||||
if isinstance(input_ids[0], str):
|
||||
prompt_kwargs = {"text": input_ids[0]}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": input_ids}
|
||||
prompt_kwargs = {"input_ids": input_ids[0]}
|
||||
sampling_params_list = sampling_params_list[0]
|
||||
image_data_list = image_data_list[0]
|
||||
return_logprobs = return_logprobs[0]
|
||||
logprob_start_lens = logprob_start_lens[0]
|
||||
top_logprobs_nums = top_logprobs_nums[0]
|
||||
modalities_list = modalities_list[:1]
|
||||
modalities_list = modalities_list[0]
|
||||
else:
|
||||
if isinstance(input_ids[0], str):
|
||||
prompt_kwargs = {"text": input_ids}
|
||||
@@ -976,9 +963,8 @@ def v1_chat_generate_request(
|
||||
rid=request_ids,
|
||||
modalities=modalities_list,
|
||||
)
|
||||
if len(all_requests) == 1:
|
||||
return adapted_request, all_requests[0]
|
||||
return adapted_request, all_requests
|
||||
|
||||
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
||||
|
||||
|
||||
def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
||||
@@ -1116,7 +1102,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
async for content in tokenizer_manager.generate_request(
|
||||
adapted_request, raw_request
|
||||
):
|
||||
index = content["index"]
|
||||
index = content.get("index", 0)
|
||||
|
||||
is_first = is_firsts.get(index, True)
|
||||
stream_buffer = stream_buffers.get(index, "")
|
||||
|
||||
Reference in New Issue
Block a user