From b88ea90d4ad98992790395d11ae20bf27b9657f8 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Mon, 30 Sep 2024 17:09:54 -0700 Subject: [PATCH] Fix bugs of `logprobs_nums` (#1548) --- python/sglang/srt/managers/schedule_batch.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 6fcc9616f..8ff204abd 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -748,6 +748,8 @@ class ScheduleBatch: self.top_logprobs_nums = [ self.top_logprobs_nums[i] for i in unfinished_indices ] + else: + self.top_logprobs_nums = None self.has_stream = any(req.stream for req in self.reqs) self.sampling_info.filter_batch(unfinished_indices, new_indices) @@ -758,13 +760,11 @@ class ScheduleBatch: # needs to be called with pre-merged Batch.reqs. self.sampling_info.merge_batch(other.sampling_info) - self.reqs.extend(other.reqs) self.req_pool_indices = torch.concat( [self.req_pool_indices, other.req_pool_indices] ) self.seq_lens = torch.concat([self.seq_lens, other.seq_lens]) self.out_cache_loc = None - self.return_logprob = any(req.return_logprob for req in self.reqs) if self.return_logprob and other.return_logprob: self.top_logprobs_nums.extend(other.top_logprobs_nums) elif self.return_logprob: @@ -772,6 +772,8 @@ class ScheduleBatch: elif other.return_logprob: self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums self.has_stream = any(req.stream for req in self.reqs) + self.reqs.extend(other.reqs) + self.return_logprob = self.return_logprob or other.return_logprob def get_model_worker_batch(self): if self.forward_mode.is_decode():