Fix RuntimeEndpoint.select method (#1495)
This commit is contained in:
@@ -244,7 +244,8 @@ class RuntimeEndpoint(BaseBackend):
|
||||
"temperature": 0,
|
||||
},
|
||||
"return_logprob": True,
|
||||
"logprob_start_len": max(prompt_len - 2, 0), # for token healing
|
||||
"return_text_in_logprobs": True,
|
||||
"logprob_start_len": prompt_len - 2, # For token healing
|
||||
}
|
||||
obj = self._generate_http_request(s, data)
|
||||
|
||||
@@ -254,6 +255,17 @@ class RuntimeEndpoint(BaseBackend):
|
||||
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
|
||||
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
|
||||
|
||||
# Remove extra token if no token healing occurred
|
||||
for i in range(len(input_token_logprobs)):
|
||||
healed_token_str = input_token_logprobs[i][0][-1]
|
||||
healed_token_logprob = input_token_logprobs[i][0][0]
|
||||
if s.text_.endswith(healed_token_str):
|
||||
normalized_prompt_logprobs[i] = (
|
||||
normalized_prompt_logprobs[i] * len(input_token_logprobs[i])
|
||||
- healed_token_logprob
|
||||
) / (len(input_token_logprobs[i]) - 1)
|
||||
input_token_logprobs[i] = input_token_logprobs[i][1:]
|
||||
|
||||
# Compute unconditional logprobs if required
|
||||
if choices_method.requires_unconditional_logprobs:
|
||||
input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]
|
||||
|
||||
Reference in New Issue
Block a user