Rename prefill_token_logprobs -> input_token_logprobs; decode_token_logprobs -> output_token_logprobs (#776)
This commit is contained in:
@@ -20,8 +20,8 @@ def main():
|
||||
print("questions:", question)
|
||||
print("choice:", state["tool"])
|
||||
meta_info = state.get_meta_info("tool")
|
||||
print("logprobs of choice 1", meta_info["prefill_token_logprobs"][0])
|
||||
print("logprobs of choice 2", meta_info["prefill_token_logprobs"][1])
|
||||
print("logprobs of choice 1", meta_info["input_token_logprobs"][0])
|
||||
print("logprobs of choice 2", meta_info["input_token_logprobs"][1])
|
||||
print("-" * 50)
|
||||
|
||||
# Run a batch
|
||||
@@ -34,8 +34,8 @@ def main():
|
||||
print("questions:", question)
|
||||
print("choice:", state["tool"])
|
||||
meta_info = state.get_meta_info("tool")
|
||||
print("logprobs of choice 1", meta_info["prefill_token_logprobs"][0])
|
||||
print("logprobs of choice 2", meta_info["prefill_token_logprobs"][1])
|
||||
print("logprobs of choice 1", meta_info["input_token_logprobs"][0])
|
||||
print("logprobs of choice 2", meta_info["input_token_logprobs"][1])
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
|
||||
top_logprobs_num=get_top_k,
|
||||
return_text_in_logprobs=True,
|
||||
)
|
||||
logprobs = step_0.get_meta_info("get_top_k")["decode_top_logprobs"][0]
|
||||
logprobs = step_0.get_meta_info("get_top_k")["output_top_logprobs"][0]
|
||||
|
||||
print("Decoding step 0:", ", ".join(pformat(token[2]) for token in logprobs))
|
||||
for idx, (f, token) in enumerate(zip(forks, logprobs)):
|
||||
@@ -55,9 +55,9 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
|
||||
)
|
||||
|
||||
# calculate probability disparity between the top and secondary tokens
|
||||
x1s = [exp(xt[0][0]) for xt in f.get_meta_info("answer")["decode_top_logprobs"]]
|
||||
x2s = [exp(xt[1][0]) for xt in f.get_meta_info("answer")["decode_top_logprobs"]]
|
||||
tokens = [xt[0][2] for xt in f.get_meta_info("answer")["decode_top_logprobs"]]
|
||||
x1s = [exp(xt[0][0]) for xt in f.get_meta_info("answer")["output_top_logprobs"]]
|
||||
x2s = [exp(xt[1][0]) for xt in f.get_meta_info("answer")["output_top_logprobs"]]
|
||||
tokens = [xt[0][2] for xt in f.get_meta_info("answer")["output_top_logprobs"]]
|
||||
delta = (sum(x1s) - sum(x2s)) / len(x1s)
|
||||
|
||||
# extract the answer span (without the '<|end_of_text|>' token)
|
||||
@@ -81,19 +81,19 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
|
||||
answer_tokens = [
|
||||
xt[0][2]
|
||||
for xt in answer_forks[idx].get_meta_info("answer_span")[
|
||||
"decode_top_logprobs"
|
||||
"output_top_logprobs"
|
||||
]
|
||||
]
|
||||
answer_x1s = [
|
||||
exp(xt[0][0])
|
||||
for xt in answer_forks[idx].get_meta_info("answer_span")[
|
||||
"decode_top_logprobs"
|
||||
"output_top_logprobs"
|
||||
]
|
||||
]
|
||||
answer_x2s = [
|
||||
exp(xt[1][0])
|
||||
for xt in answer_forks[idx].get_meta_info("answer_span")[
|
||||
"decode_top_logprobs"
|
||||
"output_top_logprobs"
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
@@ -56,14 +56,14 @@ def srt_api_request(name):
|
||||
# fout.write(json.dumps(res, indent=4))
|
||||
|
||||
meta_info = res["meta_info"]
|
||||
assert len(meta_info["prefill_token_logprobs"]) == len(
|
||||
meta_info["prefill_top_logprobs"]
|
||||
assert len(meta_info["input_token_logprobs"]) == len(
|
||||
meta_info["input_top_logprobs"]
|
||||
)
|
||||
assert len(meta_info["decode_token_logprobs"]) == len(
|
||||
meta_info["decode_top_logprobs"]
|
||||
assert len(meta_info["output_token_logprobs"]) == len(
|
||||
meta_info["output_top_logprobs"]
|
||||
)
|
||||
assert len(meta_info["prefill_token_logprobs"]) == meta_info["prompt_tokens"]
|
||||
assert len(meta_info["decode_token_logprobs"]) == meta_info["completion_tokens"] - 1
|
||||
assert len(meta_info["input_token_logprobs"]) == meta_info["prompt_tokens"]
|
||||
assert len(meta_info["output_token_logprobs"]) == meta_info["completion_tokens"] - 1
|
||||
|
||||
return res
|
||||
|
||||
@@ -72,11 +72,11 @@ def pretty_print(res):
|
||||
meta_info = res["meta_info"]
|
||||
|
||||
print("\n\n", "=" * 30, "Prefill", "=" * 30)
|
||||
for i in range(len(meta_info["prefill_token_logprobs"])):
|
||||
print(f"{str(meta_info['prefill_token_logprobs'][i][2].encode()): <20}", end="")
|
||||
for i in range(len(meta_info["input_token_logprobs"])):
|
||||
print(f"{str(meta_info['input_token_logprobs'][i][2].encode()): <20}", end="")
|
||||
top_ks = (
|
||||
[str(t[2].encode()) for t in meta_info["prefill_top_logprobs"][i]]
|
||||
if meta_info["prefill_top_logprobs"][i]
|
||||
[str(t[2].encode()) for t in meta_info["input_top_logprobs"][i]]
|
||||
if meta_info["input_top_logprobs"][i]
|
||||
else []
|
||||
)
|
||||
for top_k in top_ks:
|
||||
@@ -84,9 +84,9 @@ def pretty_print(res):
|
||||
print()
|
||||
|
||||
print("\n\n", "=" * 30, "Decode", "=" * 30)
|
||||
for i in range(len(meta_info["decode_token_logprobs"])):
|
||||
print(f"{str(meta_info['decode_token_logprobs'][i][2].encode()): <20}", end="")
|
||||
top_ks = [str(t[2].encode()) for t in meta_info["decode_top_logprobs"][i]]
|
||||
for i in range(len(meta_info["output_token_logprobs"])):
|
||||
print(f"{str(meta_info['output_token_logprobs'][i][2].encode()): <20}", end="")
|
||||
top_ks = [str(t[2].encode()) for t in meta_info["output_top_logprobs"][i]]
|
||||
for top_k in top_ks:
|
||||
print(f"{top_k: <15}", end="")
|
||||
print()
|
||||
|
||||
Reference in New Issue
Block a user