Fix logprobs with logprob_start_len (#193)
This commit is contained in:
@@ -432,9 +432,14 @@ class ModelRpcServer(rpyc.Service):
|
||||
req.logprob = logprobs[pt : pt + req.extend_input_len - 1]
|
||||
req.normalized_logprob = normalized_logprobs[i]
|
||||
|
||||
token_ids = req.input_ids + [next_token_ids[i]]
|
||||
token_logprobs = [None] + req.logprob + [last_logprobs[i]]
|
||||
# If logprob_start_len > 0, then first logprob_start_len prompt tokens
|
||||
# will be ignored.
|
||||
prompt_token_len = len(req.logprob)
|
||||
token_ids = req.input_ids[-prompt_token_len :] + [next_token_ids[i]]
|
||||
token_logprobs = req.logprob + [last_logprobs[i]]
|
||||
req.token_logprob = list(zip(token_ids, token_logprobs))
|
||||
if req.logprob_start_len == 0:
|
||||
req.token_logprob = [(req.input_ids[0], None)] + req.token_logprob
|
||||
pt += req.extend_input_len
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
|
||||
Reference in New Issue
Block a user