Fix logprobs with logprob_start_len (#193)

This commit is contained in:
Cody Yu
2024-02-22 10:33:03 -08:00
committed by GitHub
parent 9de9a46815
commit 4cb9aaedf3

View File

@@ -432,9 +432,14 @@ class ModelRpcServer(rpyc.Service):
req.logprob = logprobs[pt : pt + req.extend_input_len - 1]
req.normalized_logprob = normalized_logprobs[i]
token_ids = req.input_ids + [next_token_ids[i]]
token_logprobs = [None] + req.logprob + [last_logprobs[i]]
# If logprob_start_len > 0, then first logprob_start_len prompt tokens
# will be ignored.
prompt_token_len = len(req.logprob)
token_ids = req.input_ids[-prompt_token_len :] + [next_token_ids[i]]
token_logprobs = req.logprob + [last_logprobs[i]]
req.token_logprob = list(zip(token_ids, token_logprobs))
if req.logprob_start_len == 0:
req.token_logprob = [(req.input_ids[0], None)] + req.token_logprob
pt += req.extend_input_len
self.handle_finished_requests(batch)