@@ -136,8 +136,13 @@ class ForwardBatch:
|
||||
mrope_positions_list = [None] * self.seq_lens.shape[0]
|
||||
if self.forward_mode.is_decode():
|
||||
for i, _ in enumerate(mrope_positions_list):
|
||||
mrope_position_delta = (
|
||||
0
|
||||
if batch.image_inputs[i] is None
|
||||
else batch.image_inputs[i].mrope_position_delta
|
||||
)
|
||||
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
|
||||
batch.mrope_positions_delta[i][0],
|
||||
mrope_position_delta,
|
||||
int(self.seq_lens[i]) - 1,
|
||||
int(self.seq_lens[i]),
|
||||
)
|
||||
@@ -159,7 +164,6 @@ class ForwardBatch:
|
||||
)
|
||||
]
|
||||
] * 3
|
||||
mrope_position_delta = 0
|
||||
else:
|
||||
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
|
||||
mrope_positions, mrope_position_delta = (
|
||||
@@ -173,8 +177,8 @@ class ForwardBatch:
|
||||
context_len=0,
|
||||
)
|
||||
)
|
||||
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
|
||||
mrope_positions_list[i] = mrope_positions
|
||||
batch.mrope_positions_delta[i].append(mrope_position_delta)
|
||||
|
||||
self.mrope_positions = torch.concat(
|
||||
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
|
||||
|
||||
Reference in New Issue
Block a user