Remove normalized_prompt_logprobs from the engine to make code easier to maintain (#2902)
This commit is contained in:
@@ -251,11 +251,12 @@ class RuntimeEndpoint(BaseBackend):
|
||||
}
|
||||
obj = self._generate_http_request(s, data)
|
||||
|
||||
normalized_prompt_logprobs = [
|
||||
r["meta_info"]["normalized_prompt_logprob"] for r in obj
|
||||
]
|
||||
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
|
||||
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
|
||||
normalized_prompt_logprobs = [
|
||||
compute_normalized_prompt_logprobs(r["meta_info"]["input_token_logprobs"])
|
||||
for r in obj
|
||||
]
|
||||
|
||||
# Remove extra token if no token healing occurred
|
||||
for i in range(len(input_token_logprobs)):
|
||||
@@ -319,3 +320,8 @@ class RuntimeEndpoint(BaseBackend):
|
||||
def _assert_success(self, res):
|
||||
if res.status_code != 200:
|
||||
raise RuntimeError(res.json())
|
||||
|
||||
|
||||
def compute_normalized_prompt_logprobs(input_logprobs):
|
||||
values = [x[0] for x in input_logprobs if x[0]]
|
||||
return sum(values) / len(values)
|
||||
|
||||
Reference in New Issue
Block a user