fix: handle None multimodal_inputs during merging and filtering batches in disaggregation decode mode (#6169)

This commit is contained in:
Yusong Gao
2025-05-11 15:28:21 +08:00
committed by GitHub
parent e9bebafb19
commit 41273fd71f

View File

@@ -1485,7 +1485,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
self.reqs = [self.reqs[i] for i in keep_indices]
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
if self.multimodal_inputs is not None:
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
self.seq_lens = self.seq_lens[keep_indices_device]
self.out_cache_loc = None
@@ -1534,7 +1535,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
self.reqs.extend(other.reqs)
self.multimodal_inputs.extend(other.multimodal_inputs)
if self.multimodal_inputs is not None:
self.multimodal_inputs.extend(other.multimodal_inputs)
self.return_logprob |= other.return_logprob
self.has_stream |= other.has_stream