From 41273fd71f4e5058d69387f77ddf2b04e1fbfed6 Mon Sep 17 00:00:00 2001 From: Yusong Gao Date: Sun, 11 May 2025 15:28:21 +0800 Subject: [PATCH] fix: handle None multimodal_inputs during merging and filtering batches in disaggregation decode mode (#6169) --- python/sglang/srt/managers/schedule_batch.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 05065c237..38420076a 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1485,7 +1485,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices] self.reqs = [self.reqs[i] for i in keep_indices] - self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices] + if self.multimodal_inputs is not None: + self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices] self.req_pool_indices = self.req_pool_indices[keep_indices_device] self.seq_lens = self.seq_lens[keep_indices_device] self.out_cache_loc = None @@ -1534,7 +1535,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs self.reqs.extend(other.reqs) - self.multimodal_inputs.extend(other.multimodal_inputs) + if self.multimodal_inputs is not None: + self.multimodal_inputs.extend(other.multimodal_inputs) self.return_logprob |= other.return_logprob self.has_stream |= other.has_stream