Simplify tokenizer manager (#1904)

This commit is contained in:
Lianmin Zheng
2024-11-03 08:38:26 -08:00
committed by GitHub
parent 916b3cdddc
commit c17c578108
11 changed files with 261 additions and 443 deletions

View File

@@ -71,6 +71,7 @@ from sglang.srt.openai_api.protocol import (
TopLogprob,
UsageInfo,
)
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
@@ -314,6 +315,8 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
)
except Exception as e:
logger.error(f"error: {get_exception_traceback()}")
responses = []
error_json = {
"id": f"batch_req_{uuid.uuid4()}",
"custom_id": request_data.get("custom_id"),
@@ -363,7 +366,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
}
except Exception as e:
logger.error("error in SGLang:", e)
logger.error(f"error: {e}")
# Update batch status to "failed"
retrieve_batch = batch_storage[batch_id]
retrieve_batch.status = "failed"
@@ -469,80 +472,67 @@ async def v1_retrieve_file_content(file_id: str):
def v1_generate_request(
all_requests: List[CompletionRequest], request_ids: List[str] = None
):
if len(all_requests) > 1:
first_prompt_type = type(all_requests[0].prompt)
for request in all_requests:
assert (
type(request.prompt) is first_prompt_type
), "All prompts must be of the same type in file input settings"
if request.n > 1:
raise ValueError(
"Parallel sampling is not supported for completions from files"
)
prompts = []
sampling_params_list = []
return_logprobs = []
logprob_start_lens = []
top_logprobs_nums = []
# NOTE: with openai API, the prompt's logprobs are always not computed
first_prompt_type = type(all_requests[0].prompt)
for request in all_requests:
assert (
type(request.prompt) is first_prompt_type
), "All prompts must be of the same type in file input settings"
if len(all_requests) > 1 and request.n > 1:
raise ValueError(
"Parallel sampling is not supported for completions from files"
)
# NOTE: with openai API, the prompt's logprobs are always not computed
if request.echo and request.logprobs:
logger.warning(
"Echo is not compatible with logprobs. "
"To compute logprobs of input prompt, please use SGLang /request API."
"To compute logprobs of input prompt, please use the native /generate API."
)
for request in all_requests:
prompts.append(request.prompt)
sampling_params_list.append(
{
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"min_new_tokens": request.min_tokens,
"stop": request.stop,
"stop_token_ids": request.stop_token_ids,
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"repetition_penalty": request.repetition_penalty,
"regex": request.regex,
"json_schema": request.json_schema,
"n": request.n,
"ignore_eos": request.ignore_eos,
"no_stop_trim": request.no_stop_trim,
}
)
return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
logprob_start_lens.append(-1)
top_logprobs_nums.append(
request.logprobs if request.logprobs is not None else 0
)
sampling_params = []
if isinstance(request.no_stop_trim, list):
num_reqs = len(request.prompt)
else:
num_reqs = 1
for i in range(num_reqs):
sampling_params.append(
{
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"min_new_tokens": request.min_tokens,
"stop": request.stop,
"stop_token_ids": request.stop_token_ids,
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"repetition_penalty": request.repetition_penalty,
"regex": request.regex,
"json_schema": request.json_schema,
"n": request.n,
"ignore_eos": request.ignore_eos,
"no_stop_trim": (
request.no_stop_trim
if not isinstance(request.no_stop_trim, list)
else request.no_stop_trim[i]
),
}
)
if num_reqs == 1:
sampling_params_list.append(sampling_params[0])
else:
sampling_params_list.append(sampling_params)
if len(all_requests) == 1:
prompt = prompts[0]
sampling_params_list = sampling_params_list[0]
logprob_start_lens = logprob_start_lens[0]
return_logprobs = return_logprobs[0]
top_logprobs_nums = top_logprobs_nums[0]
if isinstance(prompt, str) or isinstance(prompt[0], str):
prompt_kwargs = {"text": prompt}
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
prompt_kwargs = {"text": prompts[0]}
else:
prompt_kwargs = {"input_ids": prompt}
prompt_kwargs = {"input_ids": prompts[0]}
sampling_params_list = sampling_params_list[0]
return_logprobs = return_logprobs[0]
logprob_start_lens = logprob_start_lens[0]
top_logprobs_nums = top_logprobs_nums[0]
else:
if isinstance(prompts[0], str):
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
prompt_kwargs = {"text": prompts}
else:
prompt_kwargs = {"input_ids": prompts}
@@ -558,9 +548,7 @@ def v1_generate_request(
rid=request_ids,
)
if len(all_requests) == 1:
return adapted_request, all_requests[0]
return adapted_request, all_requests
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
@@ -595,7 +583,7 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
if isinstance(request, list) and request[idx].echo:
echo = True
text = request[idx].prompt + text
if (not isinstance(request, list)) and echo:
if echo and not isinstance(request, list):
prompt_index = idx // request.n
text = prompts[prompt_index] + text
@@ -709,7 +697,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
async for content in tokenizer_manager.generate_request(
adapted_request, raw_request
):
index = content["index"]
index = content.get("index", 0)
stream_buffer = stream_buffers.get(index, "")
n_prev_token = n_prev_tokens.get(index, 0)
@@ -945,19 +933,18 @@ def v1_chat_generate_request(
sampling_params_list.append(sampling_params)
image_data_list.append(image_data)
modalities_list.extend(modalities)
modalities_list.append(modalities)
if len(all_requests) == 1:
input_ids = input_ids[0]
if isinstance(input_ids, str):
prompt_kwargs = {"text": input_ids}
if isinstance(input_ids[0], str):
prompt_kwargs = {"text": input_ids[0]}
else:
prompt_kwargs = {"input_ids": input_ids}
prompt_kwargs = {"input_ids": input_ids[0]}
sampling_params_list = sampling_params_list[0]
image_data_list = image_data_list[0]
return_logprobs = return_logprobs[0]
logprob_start_lens = logprob_start_lens[0]
top_logprobs_nums = top_logprobs_nums[0]
modalities_list = modalities_list[:1]
modalities_list = modalities_list[0]
else:
if isinstance(input_ids[0], str):
prompt_kwargs = {"text": input_ids}
@@ -976,9 +963,8 @@ def v1_chat_generate_request(
rid=request_ids,
modalities=modalities_list,
)
if len(all_requests) == 1:
return adapted_request, all_requests[0]
return adapted_request, all_requests
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
@@ -1116,7 +1102,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
async for content in tokenizer_manager.generate_request(
adapted_request, raw_request
):
index = content["index"]
index = content.get("index", 0)
is_first = is_firsts.get(index, True)
stream_buffer = stream_buffers.get(index, "")