Rename prefill_token_logprobs -> input_token_logprobs; decode_token_logprobs -> output_token_logprobs (#776)
This commit is contained in:
@@ -31,7 +31,7 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
|
||||
top_logprobs_num=get_top_k,
|
||||
return_text_in_logprobs=True,
|
||||
)
|
||||
logprobs = step_0.get_meta_info("get_top_k")["decode_top_logprobs"][0]
|
||||
logprobs = step_0.get_meta_info("get_top_k")["output_top_logprobs"][0]
|
||||
|
||||
print("Decoding step 0:", ", ".join(pformat(token[2]) for token in logprobs))
|
||||
for idx, (f, token) in enumerate(zip(forks, logprobs)):
|
||||
@@ -55,9 +55,9 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
|
||||
)
|
||||
|
||||
# calculate probability disparity between the top and secondary tokens
|
||||
x1s = [exp(xt[0][0]) for xt in f.get_meta_info("answer")["decode_top_logprobs"]]
|
||||
x2s = [exp(xt[1][0]) for xt in f.get_meta_info("answer")["decode_top_logprobs"]]
|
||||
tokens = [xt[0][2] for xt in f.get_meta_info("answer")["decode_top_logprobs"]]
|
||||
x1s = [exp(xt[0][0]) for xt in f.get_meta_info("answer")["output_top_logprobs"]]
|
||||
x2s = [exp(xt[1][0]) for xt in f.get_meta_info("answer")["output_top_logprobs"]]
|
||||
tokens = [xt[0][2] for xt in f.get_meta_info("answer")["output_top_logprobs"]]
|
||||
delta = (sum(x1s) - sum(x2s)) / len(x1s)
|
||||
|
||||
# extract the answer span (without the '<|end_of_text|>' token)
|
||||
@@ -81,19 +81,19 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
|
||||
answer_tokens = [
|
||||
xt[0][2]
|
||||
for xt in answer_forks[idx].get_meta_info("answer_span")[
|
||||
"decode_top_logprobs"
|
||||
"output_top_logprobs"
|
||||
]
|
||||
]
|
||||
answer_x1s = [
|
||||
exp(xt[0][0])
|
||||
for xt in answer_forks[idx].get_meta_info("answer_span")[
|
||||
"decode_top_logprobs"
|
||||
"output_top_logprobs"
|
||||
]
|
||||
]
|
||||
answer_x2s = [
|
||||
exp(xt[1][0])
|
||||
for xt in answer_forks[idx].get_meta_info("answer_span")[
|
||||
"decode_top_logprobs"
|
||||
"output_top_logprobs"
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user