From 4cb9aaedf3dfe4f876ba447ab2ac1ac9c75da911 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 22 Feb 2024 10:33:03 -0800 Subject: [PATCH] Fix logprobs with logprob_start_len (#193) --- python/sglang/srt/managers/router/model_rpc.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index f59c3f0a1..40ae6dd98 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -432,9 +432,14 @@ class ModelRpcServer(rpyc.Service): req.logprob = logprobs[pt : pt + req.extend_input_len - 1] req.normalized_logprob = normalized_logprobs[i] - token_ids = req.input_ids + [next_token_ids[i]] - token_logprobs = [None] + req.logprob + [last_logprobs[i]] + # If logprob_start_len > 0, then first logprob_start_len prompt tokens + # will be ignored. + prompt_token_len = len(req.logprob) + token_ids = req.input_ids[-prompt_token_len :] + [next_token_ids[i]] + token_logprobs = req.logprob + [last_logprobs[i]] req.token_logprob = list(zip(token_ids, token_logprobs)) + if req.logprob_start_len == 0: + req.token_logprob = [(req.input_ids[0], None)] + req.token_logprob pt += req.extend_input_len self.handle_finished_requests(batch)