Remove normalized_prompt_logprobs from the engine to make code easier to maintain (#2902)

This commit is contained in:
Lianmin Zheng
2025-01-15 04:27:18 -08:00
parent b803b395b7
commit f65c13b559
12 changed files with 11 additions and 153 deletions

View File

@@ -251,11 +251,12 @@ class RuntimeEndpoint(BaseBackend):
}
obj = self._generate_http_request(s, data)
normalized_prompt_logprobs = [
r["meta_info"]["normalized_prompt_logprob"] for r in obj
]
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
normalized_prompt_logprobs = [
compute_normalized_prompt_logprobs(r["meta_info"]["input_token_logprobs"])
for r in obj
]
# Remove extra token if no token healing occurred
for i in range(len(input_token_logprobs)):
@@ -319,3 +320,8 @@ class RuntimeEndpoint(BaseBackend):
def _assert_success(self, res):
if res.status_code != 200:
raise RuntimeError(res.json())
def compute_normalized_prompt_logprobs(input_logprobs):
values = [x[0] for x in input_logprobs if x[0]]
return sum(values) / len(values)