Fix RuntimeEndpoint.select method (#1495)
This commit is contained in:
@@ -244,7 +244,8 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
},
|
},
|
||||||
"return_logprob": True,
|
"return_logprob": True,
|
||||||
"logprob_start_len": max(prompt_len - 2, 0), # for token healing
|
"return_text_in_logprobs": True,
|
||||||
|
"logprob_start_len": prompt_len - 2, # For token healing
|
||||||
}
|
}
|
||||||
obj = self._generate_http_request(s, data)
|
obj = self._generate_http_request(s, data)
|
||||||
|
|
||||||
@@ -254,6 +255,17 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
|
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
|
||||||
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
|
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
|
||||||
|
|
||||||
|
# Remove extra token if no token healing occurred
|
||||||
|
for i in range(len(input_token_logprobs)):
|
||||||
|
healed_token_str = input_token_logprobs[i][0][-1]
|
||||||
|
healed_token_logprob = input_token_logprobs[i][0][0]
|
||||||
|
if s.text_.endswith(healed_token_str):
|
||||||
|
normalized_prompt_logprobs[i] = (
|
||||||
|
normalized_prompt_logprobs[i] * len(input_token_logprobs[i])
|
||||||
|
- healed_token_logprob
|
||||||
|
) / (len(input_token_logprobs[i]) - 1)
|
||||||
|
input_token_logprobs[i] = input_token_logprobs[i][1:]
|
||||||
|
|
||||||
# Compute unconditional logprobs if required
|
# Compute unconditional logprobs if required
|
||||||
if choices_method.requires_unconditional_logprobs:
|
if choices_method.requires_unconditional_logprobs:
|
||||||
input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]
|
input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]
|
||||||
|
|||||||
Reference in New Issue
Block a user