fix: handle None multimodal_inputs during merging and filtering batches in disaggregation decode mode (#6169)
This commit is contained in:
@@ -1485,7 +1485,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
||||
|
||||
self.reqs = [self.reqs[i] for i in keep_indices]
|
||||
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
|
||||
if self.multimodal_inputs is not None:
|
||||
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
|
||||
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
|
||||
self.seq_lens = self.seq_lens[keep_indices_device]
|
||||
self.out_cache_loc = None
|
||||
@@ -1534,7 +1535,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
||||
self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
|
||||
self.reqs.extend(other.reqs)
|
||||
self.multimodal_inputs.extend(other.multimodal_inputs)
|
||||
if self.multimodal_inputs is not None:
|
||||
self.multimodal_inputs.extend(other.multimodal_inputs)
|
||||
|
||||
self.return_logprob |= other.return_logprob
|
||||
self.has_stream |= other.has_stream
|
||||
|
||||
Reference in New Issue
Block a user