fix: fix one more bug from merging mm_inputs (#5718)
Co-authored-by: Xinyuan Tong <justinning0323@outlook.com> Co-authored-by: XinyuanTong <115166877+JustinTong0323@users.noreply.github.com>
This commit is contained in:
@@ -351,7 +351,6 @@ class MultimodalInputs:
|
||||
optional_args = [
|
||||
"mm_items",
|
||||
"image_pad_len",
|
||||
"mrope_position_delta",
|
||||
]
|
||||
for arg in optional_args:
|
||||
self_arg = getattr(self, arg, None)
|
||||
@@ -367,6 +366,14 @@ class MultimodalInputs:
|
||||
[self.mrope_positions, other.mrope_positions], dim=1
|
||||
)
|
||||
|
||||
mrope_position_delta = self.mrope_position_delta
|
||||
if mrope_position_delta is not None:
|
||||
if other.mrope_position_delta is None:
|
||||
self.mrope_position_delta = mrope_position_delta
|
||||
else:
|
||||
self.mrope_position_delta = torch.cat(
|
||||
[self.mrope_position_delta, other.mrope_position_delta], dim=0
|
||||
)
|
||||
# other args would be kept intact
|
||||
|
||||
|
||||
@@ -1455,7 +1462,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
if self.model_config.is_encoder_decoder:
|
||||
self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
|
||||
self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
|
||||
|
||||
self.req_pool_indices = torch.cat(
|
||||
[self.req_pool_indices, other.req_pool_indices]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user