Support more OpenAI API test (#916)
This commit is contained in:
@@ -92,7 +92,7 @@ class GenerateReqInput:
|
|||||||
for element in parallel_sample_num_list
|
for element in parallel_sample_num_list
|
||||||
)
|
)
|
||||||
if parallel_sample_num > 1 and (not all_equal):
|
if parallel_sample_num > 1 and (not all_equal):
|
||||||
## TODO cope with the case that the parallel_sample_num is different for different samples
|
# TODO cope with the case that the parallel_sample_num is different for different samples
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The parallel_sample_num should be the same for all samples in sample params."
|
"The parallel_sample_num should be the same for all samples in sample params."
|
||||||
)
|
)
|
||||||
@@ -103,14 +103,19 @@ class GenerateReqInput:
|
|||||||
if parallel_sample_num != 1:
|
if parallel_sample_num != 1:
|
||||||
# parallel sampling +1 represents the original prefill stage
|
# parallel sampling +1 represents the original prefill stage
|
||||||
num = parallel_sample_num + 1
|
num = parallel_sample_num + 1
|
||||||
if isinstance(self.text, List):
|
if isinstance(self.text, list):
|
||||||
## suppot batch operation
|
# suppot batch operation
|
||||||
self.batch_size = len(self.text)
|
self.batch_size = len(self.text)
|
||||||
num = num * len(self.text)
|
num = num * len(self.text)
|
||||||
|
elif isinstance(self.input_ids, list) and isinstance(
|
||||||
|
self.input_ids[0], list
|
||||||
|
):
|
||||||
|
self.batch_size = len(self.input_ids)
|
||||||
|
num = num * len(self.input_ids)
|
||||||
else:
|
else:
|
||||||
self.batch_size = 1
|
self.batch_size = 1
|
||||||
else:
|
else:
|
||||||
## support select operation
|
# support select operation
|
||||||
num = len(self.text) if self.text is not None else len(self.input_ids)
|
num = len(self.text) if self.text is not None else len(self.input_ids)
|
||||||
self.batch_size = num
|
self.batch_size = num
|
||||||
|
|
||||||
|
|||||||
@@ -153,8 +153,9 @@ class TokenizerManager:
|
|||||||
async def _handle_single_request(
|
async def _handle_single_request(
|
||||||
self, obj, request, index=None, is_cache_for_prefill=False
|
self, obj, request, index=None, is_cache_for_prefill=False
|
||||||
):
|
):
|
||||||
if not is_cache_for_prefill:
|
if not is_cache_for_prefill: # The normal case with a single prompt
|
||||||
not_use_index = not (index is not None)
|
not_use_index = index is None
|
||||||
|
|
||||||
rid = obj.rid if not_use_index else obj.rid[index]
|
rid = obj.rid if not_use_index else obj.rid[index]
|
||||||
input_text = obj.text if not_use_index else obj.text[index]
|
input_text = obj.text if not_use_index else obj.text[index]
|
||||||
input_ids = (
|
input_ids = (
|
||||||
@@ -182,14 +183,27 @@ class TokenizerManager:
|
|||||||
top_logprobs_num = (
|
top_logprobs_num = (
|
||||||
obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
|
obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
|
||||||
)
|
)
|
||||||
else:
|
else: # A prefill request to cache the common prompt for parallel sampling
|
||||||
if isinstance(obj.text, list):
|
if obj.text is not None:
|
||||||
input_text = obj.text[index]
|
if isinstance(obj.text, list):
|
||||||
rid = obj.rid[index]
|
input_text = obj.text[index]
|
||||||
|
rid = obj.rid[index]
|
||||||
|
else:
|
||||||
|
input_text = obj.text
|
||||||
|
rid = obj.rid[0]
|
||||||
|
input_ids = self.tokenizer.encode(input_text)
|
||||||
else:
|
else:
|
||||||
input_text = obj.text
|
input_text = None
|
||||||
rid = obj.rid[0]
|
if isinstance(obj.input_ids, list) and isinstance(
|
||||||
input_ids = self.tokenizer.encode(input_text)
|
obj.input_ids[0], list
|
||||||
|
):
|
||||||
|
# when obj["input_ids"] is List[List[int]]
|
||||||
|
input_ids = obj.input_ids[index]
|
||||||
|
rid = obj.rid[index]
|
||||||
|
else:
|
||||||
|
input_ids = obj.input_ids
|
||||||
|
rid = obj.rid[0]
|
||||||
|
|
||||||
sampling_params = SamplingParams(**obj.sampling_params[0])
|
sampling_params = SamplingParams(**obj.sampling_params[0])
|
||||||
sampling_params.max_new_tokens = 0
|
sampling_params.max_new_tokens = 0
|
||||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||||
@@ -240,11 +254,11 @@ class TokenizerManager:
|
|||||||
):
|
):
|
||||||
if input_id_result is not None:
|
if input_id_result is not None:
|
||||||
input_id_result.append(input_id)
|
input_id_result.append(input_id)
|
||||||
pass
|
if input_id_result is not None and len(input_id_result) > 1:
|
||||||
if len(input_id_result) > 1 and input_id_result is not None:
|
|
||||||
obj.input_ids = input_id_result
|
obj.input_ids = input_id_result
|
||||||
elif input_id_result is not None:
|
elif input_id_result is not None:
|
||||||
obj.input_ids = input_id_result[0]
|
obj.input_ids = input_id_result[0]
|
||||||
|
|
||||||
# First send out all requests
|
# First send out all requests
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
for j in range(parallel_sample_num):
|
for j in range(parallel_sample_num):
|
||||||
@@ -264,11 +278,12 @@ class TokenizerManager:
|
|||||||
input_text = None
|
input_text = None
|
||||||
input_ids = obj.input_ids[i]
|
input_ids = obj.input_ids[i]
|
||||||
else:
|
else:
|
||||||
|
assert obj.input_ids is not None
|
||||||
if batch_size == 1:
|
if batch_size == 1:
|
||||||
input_text = obj.text
|
input_text = None
|
||||||
input_ids = obj.input_ids
|
input_ids = obj.input_ids
|
||||||
else:
|
else:
|
||||||
input_text = obj.text[i]
|
input_text = None
|
||||||
input_ids = obj.input_ids[i]
|
input_ids = obj.input_ids[i]
|
||||||
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
||||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||||
|
|||||||
@@ -251,7 +251,9 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|||||||
if end_point == "/v1/chat/completions":
|
if end_point == "/v1/chat/completions":
|
||||||
responses = v1_chat_generate_response(request, ret, to_file=True)
|
responses = v1_chat_generate_response(request, ret, to_file=True)
|
||||||
else:
|
else:
|
||||||
responses = v1_generate_response(request, ret, to_file=True)
|
responses = v1_generate_response(
|
||||||
|
request, ret, tokenizer_manager, to_file=True
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_json = {
|
error_json = {
|
||||||
@@ -339,6 +341,7 @@ def v1_generate_request(all_requests):
|
|||||||
return_logprobs = []
|
return_logprobs = []
|
||||||
top_logprobs_nums = []
|
top_logprobs_nums = []
|
||||||
first_prompt_type = type(all_requests[0].prompt)
|
first_prompt_type = type(all_requests[0].prompt)
|
||||||
|
|
||||||
for request in all_requests:
|
for request in all_requests:
|
||||||
prompt = request.prompt
|
prompt = request.prompt
|
||||||
assert (
|
assert (
|
||||||
@@ -364,7 +367,7 @@ def v1_generate_request(all_requests):
|
|||||||
)
|
)
|
||||||
if len(all_requests) > 1 and request.n > 1:
|
if len(all_requests) > 1 and request.n > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Batch operation is not supported for completions from files"
|
"Parallel sampling is not supported for completions from files"
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(all_requests) == 1:
|
if len(all_requests) == 1:
|
||||||
@@ -377,10 +380,11 @@ def v1_generate_request(all_requests):
|
|||||||
else:
|
else:
|
||||||
prompt_kwargs = {"input_ids": prompt}
|
prompt_kwargs = {"input_ids": prompt}
|
||||||
else:
|
else:
|
||||||
if isinstance(prompts[0], str):
|
if isinstance(prompts[0], str) or isinstance(propmt[0][0], str):
|
||||||
prompt_kwargs = {"text": prompts}
|
prompt_kwargs = {"text": prompts}
|
||||||
else:
|
else:
|
||||||
prompt_kwargs = {"input_ids": prompts}
|
prompt_kwargs = {"input_ids": prompts}
|
||||||
|
|
||||||
adapted_request = GenerateReqInput(
|
adapted_request = GenerateReqInput(
|
||||||
**prompt_kwargs,
|
**prompt_kwargs,
|
||||||
sampling_params=sampling_params_list,
|
sampling_params=sampling_params_list,
|
||||||
@@ -389,35 +393,52 @@ def v1_generate_request(all_requests):
|
|||||||
return_text_in_logprobs=True,
|
return_text_in_logprobs=True,
|
||||||
stream=all_requests[0].stream,
|
stream=all_requests[0].stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(all_requests) == 1:
|
if len(all_requests) == 1:
|
||||||
return adapted_request, all_requests[0]
|
return adapted_request, all_requests[0]
|
||||||
return adapted_request, all_requests
|
return adapted_request, all_requests
|
||||||
|
|
||||||
|
|
||||||
def v1_generate_response(request, ret, to_file=False):
|
def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
||||||
choices = []
|
choices = []
|
||||||
echo = False
|
echo = False
|
||||||
|
|
||||||
if (not isinstance(request, List)) and request.echo:
|
if (not isinstance(request, list)) and request.echo:
|
||||||
# TODO: handle the case propmt is token ids
|
# TODO: handle the case propmt is token ids
|
||||||
if isinstance(request.prompt, list):
|
if isinstance(request.prompt, list) and isinstance(request.prompt[0], str):
|
||||||
|
# for the case of multiple str prompts
|
||||||
prompts = request.prompt
|
prompts = request.prompt
|
||||||
|
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list):
|
||||||
|
# for the case of multiple token ids prompts
|
||||||
|
prompts = [
|
||||||
|
tokenizer_manager.tokenizer.decode(prompt, skip_special_tokens=True)
|
||||||
|
for prompt in request.prompt
|
||||||
|
]
|
||||||
|
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int):
|
||||||
|
# for the case of single token ids prompt
|
||||||
|
prompts = [
|
||||||
|
tokenizer_manager.tokenizer.decode(
|
||||||
|
request.prompt, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
|
# for the case of single str prompt
|
||||||
prompts = [request.prompt]
|
prompts = [request.prompt]
|
||||||
echo = True
|
echo = True
|
||||||
|
|
||||||
for idx, ret_item in enumerate(ret):
|
for idx, ret_item in enumerate(ret):
|
||||||
text = ret_item["text"]
|
text = ret_item["text"]
|
||||||
if isinstance(request, List) and request[idx].echo:
|
if isinstance(request, list) and request[idx].echo:
|
||||||
echo = True
|
echo = True
|
||||||
text = request[idx].prompt + text
|
text = request[idx].prompt + text
|
||||||
if (not isinstance(request, List)) and echo:
|
if (not isinstance(request, list)) and echo:
|
||||||
text = prompts[idx] + text
|
prompt_index = idx // request.n
|
||||||
|
text = prompts[prompt_index] + text
|
||||||
|
|
||||||
logprobs = False
|
logprobs = False
|
||||||
if isinstance(request, List) and request[idx].logprobs:
|
if isinstance(request, list) and request[idx].logprobs:
|
||||||
logprobs = True
|
logprobs = True
|
||||||
elif (not isinstance(request, List)) and request.logprobs:
|
elif (not isinstance(request, list)) and request.logprobs:
|
||||||
logprobs = True
|
logprobs = True
|
||||||
if logprobs:
|
if logprobs:
|
||||||
if echo:
|
if echo:
|
||||||
@@ -479,15 +500,16 @@ def v1_generate_response(request, ret, to_file=False):
|
|||||||
responses.append(response)
|
responses.append(response)
|
||||||
return responses
|
return responses
|
||||||
else:
|
else:
|
||||||
|
prompt_tokens = sum(item["meta_info"]["prompt_tokens"] for item in ret)
|
||||||
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
|
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
|
||||||
response = CompletionResponse(
|
response = CompletionResponse(
|
||||||
id=ret[0]["meta_info"]["id"],
|
id=ret[0]["meta_info"]["id"],
|
||||||
model=request.model,
|
model=request.model,
|
||||||
choices=choices,
|
choices=choices,
|
||||||
usage=UsageInfo(
|
usage=UsageInfo(
|
||||||
prompt_tokens=ret[0]["meta_info"]["prompt_tokens"],
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
total_tokens=ret[0]["meta_info"]["prompt_tokens"] + completion_tokens,
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
@@ -513,8 +535,18 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
|
|
||||||
if not stream_buffer: # The first chunk
|
if not stream_buffer: # The first chunk
|
||||||
if request.echo:
|
if request.echo:
|
||||||
|
if isinstance(request.prompt, str):
|
||||||
|
# for the case of single str prompts
|
||||||
|
prompts = request.prompt
|
||||||
|
elif isinstance(request.prompt, list) and isinstance(
|
||||||
|
request.prompt[0], int
|
||||||
|
):
|
||||||
|
prompts = tokenizer_manager.tokenizer.decode(
|
||||||
|
request.prompt, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
# Prepend prompt in response text.
|
# Prepend prompt in response text.
|
||||||
text = request.prompt + text
|
text = prompts + text
|
||||||
|
|
||||||
if request.logprobs:
|
if request.logprobs:
|
||||||
# The first chunk and echo is enabled.
|
# The first chunk and echo is enabled.
|
||||||
@@ -539,7 +571,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
"output_top_logprobs"
|
"output_top_logprobs"
|
||||||
][n_prev_token:],
|
][n_prev_token:],
|
||||||
)
|
)
|
||||||
|
|
||||||
n_prev_token = len(
|
n_prev_token = len(
|
||||||
content["meta_info"]["output_token_logprobs"]
|
content["meta_info"]["output_token_logprobs"]
|
||||||
)
|
)
|
||||||
@@ -588,7 +619,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
if not isinstance(ret, list):
|
if not isinstance(ret, list):
|
||||||
ret = [ret]
|
ret = [ret]
|
||||||
|
|
||||||
response = v1_generate_response(request, ret)
|
response = v1_generate_response(request, ret, tokenizer_manager)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@@ -626,7 +657,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|||||||
prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
|
prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
|
||||||
else:
|
else:
|
||||||
# Use the raw prompt and stop strings if the messages is already a string.
|
# Use the raw prompt and stop strings if the messages is already a string.
|
||||||
prompt = request.messages
|
prompt_ids = request.messages
|
||||||
stop = request.stop
|
stop = request.stop
|
||||||
image_data = None
|
image_data = None
|
||||||
input_ids.append(prompt_ids)
|
input_ids.append(prompt_ids)
|
||||||
@@ -647,12 +678,21 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|||||||
image_data_list.append(image_data)
|
image_data_list.append(image_data)
|
||||||
if len(all_requests) == 1:
|
if len(all_requests) == 1:
|
||||||
input_ids = input_ids[0]
|
input_ids = input_ids[0]
|
||||||
|
if isinstance(input_ids, str):
|
||||||
|
prompt_kwargs = {"text": input_ids}
|
||||||
|
else:
|
||||||
|
prompt_kwargs = {"input_ids": input_ids}
|
||||||
sampling_params_list = sampling_params_list[0]
|
sampling_params_list = sampling_params_list[0]
|
||||||
image_data = image_data_list[0]
|
image_data = image_data_list[0]
|
||||||
return_logprobs = return_logprobs[0]
|
return_logprobs = return_logprobs[0]
|
||||||
top_logprobs_nums = top_logprobs_nums[0]
|
top_logprobs_nums = top_logprobs_nums[0]
|
||||||
|
else:
|
||||||
|
if isinstance(input_ids[0], str):
|
||||||
|
prompt_kwargs = {"text": input_ids}
|
||||||
|
else:
|
||||||
|
prompt_kwargs = {"input_ids": input_ids}
|
||||||
adapted_request = GenerateReqInput(
|
adapted_request = GenerateReqInput(
|
||||||
input_ids=input_ids,
|
**prompt_kwargs,
|
||||||
image_data=image_data,
|
image_data=image_data,
|
||||||
sampling_params=sampling_params_list,
|
sampling_params=sampling_params_list,
|
||||||
return_logprob=return_logprobs,
|
return_logprob=return_logprobs,
|
||||||
@@ -672,9 +712,9 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|||||||
|
|
||||||
for idx, ret_item in enumerate(ret):
|
for idx, ret_item in enumerate(ret):
|
||||||
logprobs = False
|
logprobs = False
|
||||||
if isinstance(request, List) and request[idx].logprobs:
|
if isinstance(request, list) and request[idx].logprobs:
|
||||||
logprobs = True
|
logprobs = True
|
||||||
elif (not isinstance(request, List)) and request.logprobs:
|
elif (not isinstance(request, list)) and request.logprobs:
|
||||||
logprobs = True
|
logprobs = True
|
||||||
if logprobs:
|
if logprobs:
|
||||||
logprobs = to_openai_style_logprobs(
|
logprobs = to_openai_style_logprobs(
|
||||||
@@ -779,10 +819,58 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
is_first = True
|
is_first = True
|
||||||
|
|
||||||
stream_buffer = ""
|
stream_buffer = ""
|
||||||
|
n_prev_token = 0
|
||||||
try:
|
try:
|
||||||
async for content in tokenizer_manager.generate_request(
|
async for content in tokenizer_manager.generate_request(
|
||||||
adapted_request, raw_request
|
adapted_request, raw_request
|
||||||
):
|
):
|
||||||
|
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
||||||
|
completion_tokens = content["meta_info"]["completion_tokens"]
|
||||||
|
if request.logprobs:
|
||||||
|
logprobs = to_openai_style_logprobs(
|
||||||
|
output_token_logprobs=content["meta_info"][
|
||||||
|
"output_token_logprobs"
|
||||||
|
][n_prev_token:],
|
||||||
|
output_top_logprobs=content["meta_info"][
|
||||||
|
"output_top_logprobs"
|
||||||
|
][n_prev_token:],
|
||||||
|
)
|
||||||
|
|
||||||
|
n_prev_token = len(
|
||||||
|
content["meta_info"]["output_token_logprobs"]
|
||||||
|
)
|
||||||
|
token_logprobs = []
|
||||||
|
for token, logprob in zip(
|
||||||
|
logprobs.tokens, logprobs.token_logprobs
|
||||||
|
):
|
||||||
|
token_bytes = list(token.encode("utf-8"))
|
||||||
|
top_logprobs = []
|
||||||
|
if logprobs.top_logprobs:
|
||||||
|
for top_token, top_logprob in logprobs.top_logprobs[
|
||||||
|
0
|
||||||
|
].items():
|
||||||
|
top_token_bytes = list(top_token.encode("utf-8"))
|
||||||
|
top_logprobs.append(
|
||||||
|
TopLogprob(
|
||||||
|
token=top_token,
|
||||||
|
bytes=top_token_bytes,
|
||||||
|
logprob=top_logprob,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
token_logprobs.append(
|
||||||
|
ChatCompletionTokenLogprob(
|
||||||
|
token=token,
|
||||||
|
bytes=token_bytes,
|
||||||
|
logprob=logprob,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
choice_logprobs = ChoiceLogprobs(content=token_logprobs)
|
||||||
|
|
||||||
|
else:
|
||||||
|
choice_logprobs = None
|
||||||
|
|
||||||
if is_first:
|
if is_first:
|
||||||
# First chunk with role
|
# First chunk with role
|
||||||
is_first = False
|
is_first = False
|
||||||
@@ -790,11 +878,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
index=0,
|
index=0,
|
||||||
delta=DeltaMessage(role="assistant"),
|
delta=DeltaMessage(role="assistant"),
|
||||||
finish_reason=content["meta_info"]["finish_reason"],
|
finish_reason=content["meta_info"]["finish_reason"],
|
||||||
|
logprobs=choice_logprobs,
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(
|
chunk = ChatCompletionStreamResponse(
|
||||||
id=content["meta_info"]["id"],
|
id=content["meta_info"]["id"],
|
||||||
choices=[choice_data],
|
choices=[choice_data],
|
||||||
model=request.model,
|
model=request.model,
|
||||||
|
usage=UsageInfo(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
|
|
||||||
@@ -805,11 +899,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
index=0,
|
index=0,
|
||||||
delta=DeltaMessage(content=delta),
|
delta=DeltaMessage(content=delta),
|
||||||
finish_reason=content["meta_info"]["finish_reason"],
|
finish_reason=content["meta_info"]["finish_reason"],
|
||||||
|
logprobs=choice_logprobs,
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(
|
chunk = ChatCompletionStreamResponse(
|
||||||
id=content["meta_info"]["id"],
|
id=content["meta_info"]["id"],
|
||||||
choices=[choice_data],
|
choices=[choice_data],
|
||||||
model=request.model,
|
model=request.model,
|
||||||
|
usage=UsageInfo(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
|||||||
@@ -278,7 +278,7 @@ class DeltaMessage(BaseModel):
|
|||||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||||
index: int
|
index: int
|
||||||
delta: DeltaMessage
|
delta: DeltaMessage
|
||||||
logprobs: Optional[LogProbs] = None
|
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
|
||||||
finish_reason: Optional[str] = None
|
finish_reason: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import unittest
|
|||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
from sglang.srt.utils import kill_child_process
|
from sglang.srt.utils import kill_child_process
|
||||||
from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server
|
from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server
|
||||||
|
|
||||||
@@ -18,60 +19,85 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
cls.model, cls.base_url, timeout=300, api_key=cls.api_key
|
cls.model, cls.base_url, timeout=300, api_key=cls.api_key
|
||||||
)
|
)
|
||||||
cls.base_url += "/v1"
|
cls.base_url += "/v1"
|
||||||
|
cls.tokenizer = get_tokenizer(MODEL_NAME_FOR_TEST)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid)
|
||||||
|
|
||||||
def run_completion(self, echo, logprobs, use_list_input):
|
def run_completion(
|
||||||
|
self, echo, logprobs, use_list_input, parallel_sample_num, token_input
|
||||||
|
):
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
prompt = "The capital of France is"
|
prompt = "The capital of France is"
|
||||||
|
if token_input:
|
||||||
|
prompt_input = self.tokenizer.encode(prompt)
|
||||||
|
num_prompt_tokens = len(prompt_input)
|
||||||
|
else:
|
||||||
|
prompt_input = prompt
|
||||||
|
num_prompt_tokens = len(self.tokenizer.encode(prompt))
|
||||||
|
|
||||||
if use_list_input:
|
if use_list_input:
|
||||||
prompt_arg = [prompt, prompt]
|
prompt_arg = [prompt_input, prompt_input]
|
||||||
num_choices = len(prompt_arg)
|
num_choices = len(prompt_arg)
|
||||||
|
num_prompt_tokens *= 2
|
||||||
else:
|
else:
|
||||||
prompt_arg = prompt
|
prompt_arg = prompt_input
|
||||||
num_choices = 1
|
num_choices = 1
|
||||||
|
|
||||||
|
if parallel_sample_num:
|
||||||
|
# FIXME: This is wrong. We should not count the prompt tokens multiple times for
|
||||||
|
# parallel sampling.
|
||||||
|
num_prompt_tokens *= parallel_sample_num
|
||||||
|
|
||||||
response = client.completions.create(
|
response = client.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
prompt=prompt_arg,
|
prompt=prompt_arg,
|
||||||
temperature=0.1,
|
temperature=0,
|
||||||
max_tokens=32,
|
max_tokens=32,
|
||||||
echo=echo,
|
echo=echo,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
n=parallel_sample_num,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(response.choices) == num_choices
|
assert len(response.choices) == num_choices * parallel_sample_num
|
||||||
|
|
||||||
if echo:
|
if echo:
|
||||||
text = response.choices[0].text
|
text = response.choices[0].text
|
||||||
assert text.startswith(prompt)
|
assert text.startswith(prompt)
|
||||||
|
|
||||||
if logprobs:
|
if logprobs:
|
||||||
assert response.choices[0].logprobs
|
assert response.choices[0].logprobs
|
||||||
assert isinstance(response.choices[0].logprobs.tokens[0], str)
|
assert isinstance(response.choices[0].logprobs.tokens[0], str)
|
||||||
assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
|
assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
|
||||||
ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1])
|
ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1])
|
||||||
# FIXME: Fix this bug. Sometimes, some top_logprobs are missing in the return value.
|
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
|
||||||
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
||||||
|
assert ret_num_top_logprobs > 0
|
||||||
if echo:
|
if echo:
|
||||||
assert response.choices[0].logprobs.token_logprobs[0] == None
|
assert response.choices[0].logprobs.token_logprobs[0] == None
|
||||||
else:
|
else:
|
||||||
assert response.choices[0].logprobs.token_logprobs[0] != None
|
assert response.choices[0].logprobs.token_logprobs[0] != None
|
||||||
|
|
||||||
assert response.id
|
assert response.id
|
||||||
assert response.created
|
assert response.created
|
||||||
assert response.usage.prompt_tokens > 0
|
assert (
|
||||||
|
response.usage.prompt_tokens == num_prompt_tokens
|
||||||
|
), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
|
||||||
assert response.usage.completion_tokens > 0
|
assert response.usage.completion_tokens > 0
|
||||||
assert response.usage.total_tokens > 0
|
assert response.usage.total_tokens > 0
|
||||||
|
|
||||||
def run_completion_stream(self, echo, logprobs):
|
def run_completion_stream(self, echo, logprobs, token_input):
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
prompt = "The capital of France is"
|
prompt = "The capital of France is"
|
||||||
|
if token_input:
|
||||||
|
prompt_arg = self.tokenizer.encode(prompt)
|
||||||
|
else:
|
||||||
|
prompt_arg = prompt
|
||||||
generator = client.completions.create(
|
generator = client.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
prompt=prompt,
|
prompt=prompt_arg,
|
||||||
temperature=0.1,
|
temperature=0,
|
||||||
max_tokens=32,
|
max_tokens=32,
|
||||||
echo=echo,
|
echo=echo,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
@@ -90,12 +116,15 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
ret_num_top_logprobs = len(
|
ret_num_top_logprobs = len(
|
||||||
response.choices[0].logprobs.top_logprobs[0]
|
response.choices[0].logprobs.top_logprobs[0]
|
||||||
)
|
)
|
||||||
# FIXME: Fix this bug. Sometimes, some top_logprobs are missing in the return value.
|
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
|
||||||
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
||||||
|
assert ret_num_top_logprobs > 0
|
||||||
|
|
||||||
if first:
|
if first:
|
||||||
if echo:
|
if echo:
|
||||||
assert response.choices[0].text.startswith(prompt)
|
assert response.choices[0].text.startswith(
|
||||||
|
prompt
|
||||||
|
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {first}"
|
||||||
first = False
|
first = False
|
||||||
|
|
||||||
assert response.id
|
assert response.id
|
||||||
@@ -104,7 +133,7 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
assert response.usage.completion_tokens > 0
|
assert response.usage.completion_tokens > 0
|
||||||
assert response.usage.total_tokens > 0
|
assert response.usage.total_tokens > 0
|
||||||
|
|
||||||
def run_chat_completion(self, logprobs):
|
def run_chat_completion(self, logprobs, parallel_sample_num):
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
@@ -116,6 +145,7 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
max_tokens=32,
|
max_tokens=32,
|
||||||
logprobs=logprobs is not None and logprobs > 0,
|
logprobs=logprobs is not None and logprobs > 0,
|
||||||
top_logprobs=logprobs,
|
top_logprobs=logprobs,
|
||||||
|
n=parallel_sample_num,
|
||||||
)
|
)
|
||||||
if logprobs:
|
if logprobs:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
@@ -128,7 +158,7 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
assert (
|
assert (
|
||||||
ret_num_top_logprobs == logprobs
|
ret_num_top_logprobs == logprobs
|
||||||
), f"{ret_num_top_logprobs} vs {logprobs}"
|
), f"{ret_num_top_logprobs} vs {logprobs}"
|
||||||
|
assert len(response.choices) == parallel_sample_num
|
||||||
assert response.choices[0].message.role == "assistant"
|
assert response.choices[0].message.role == "assistant"
|
||||||
assert isinstance(response.choices[0].message.content, str)
|
assert isinstance(response.choices[0].message.content, str)
|
||||||
assert response.id
|
assert response.id
|
||||||
@@ -161,11 +191,21 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if logprobs:
|
if logprobs:
|
||||||
# FIXME: Fix this bug. Return top_logprobs in the streaming mode.
|
assert response.choices[0].logprobs
|
||||||
pass
|
assert isinstance(
|
||||||
|
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
|
||||||
|
)
|
||||||
|
assert isinstance(
|
||||||
|
response.choices[0].logprobs.content[0].top_logprobs, list
|
||||||
|
)
|
||||||
|
ret_num_top_logprobs = len(
|
||||||
|
response.choices[0].logprobs.content[0].top_logprobs
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
ret_num_top_logprobs == logprobs
|
||||||
|
), f"{ret_num_top_logprobs} vs {logprobs}"
|
||||||
|
|
||||||
assert isinstance(data.content, str)
|
assert isinstance(data.content, str)
|
||||||
|
|
||||||
assert response.id
|
assert response.id
|
||||||
assert response.created
|
assert response.created
|
||||||
|
|
||||||
@@ -173,16 +213,27 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
for echo in [False, True]:
|
for echo in [False, True]:
|
||||||
for logprobs in [None, 5]:
|
for logprobs in [None, 5]:
|
||||||
for use_list_input in [True, False]:
|
for use_list_input in [True, False]:
|
||||||
self.run_completion(echo, logprobs, use_list_input)
|
for parallel_sample_num in [1, 2]:
|
||||||
|
for token_input in [False, True]:
|
||||||
|
self.run_completion(
|
||||||
|
echo,
|
||||||
|
logprobs,
|
||||||
|
use_list_input,
|
||||||
|
parallel_sample_num,
|
||||||
|
token_input,
|
||||||
|
)
|
||||||
|
|
||||||
def test_completion_stream(self):
|
def test_completion_stream(self):
|
||||||
|
# parallel sampling adn list input are not supported in streaming mode
|
||||||
for echo in [False, True]:
|
for echo in [False, True]:
|
||||||
for logprobs in [None, 5]:
|
for logprobs in [None, 5]:
|
||||||
self.run_completion_stream(echo, logprobs)
|
for token_input in [False, True]:
|
||||||
|
self.run_completion_stream(echo, logprobs, token_input)
|
||||||
|
|
||||||
def test_chat_completion(self):
|
def test_chat_completion(self):
|
||||||
for logprobs in [None, 5]:
|
for logprobs in [None, 5]:
|
||||||
self.run_chat_completion(logprobs)
|
for parallel_sample_num in [1, 2]:
|
||||||
|
self.run_chat_completion(logprobs, parallel_sample_num)
|
||||||
|
|
||||||
def test_chat_completion_stream(self):
|
def test_chat_completion_stream(self):
|
||||||
for logprobs in [None, 5]:
|
for logprobs in [None, 5]:
|
||||||
@@ -224,5 +275,5 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# t = TestOpenAIServer()
|
# t = TestOpenAIServer()
|
||||||
# t.setUpClass()
|
# t.setUpClass()
|
||||||
# t.test_chat_completion_stream()
|
# t.test_completion()
|
||||||
# t.tearDownClass()
|
# t.tearDownClass()
|
||||||
|
|||||||
Reference in New Issue
Block a user