Rename prefill_token_logprobs -> input_token_logprobs; decode_token_logprobs -> output_token_logprobs (#776)

This commit is contained in:
Lianmin Zheng
2024-07-27 19:50:34 -07:00
committed by GitHub
parent 0a409bd438
commit 30db99b3d9
16 changed files with 188 additions and 184 deletions

View File

@@ -20,8 +20,8 @@ def main():
print("questions:", question)
print("choice:", state["tool"])
meta_info = state.get_meta_info("tool")
print("logprobs of choice 1", meta_info["prefill_token_logprobs"][0])
print("logprobs of choice 2", meta_info["prefill_token_logprobs"][1])
print("logprobs of choice 1", meta_info["input_token_logprobs"][0])
print("logprobs of choice 2", meta_info["input_token_logprobs"][1])
print("-" * 50)
# Run a batch
@@ -34,8 +34,8 @@ def main():
print("questions:", question)
print("choice:", state["tool"])
meta_info = state.get_meta_info("tool")
print("logprobs of choice 1", meta_info["prefill_token_logprobs"][0])
print("logprobs of choice 2", meta_info["prefill_token_logprobs"][1])
print("logprobs of choice 1", meta_info["input_token_logprobs"][0])
print("logprobs of choice 2", meta_info["input_token_logprobs"][1])
print("-" * 50)

View File

@@ -31,7 +31,7 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
top_logprobs_num=get_top_k,
return_text_in_logprobs=True,
)
logprobs = step_0.get_meta_info("get_top_k")["decode_top_logprobs"][0]
logprobs = step_0.get_meta_info("get_top_k")["output_top_logprobs"][0]
print("Decoding step 0:", ", ".join(pformat(token[2]) for token in logprobs))
for idx, (f, token) in enumerate(zip(forks, logprobs)):
@@ -55,9 +55,9 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
)
# calculate probability disparity between the top and secondary tokens
x1s = [exp(xt[0][0]) for xt in f.get_meta_info("answer")["decode_top_logprobs"]]
x2s = [exp(xt[1][0]) for xt in f.get_meta_info("answer")["decode_top_logprobs"]]
tokens = [xt[0][2] for xt in f.get_meta_info("answer")["decode_top_logprobs"]]
x1s = [exp(xt[0][0]) for xt in f.get_meta_info("answer")["output_top_logprobs"]]
x2s = [exp(xt[1][0]) for xt in f.get_meta_info("answer")["output_top_logprobs"]]
tokens = [xt[0][2] for xt in f.get_meta_info("answer")["output_top_logprobs"]]
delta = (sum(x1s) - sum(x2s)) / len(x1s)
# extract the answer span (without the '<|end_of_text|>' token)
@@ -81,19 +81,19 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
answer_tokens = [
xt[0][2]
for xt in answer_forks[idx].get_meta_info("answer_span")[
"decode_top_logprobs"
"output_top_logprobs"
]
]
answer_x1s = [
exp(xt[0][0])
for xt in answer_forks[idx].get_meta_info("answer_span")[
"decode_top_logprobs"
"output_top_logprobs"
]
]
answer_x2s = [
exp(xt[1][0])
for xt in answer_forks[idx].get_meta_info("answer_span")[
"decode_top_logprobs"
"output_top_logprobs"
]
]

View File

@@ -56,14 +56,14 @@ def srt_api_request(name):
# fout.write(json.dumps(res, indent=4))
meta_info = res["meta_info"]
assert len(meta_info["prefill_token_logprobs"]) == len(
meta_info["prefill_top_logprobs"]
assert len(meta_info["input_token_logprobs"]) == len(
meta_info["input_top_logprobs"]
)
assert len(meta_info["decode_token_logprobs"]) == len(
meta_info["decode_top_logprobs"]
assert len(meta_info["output_token_logprobs"]) == len(
meta_info["output_top_logprobs"]
)
assert len(meta_info["prefill_token_logprobs"]) == meta_info["prompt_tokens"]
assert len(meta_info["decode_token_logprobs"]) == meta_info["completion_tokens"] - 1
assert len(meta_info["input_token_logprobs"]) == meta_info["prompt_tokens"]
assert len(meta_info["output_token_logprobs"]) == meta_info["completion_tokens"] - 1
return res
@@ -72,11 +72,11 @@ def pretty_print(res):
meta_info = res["meta_info"]
print("\n\n", "=" * 30, "Prefill", "=" * 30)
for i in range(len(meta_info["prefill_token_logprobs"])):
print(f"{str(meta_info['prefill_token_logprobs'][i][2].encode()): <20}", end="")
for i in range(len(meta_info["input_token_logprobs"])):
print(f"{str(meta_info['input_token_logprobs'][i][2].encode()): <20}", end="")
top_ks = (
[str(t[2].encode()) for t in meta_info["prefill_top_logprobs"][i]]
if meta_info["prefill_top_logprobs"][i]
[str(t[2].encode()) for t in meta_info["input_top_logprobs"][i]]
if meta_info["input_top_logprobs"][i]
else []
)
for top_k in top_ks:
@@ -84,9 +84,9 @@ def pretty_print(res):
print()
print("\n\n", "=" * 30, "Decode", "=" * 30)
for i in range(len(meta_info["decode_token_logprobs"])):
print(f"{str(meta_info['decode_token_logprobs'][i][2].encode()): <20}", end="")
top_ks = [str(t[2].encode()) for t in meta_info["decode_top_logprobs"][i]]
for i in range(len(meta_info["output_token_logprobs"])):
print(f"{str(meta_info['output_token_logprobs'][i][2].encode()): <20}", end="")
top_ks = [str(t[2].encode()) for t in meta_info["output_top_logprobs"][i]]
for top_k in top_ks:
print(f"{top_k: <15}", end="")
print()