fix: fix one more bug from merging mm_inputs (#5718)

Co-authored-by: Xinyuan Tong <justinning0323@outlook.com>
Co-authored-by: XinyuanTong <115166877+JustinTong0323@users.noreply.github.com>
This commit is contained in:
Mick
2025-04-26 09:28:33 +09:00
committed by GitHub
parent c3948ba67e
commit feda9b11b3
4 changed files with 54 additions and 39 deletions

View File

@@ -351,7 +351,6 @@ class MultimodalInputs:
optional_args = [
"mm_items",
"image_pad_len",
"mrope_position_delta",
]
for arg in optional_args:
self_arg = getattr(self, arg, None)
@@ -367,6 +366,14 @@ class MultimodalInputs:
[self.mrope_positions, other.mrope_positions], dim=1
)
mrope_position_delta = self.mrope_position_delta
if mrope_position_delta is not None:
if other.mrope_position_delta is None:
self.mrope_position_delta = mrope_position_delta
else:
self.mrope_position_delta = torch.cat(
[self.mrope_position_delta, other.mrope_position_delta], dim=0
)
# other args would be kept intact
@@ -1455,7 +1462,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if self.model_config.is_encoder_decoder:
self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
self.req_pool_indices = torch.cat(
[self.req_pool_indices, other.req_pool_indices]
)