Fix: incorrect top_logprobs in chat completion (#2088)
This commit is contained in:
@@ -989,11 +989,15 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
|||||||
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
|
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
|
||||||
)
|
)
|
||||||
token_logprobs = []
|
token_logprobs = []
|
||||||
for token, logprob in zip(logprobs.tokens, logprobs.token_logprobs):
|
for token_idx, (token, logprob) in enumerate(
|
||||||
|
zip(logprobs.tokens, logprobs.token_logprobs)
|
||||||
|
):
|
||||||
token_bytes = list(token.encode("utf-8"))
|
token_bytes = list(token.encode("utf-8"))
|
||||||
top_logprobs = []
|
top_logprobs = []
|
||||||
if logprobs.top_logprobs:
|
if logprobs.top_logprobs:
|
||||||
for top_token, top_logprob in logprobs.top_logprobs[0].items():
|
for top_token, top_logprob in logprobs.top_logprobs[
|
||||||
|
token_idx
|
||||||
|
].items():
|
||||||
top_token_bytes = list(top_token.encode("utf-8"))
|
top_token_bytes = list(top_token.encode("utf-8"))
|
||||||
top_logprobs.append(
|
top_logprobs.append(
|
||||||
TopLogprob(
|
TopLogprob(
|
||||||
|
|||||||
Reference in New Issue
Block a user