Fix prompt len in parallel sampling (#928)

This commit is contained in:
yichuan~
2024-08-05 15:56:08 +08:00
committed by GitHub
parent 399cad91f3
commit fd7926e46e
2 changed files with 11 additions and 15 deletions

View File

@@ -500,7 +500,9 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
responses.append(response)
return responses
else:
prompt_tokens = sum(item["meta_info"]["prompt_tokens"] for item in ret)
prompt_tokens = sum(
ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
)
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
response = CompletionResponse(
id=ret[0]["meta_info"]["id"],
@@ -707,8 +709,6 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
def v1_chat_generate_response(request, ret, to_file=False):
choices = []
total_prompt_tokens = 0
total_completion_tokens = 0
for idx, ret_item in enumerate(ret):
logprobs = False
@@ -747,8 +747,6 @@ def v1_chat_generate_response(request, ret, to_file=False):
choice_logprobs = ChoiceLogprobs(content=token_logprobs)
else:
choice_logprobs = None
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
completion_tokens = ret_item["meta_info"]["completion_tokens"]
if to_file:
# to make the choice data json serializable
@@ -767,8 +765,7 @@ def v1_chat_generate_response(request, ret, to_file=False):
)
choices.append(choice_data)
total_prompt_tokens += prompt_tokens
total_completion_tokens += completion_tokens
if to_file:
responses = []
@@ -795,14 +792,18 @@ def v1_chat_generate_response(request, ret, to_file=False):
responses.append(response)
return responses
else:
prompt_tokens = sum(
ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
)
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
response = ChatCompletionResponse(
id=ret[0]["meta_info"]["id"],
model=request.model,
choices=choices,
usage=UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return response