Rename prefill_token_logprobs -> input_token_logprobs; decode_token_logprobs -> output_token_logprobs (#776)
This commit is contained in:
@@ -56,14 +56,14 @@ def srt_api_request(name):
|
||||
# fout.write(json.dumps(res, indent=4))
|
||||
|
||||
meta_info = res["meta_info"]
|
||||
assert len(meta_info["prefill_token_logprobs"]) == len(
|
||||
meta_info["prefill_top_logprobs"]
|
||||
assert len(meta_info["input_token_logprobs"]) == len(
|
||||
meta_info["input_top_logprobs"]
|
||||
)
|
||||
assert len(meta_info["decode_token_logprobs"]) == len(
|
||||
meta_info["decode_top_logprobs"]
|
||||
assert len(meta_info["output_token_logprobs"]) == len(
|
||||
meta_info["output_top_logprobs"]
|
||||
)
|
||||
assert len(meta_info["prefill_token_logprobs"]) == meta_info["prompt_tokens"]
|
||||
assert len(meta_info["decode_token_logprobs"]) == meta_info["completion_tokens"] - 1
|
||||
assert len(meta_info["input_token_logprobs"]) == meta_info["prompt_tokens"]
|
||||
assert len(meta_info["output_token_logprobs"]) == meta_info["completion_tokens"] - 1
|
||||
|
||||
return res
|
||||
|
||||
@@ -72,11 +72,11 @@ def pretty_print(res):
|
||||
meta_info = res["meta_info"]
|
||||
|
||||
print("\n\n", "=" * 30, "Prefill", "=" * 30)
|
||||
for i in range(len(meta_info["prefill_token_logprobs"])):
|
||||
print(f"{str(meta_info['prefill_token_logprobs'][i][2].encode()): <20}", end="")
|
||||
for i in range(len(meta_info["input_token_logprobs"])):
|
||||
print(f"{str(meta_info['input_token_logprobs'][i][2].encode()): <20}", end="")
|
||||
top_ks = (
|
||||
[str(t[2].encode()) for t in meta_info["prefill_top_logprobs"][i]]
|
||||
if meta_info["prefill_top_logprobs"][i]
|
||||
[str(t[2].encode()) for t in meta_info["input_top_logprobs"][i]]
|
||||
if meta_info["input_top_logprobs"][i]
|
||||
else []
|
||||
)
|
||||
for top_k in top_ks:
|
||||
@@ -84,9 +84,9 @@ def pretty_print(res):
|
||||
print()
|
||||
|
||||
print("\n\n", "=" * 30, "Decode", "=" * 30)
|
||||
for i in range(len(meta_info["decode_token_logprobs"])):
|
||||
print(f"{str(meta_info['decode_token_logprobs'][i][2].encode()): <20}", end="")
|
||||
top_ks = [str(t[2].encode()) for t in meta_info["decode_top_logprobs"][i]]
|
||||
for i in range(len(meta_info["output_token_logprobs"])):
|
||||
print(f"{str(meta_info['output_token_logprobs'][i][2].encode()): <20}", end="")
|
||||
top_ks = [str(t[2].encode()) for t in meta_info["output_top_logprobs"][i]]
|
||||
for top_k in top_ks:
|
||||
print(f"{top_k: <15}", end="")
|
||||
print()
|
||||
|
||||
Reference in New Issue
Block a user