qwen2vl fix bug for #1971 #1897 (#1984)

This commit is contained in:
yizhang2077
2024-11-11 00:10:45 +08:00
committed by GitHub
parent 47ffe7af81
commit a8aad9357d
3 changed files with 8 additions and 14 deletions

View File

@@ -136,8 +136,13 @@ class ForwardBatch:
mrope_positions_list = [None] * self.seq_lens.shape[0]
if self.forward_mode.is_decode():
for i, _ in enumerate(mrope_positions_list):
mrope_position_delta = (
0
if batch.image_inputs[i] is None
else batch.image_inputs[i].mrope_position_delta
)
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
batch.mrope_positions_delta[i][0],
mrope_position_delta,
int(self.seq_lens[i]) - 1,
int(self.seq_lens[i]),
)
@@ -159,7 +164,6 @@ class ForwardBatch:
)
]
] * 3
mrope_position_delta = 0
else:
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
mrope_positions, mrope_position_delta = (
@@ -173,8 +177,8 @@ class ForwardBatch:
context_len=0,
)
)
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
mrope_positions_list[i] = mrope_positions
batch.mrope_positions_delta[i].append(mrope_position_delta)
self.mrope_positions = torch.concat(
[torch.tensor(pos, device=device) for pos in mrope_positions_list],