Fix RuntimeEndpoint.select method (#1495)

This commit is contained in:
Jeffrey Fong
2024-09-29 05:04:06 +08:00
committed by GitHub
parent f42e9bfb52
commit 065bb94753

View File

@@ -244,7 +244,8 @@ class RuntimeEndpoint(BaseBackend):
"temperature": 0,
},
"return_logprob": True,
"logprob_start_len": max(prompt_len - 2, 0), # for token healing
"return_text_in_logprobs": True,
"logprob_start_len": prompt_len - 2, # For token healing
}
obj = self._generate_http_request(s, data)
@@ -254,6 +255,17 @@ class RuntimeEndpoint(BaseBackend):
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
# Remove extra token if no token healing occurred
for i in range(len(input_token_logprobs)):
healed_token_str = input_token_logprobs[i][0][-1]
healed_token_logprob = input_token_logprobs[i][0][0]
if s.text_.endswith(healed_token_str):
normalized_prompt_logprobs[i] = (
normalized_prompt_logprobs[i] * len(input_token_logprobs[i])
- healed_token_logprob
) / (len(input_token_logprobs[i]) - 1)
input_token_logprobs[i] = input_token_logprobs[i][1:]
# Compute unconditional logprobs if required
if choices_method.requires_unconditional_logprobs:
input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]