From 25c7395934a92a213596d8bd9d00410207074796 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 31 Aug 2025 02:56:47 -0700 Subject: [PATCH] Fix input logprob index (#9841) Co-authored-by: Sheng Shen --- .../scheduler_output_processor_mixin.py | 38 ++++++++++--------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index a86899f6e..c6205a094 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -93,20 +93,21 @@ class SchedulerOutputProcessorMixin: # This updates radix so others can match self.tree_cache.cache_unfinished_req(req) - if req.return_logprob: + if batch.return_logprob: assert extend_logprob_start_len_per_req is not None assert extend_input_len_per_req is not None extend_logprob_start_len = extend_logprob_start_len_per_req[i] extend_input_len = extend_input_len_per_req[i] num_input_logprobs = extend_input_len - extend_logprob_start_len - self.add_logprob_return_values( - i, - req, - logprob_pt, - next_token_ids, - num_input_logprobs, - logits_output, - ) + if req.return_logprob: + self.add_logprob_return_values( + i, + req, + logprob_pt, + next_token_ids, + num_input_logprobs, + logits_output, + ) logprob_pt += num_input_logprobs if ( @@ -146,7 +147,7 @@ class SchedulerOutputProcessorMixin: skip_stream_req = req # Incrementally update input logprobs. - if req.return_logprob: + if batch.return_logprob: extend_logprob_start_len = extend_logprob_start_len_per_req[i] extend_input_len = extend_input_len_per_req[i] if extend_logprob_start_len < extend_input_len: @@ -154,14 +155,15 @@ class SchedulerOutputProcessorMixin: num_input_logprobs = ( extend_input_len - extend_logprob_start_len ) - self.add_input_logprob_return_values( - i, - req, - logits_output, - logprob_pt, - num_input_logprobs, - last_prefill_chunk=False, - ) + if req.return_logprob: + self.add_input_logprob_return_values( + i, + req, + logits_output, + logprob_pt, + num_input_logprobs, + last_prefill_chunk=False, + ) logprob_pt += num_input_logprobs self.set_next_batch_sampling_info_done(batch)