Support more OpenAI API test (#916)
This commit is contained in:
@@ -92,7 +92,7 @@ class GenerateReqInput:
|
||||
for element in parallel_sample_num_list
|
||||
)
|
||||
if parallel_sample_num > 1 and (not all_equal):
|
||||
## TODO cope with the case that the parallel_sample_num is different for different samples
|
||||
# TODO cope with the case that the parallel_sample_num is different for different samples
|
||||
raise ValueError(
|
||||
"The parallel_sample_num should be the same for all samples in sample params."
|
||||
)
|
||||
@@ -103,14 +103,19 @@ class GenerateReqInput:
|
||||
if parallel_sample_num != 1:
|
||||
# parallel sampling +1 represents the original prefill stage
|
||||
num = parallel_sample_num + 1
|
||||
if isinstance(self.text, List):
|
||||
## suppot batch operation
|
||||
if isinstance(self.text, list):
|
||||
# suppot batch operation
|
||||
self.batch_size = len(self.text)
|
||||
num = num * len(self.text)
|
||||
elif isinstance(self.input_ids, list) and isinstance(
|
||||
self.input_ids[0], list
|
||||
):
|
||||
self.batch_size = len(self.input_ids)
|
||||
num = num * len(self.input_ids)
|
||||
else:
|
||||
self.batch_size = 1
|
||||
else:
|
||||
## support select operation
|
||||
# support select operation
|
||||
num = len(self.text) if self.text is not None else len(self.input_ids)
|
||||
self.batch_size = num
|
||||
|
||||
|
||||
@@ -153,8 +153,9 @@ class TokenizerManager:
|
||||
async def _handle_single_request(
|
||||
self, obj, request, index=None, is_cache_for_prefill=False
|
||||
):
|
||||
if not is_cache_for_prefill:
|
||||
not_use_index = not (index is not None)
|
||||
if not is_cache_for_prefill: # The normal case with a single prompt
|
||||
not_use_index = index is None
|
||||
|
||||
rid = obj.rid if not_use_index else obj.rid[index]
|
||||
input_text = obj.text if not_use_index else obj.text[index]
|
||||
input_ids = (
|
||||
@@ -182,14 +183,27 @@ class TokenizerManager:
|
||||
top_logprobs_num = (
|
||||
obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
|
||||
)
|
||||
else:
|
||||
if isinstance(obj.text, list):
|
||||
input_text = obj.text[index]
|
||||
rid = obj.rid[index]
|
||||
else: # A prefill request to cache the common prompt for parallel sampling
|
||||
if obj.text is not None:
|
||||
if isinstance(obj.text, list):
|
||||
input_text = obj.text[index]
|
||||
rid = obj.rid[index]
|
||||
else:
|
||||
input_text = obj.text
|
||||
rid = obj.rid[0]
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
else:
|
||||
input_text = obj.text
|
||||
rid = obj.rid[0]
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
input_text = None
|
||||
if isinstance(obj.input_ids, list) and isinstance(
|
||||
obj.input_ids[0], list
|
||||
):
|
||||
# when obj["input_ids"] is List[List[int]]
|
||||
input_ids = obj.input_ids[index]
|
||||
rid = obj.rid[index]
|
||||
else:
|
||||
input_ids = obj.input_ids
|
||||
rid = obj.rid[0]
|
||||
|
||||
sampling_params = SamplingParams(**obj.sampling_params[0])
|
||||
sampling_params.max_new_tokens = 0
|
||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||
@@ -240,11 +254,11 @@ class TokenizerManager:
|
||||
):
|
||||
if input_id_result is not None:
|
||||
input_id_result.append(input_id)
|
||||
pass
|
||||
if len(input_id_result) > 1 and input_id_result is not None:
|
||||
if input_id_result is not None and len(input_id_result) > 1:
|
||||
obj.input_ids = input_id_result
|
||||
elif input_id_result is not None:
|
||||
obj.input_ids = input_id_result[0]
|
||||
|
||||
# First send out all requests
|
||||
for i in range(batch_size):
|
||||
for j in range(parallel_sample_num):
|
||||
@@ -264,11 +278,12 @@ class TokenizerManager:
|
||||
input_text = None
|
||||
input_ids = obj.input_ids[i]
|
||||
else:
|
||||
assert obj.input_ids is not None
|
||||
if batch_size == 1:
|
||||
input_text = obj.text
|
||||
input_text = None
|
||||
input_ids = obj.input_ids
|
||||
else:
|
||||
input_text = obj.text[i]
|
||||
input_text = None
|
||||
input_ids = obj.input_ids[i]
|
||||
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||
|
||||
@@ -251,7 +251,9 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
||||
if end_point == "/v1/chat/completions":
|
||||
responses = v1_chat_generate_response(request, ret, to_file=True)
|
||||
else:
|
||||
responses = v1_generate_response(request, ret, to_file=True)
|
||||
responses = v1_generate_response(
|
||||
request, ret, tokenizer_manager, to_file=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_json = {
|
||||
@@ -339,6 +341,7 @@ def v1_generate_request(all_requests):
|
||||
return_logprobs = []
|
||||
top_logprobs_nums = []
|
||||
first_prompt_type = type(all_requests[0].prompt)
|
||||
|
||||
for request in all_requests:
|
||||
prompt = request.prompt
|
||||
assert (
|
||||
@@ -364,7 +367,7 @@ def v1_generate_request(all_requests):
|
||||
)
|
||||
if len(all_requests) > 1 and request.n > 1:
|
||||
raise ValueError(
|
||||
"Batch operation is not supported for completions from files"
|
||||
"Parallel sampling is not supported for completions from files"
|
||||
)
|
||||
|
||||
if len(all_requests) == 1:
|
||||
@@ -377,10 +380,11 @@ def v1_generate_request(all_requests):
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": prompt}
|
||||
else:
|
||||
if isinstance(prompts[0], str):
|
||||
if isinstance(prompts[0], str) or isinstance(propmt[0][0], str):
|
||||
prompt_kwargs = {"text": prompts}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": prompts}
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
**prompt_kwargs,
|
||||
sampling_params=sampling_params_list,
|
||||
@@ -389,35 +393,52 @@ def v1_generate_request(all_requests):
|
||||
return_text_in_logprobs=True,
|
||||
stream=all_requests[0].stream,
|
||||
)
|
||||
|
||||
if len(all_requests) == 1:
|
||||
return adapted_request, all_requests[0]
|
||||
return adapted_request, all_requests
|
||||
|
||||
|
||||
def v1_generate_response(request, ret, to_file=False):
|
||||
def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
||||
choices = []
|
||||
echo = False
|
||||
|
||||
if (not isinstance(request, List)) and request.echo:
|
||||
if (not isinstance(request, list)) and request.echo:
|
||||
# TODO: handle the case propmt is token ids
|
||||
if isinstance(request.prompt, list):
|
||||
if isinstance(request.prompt, list) and isinstance(request.prompt[0], str):
|
||||
# for the case of multiple str prompts
|
||||
prompts = request.prompt
|
||||
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list):
|
||||
# for the case of multiple token ids prompts
|
||||
prompts = [
|
||||
tokenizer_manager.tokenizer.decode(prompt, skip_special_tokens=True)
|
||||
for prompt in request.prompt
|
||||
]
|
||||
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int):
|
||||
# for the case of single token ids prompt
|
||||
prompts = [
|
||||
tokenizer_manager.tokenizer.decode(
|
||||
request.prompt, skip_special_tokens=True
|
||||
)
|
||||
]
|
||||
else:
|
||||
# for the case of single str prompt
|
||||
prompts = [request.prompt]
|
||||
echo = True
|
||||
|
||||
for idx, ret_item in enumerate(ret):
|
||||
text = ret_item["text"]
|
||||
if isinstance(request, List) and request[idx].echo:
|
||||
if isinstance(request, list) and request[idx].echo:
|
||||
echo = True
|
||||
text = request[idx].prompt + text
|
||||
if (not isinstance(request, List)) and echo:
|
||||
text = prompts[idx] + text
|
||||
if (not isinstance(request, list)) and echo:
|
||||
prompt_index = idx // request.n
|
||||
text = prompts[prompt_index] + text
|
||||
|
||||
logprobs = False
|
||||
if isinstance(request, List) and request[idx].logprobs:
|
||||
if isinstance(request, list) and request[idx].logprobs:
|
||||
logprobs = True
|
||||
elif (not isinstance(request, List)) and request.logprobs:
|
||||
elif (not isinstance(request, list)) and request.logprobs:
|
||||
logprobs = True
|
||||
if logprobs:
|
||||
if echo:
|
||||
@@ -479,15 +500,16 @@ def v1_generate_response(request, ret, to_file=False):
|
||||
responses.append(response)
|
||||
return responses
|
||||
else:
|
||||
prompt_tokens = sum(item["meta_info"]["prompt_tokens"] for item in ret)
|
||||
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
|
||||
response = CompletionResponse(
|
||||
id=ret[0]["meta_info"]["id"],
|
||||
model=request.model,
|
||||
choices=choices,
|
||||
usage=UsageInfo(
|
||||
prompt_tokens=ret[0]["meta_info"]["prompt_tokens"],
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=ret[0]["meta_info"]["prompt_tokens"] + completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
),
|
||||
)
|
||||
return response
|
||||
@@ -513,8 +535,18 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||
|
||||
if not stream_buffer: # The first chunk
|
||||
if request.echo:
|
||||
if isinstance(request.prompt, str):
|
||||
# for the case of single str prompts
|
||||
prompts = request.prompt
|
||||
elif isinstance(request.prompt, list) and isinstance(
|
||||
request.prompt[0], int
|
||||
):
|
||||
prompts = tokenizer_manager.tokenizer.decode(
|
||||
request.prompt, skip_special_tokens=True
|
||||
)
|
||||
|
||||
# Prepend prompt in response text.
|
||||
text = request.prompt + text
|
||||
text = prompts + text
|
||||
|
||||
if request.logprobs:
|
||||
# The first chunk and echo is enabled.
|
||||
@@ -539,7 +571,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||
"output_top_logprobs"
|
||||
][n_prev_token:],
|
||||
)
|
||||
|
||||
n_prev_token = len(
|
||||
content["meta_info"]["output_token_logprobs"]
|
||||
)
|
||||
@@ -588,7 +619,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||
if not isinstance(ret, list):
|
||||
ret = [ret]
|
||||
|
||||
response = v1_generate_response(request, ret)
|
||||
response = v1_generate_response(request, ret, tokenizer_manager)
|
||||
return response
|
||||
|
||||
|
||||
@@ -626,7 +657,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
||||
prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
|
||||
else:
|
||||
# Use the raw prompt and stop strings if the messages is already a string.
|
||||
prompt = request.messages
|
||||
prompt_ids = request.messages
|
||||
stop = request.stop
|
||||
image_data = None
|
||||
input_ids.append(prompt_ids)
|
||||
@@ -647,12 +678,21 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
||||
image_data_list.append(image_data)
|
||||
if len(all_requests) == 1:
|
||||
input_ids = input_ids[0]
|
||||
if isinstance(input_ids, str):
|
||||
prompt_kwargs = {"text": input_ids}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": input_ids}
|
||||
sampling_params_list = sampling_params_list[0]
|
||||
image_data = image_data_list[0]
|
||||
return_logprobs = return_logprobs[0]
|
||||
top_logprobs_nums = top_logprobs_nums[0]
|
||||
else:
|
||||
if isinstance(input_ids[0], str):
|
||||
prompt_kwargs = {"text": input_ids}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": input_ids}
|
||||
adapted_request = GenerateReqInput(
|
||||
input_ids=input_ids,
|
||||
**prompt_kwargs,
|
||||
image_data=image_data,
|
||||
sampling_params=sampling_params_list,
|
||||
return_logprob=return_logprobs,
|
||||
@@ -672,9 +712,9 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
||||
|
||||
for idx, ret_item in enumerate(ret):
|
||||
logprobs = False
|
||||
if isinstance(request, List) and request[idx].logprobs:
|
||||
if isinstance(request, list) and request[idx].logprobs:
|
||||
logprobs = True
|
||||
elif (not isinstance(request, List)) and request.logprobs:
|
||||
elif (not isinstance(request, list)) and request.logprobs:
|
||||
logprobs = True
|
||||
if logprobs:
|
||||
logprobs = to_openai_style_logprobs(
|
||||
@@ -779,10 +819,58 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
is_first = True
|
||||
|
||||
stream_buffer = ""
|
||||
n_prev_token = 0
|
||||
try:
|
||||
async for content in tokenizer_manager.generate_request(
|
||||
adapted_request, raw_request
|
||||
):
|
||||
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
||||
completion_tokens = content["meta_info"]["completion_tokens"]
|
||||
if request.logprobs:
|
||||
logprobs = to_openai_style_logprobs(
|
||||
output_token_logprobs=content["meta_info"][
|
||||
"output_token_logprobs"
|
||||
][n_prev_token:],
|
||||
output_top_logprobs=content["meta_info"][
|
||||
"output_top_logprobs"
|
||||
][n_prev_token:],
|
||||
)
|
||||
|
||||
n_prev_token = len(
|
||||
content["meta_info"]["output_token_logprobs"]
|
||||
)
|
||||
token_logprobs = []
|
||||
for token, logprob in zip(
|
||||
logprobs.tokens, logprobs.token_logprobs
|
||||
):
|
||||
token_bytes = list(token.encode("utf-8"))
|
||||
top_logprobs = []
|
||||
if logprobs.top_logprobs:
|
||||
for top_token, top_logprob in logprobs.top_logprobs[
|
||||
0
|
||||
].items():
|
||||
top_token_bytes = list(top_token.encode("utf-8"))
|
||||
top_logprobs.append(
|
||||
TopLogprob(
|
||||
token=top_token,
|
||||
bytes=top_token_bytes,
|
||||
logprob=top_logprob,
|
||||
)
|
||||
)
|
||||
token_logprobs.append(
|
||||
ChatCompletionTokenLogprob(
|
||||
token=token,
|
||||
bytes=token_bytes,
|
||||
logprob=logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
choice_logprobs = ChoiceLogprobs(content=token_logprobs)
|
||||
|
||||
else:
|
||||
choice_logprobs = None
|
||||
|
||||
if is_first:
|
||||
# First chunk with role
|
||||
is_first = False
|
||||
@@ -790,11 +878,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
index=0,
|
||||
delta=DeltaMessage(role="assistant"),
|
||||
finish_reason=content["meta_info"]["finish_reason"],
|
||||
logprobs=choice_logprobs,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
usage=UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
),
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
@@ -805,11 +899,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
index=0,
|
||||
delta=DeltaMessage(content=delta),
|
||||
finish_reason=content["meta_info"]["finish_reason"],
|
||||
logprobs=choice_logprobs,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
usage=UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
),
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
except ValueError as e:
|
||||
|
||||
@@ -278,7 +278,7 @@ class DeltaMessage(BaseModel):
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
logprobs: Optional[LogProbs] = None
|
||||
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user