Fix prompt len in parallel sampling (#928)
This commit is contained in:
@@ -500,7 +500,9 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
|||||||
responses.append(response)
|
responses.append(response)
|
||||||
return responses
|
return responses
|
||||||
else:
|
else:
|
||||||
prompt_tokens = sum(item["meta_info"]["prompt_tokens"] for item in ret)
|
prompt_tokens = sum(
|
||||||
|
ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
|
||||||
|
)
|
||||||
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
|
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
|
||||||
response = CompletionResponse(
|
response = CompletionResponse(
|
||||||
id=ret[0]["meta_info"]["id"],
|
id=ret[0]["meta_info"]["id"],
|
||||||
@@ -707,8 +709,6 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|||||||
|
|
||||||
def v1_chat_generate_response(request, ret, to_file=False):
|
def v1_chat_generate_response(request, ret, to_file=False):
|
||||||
choices = []
|
choices = []
|
||||||
total_prompt_tokens = 0
|
|
||||||
total_completion_tokens = 0
|
|
||||||
|
|
||||||
for idx, ret_item in enumerate(ret):
|
for idx, ret_item in enumerate(ret):
|
||||||
logprobs = False
|
logprobs = False
|
||||||
@@ -747,8 +747,6 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|||||||
choice_logprobs = ChoiceLogprobs(content=token_logprobs)
|
choice_logprobs = ChoiceLogprobs(content=token_logprobs)
|
||||||
else:
|
else:
|
||||||
choice_logprobs = None
|
choice_logprobs = None
|
||||||
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
|
|
||||||
completion_tokens = ret_item["meta_info"]["completion_tokens"]
|
|
||||||
|
|
||||||
if to_file:
|
if to_file:
|
||||||
# to make the choice data json serializable
|
# to make the choice data json serializable
|
||||||
@@ -767,8 +765,7 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|||||||
)
|
)
|
||||||
|
|
||||||
choices.append(choice_data)
|
choices.append(choice_data)
|
||||||
total_prompt_tokens += prompt_tokens
|
|
||||||
total_completion_tokens += completion_tokens
|
|
||||||
if to_file:
|
if to_file:
|
||||||
responses = []
|
responses = []
|
||||||
|
|
||||||
@@ -795,14 +792,18 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|||||||
responses.append(response)
|
responses.append(response)
|
||||||
return responses
|
return responses
|
||||||
else:
|
else:
|
||||||
|
prompt_tokens = sum(
|
||||||
|
ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
|
||||||
|
)
|
||||||
|
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
|
||||||
response = ChatCompletionResponse(
|
response = ChatCompletionResponse(
|
||||||
id=ret[0]["meta_info"]["id"],
|
id=ret[0]["meta_info"]["id"],
|
||||||
model=request.model,
|
model=request.model,
|
||||||
choices=choices,
|
choices=choices,
|
||||||
usage=UsageInfo(
|
usage=UsageInfo(
|
||||||
prompt_tokens=total_prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=total_completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
total_tokens=total_prompt_tokens + total_completion_tokens,
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -45,11 +45,6 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
prompt_arg = prompt_input
|
prompt_arg = prompt_input
|
||||||
num_choices = 1
|
num_choices = 1
|
||||||
|
|
||||||
if parallel_sample_num:
|
|
||||||
# FIXME: This is wrong. We should not count the prompt tokens multiple times for
|
|
||||||
# parallel sampling.
|
|
||||||
num_prompt_tokens *= parallel_sample_num
|
|
||||||
|
|
||||||
response = client.completions.create(
|
response = client.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
prompt=prompt_arg,
|
prompt=prompt_arg,
|
||||||
|
|||||||
Reference in New Issue
Block a user