Rename prefill_token_logprobs -> input_token_logprobs; decode_token_logprobs -> output_token_logprobs (#776)

This commit is contained in:
Lianmin Zheng
2024-07-27 19:50:34 -07:00
committed by GitHub
parent 0a409bd438
commit 30db99b3d9
16 changed files with 188 additions and 184 deletions

View File

@@ -56,14 +56,14 @@ def srt_api_request(name):
# fout.write(json.dumps(res, indent=4))
meta_info = res["meta_info"]
assert len(meta_info["prefill_token_logprobs"]) == len(
meta_info["prefill_top_logprobs"]
assert len(meta_info["input_token_logprobs"]) == len(
meta_info["input_top_logprobs"]
)
assert len(meta_info["decode_token_logprobs"]) == len(
meta_info["decode_top_logprobs"]
assert len(meta_info["output_token_logprobs"]) == len(
meta_info["output_top_logprobs"]
)
assert len(meta_info["prefill_token_logprobs"]) == meta_info["prompt_tokens"]
assert len(meta_info["decode_token_logprobs"]) == meta_info["completion_tokens"] - 1
assert len(meta_info["input_token_logprobs"]) == meta_info["prompt_tokens"]
assert len(meta_info["output_token_logprobs"]) == meta_info["completion_tokens"] - 1
return res
@@ -72,11 +72,11 @@ def pretty_print(res):
meta_info = res["meta_info"]
print("\n\n", "=" * 30, "Prefill", "=" * 30)
for i in range(len(meta_info["prefill_token_logprobs"])):
print(f"{str(meta_info['prefill_token_logprobs'][i][2].encode()): <20}", end="")
for i in range(len(meta_info["input_token_logprobs"])):
print(f"{str(meta_info['input_token_logprobs'][i][2].encode()): <20}", end="")
top_ks = (
[str(t[2].encode()) for t in meta_info["prefill_top_logprobs"][i]]
if meta_info["prefill_top_logprobs"][i]
[str(t[2].encode()) for t in meta_info["input_top_logprobs"][i]]
if meta_info["input_top_logprobs"][i]
else []
)
for top_k in top_ks:
@@ -84,9 +84,9 @@ def pretty_print(res):
print()
print("\n\n", "=" * 30, "Decode", "=" * 30)
for i in range(len(meta_info["decode_token_logprobs"])):
print(f"{str(meta_info['decode_token_logprobs'][i][2].encode()): <20}", end="")
top_ks = [str(t[2].encode()) for t in meta_info["decode_top_logprobs"][i]]
for i in range(len(meta_info["output_token_logprobs"])):
print(f"{str(meta_info['output_token_logprobs'][i][2].encode()): <20}", end="")
top_ks = [str(t[2].encode()) for t in meta_info["output_top_logprobs"][i]]
for top_k in top_ks:
print(f"{top_k: <15}", end="")
print()