Remove mrope position sync (#9460)
Co-authored-by: Nathan Wang <nathan.r.wang@gmail.com>
This commit is contained in:
@@ -1433,24 +1433,6 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
|
||||
return position_ids, mrope_position_deltas
|
||||
|
||||
@staticmethod
|
||||
def get_next_input_positions(
|
||||
mrope_position_delta: int,
|
||||
context_len: int,
|
||||
seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
return torch.tensor(
|
||||
[
|
||||
list(
|
||||
range(
|
||||
context_len + mrope_position_delta,
|
||||
seq_len + mrope_position_delta,
|
||||
)
|
||||
)
|
||||
for _ in range(3)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class DualChunkRotaryEmbedding(CustomOp):
|
||||
"""Rotary positional embedding for Dual Chunk Attention."""
|
||||
|
||||
@@ -516,24 +516,23 @@ class ForwardBatch:
|
||||
for batch_idx in range(batch_size):
|
||||
mm_input = batch.multimodal_inputs[batch_idx]
|
||||
if self.forward_mode.is_decode():
|
||||
mrope_position_deltas = (
|
||||
[0]
|
||||
if mm_input is None
|
||||
else flatten_nested_list(mm_input.mrope_position_delta.tolist())
|
||||
)
|
||||
next_input_positions = []
|
||||
for mrope_position_delta in mrope_position_deltas:
|
||||
# batched deltas needs to be processed separately
|
||||
# Convert list of lists to tensor with shape [3, seq_len]
|
||||
next_input_positions += [
|
||||
MRotaryEmbedding.get_next_input_positions(
|
||||
mrope_position_delta,
|
||||
int(self.seq_lens[batch_idx]) - 1,
|
||||
int(self.seq_lens[batch_idx]),
|
||||
)
|
||||
]
|
||||
# 3 * N
|
||||
mrope_positions_list[batch_idx] = torch.cat(next_input_positions, dim=1)
|
||||
if mm_input is None:
|
||||
mrope_positions_list[batch_idx] = torch.full(
|
||||
(3, 1),
|
||||
self.seq_lens[batch_idx] - 1,
|
||||
dtype=torch.int64,
|
||||
device=model_runner.device,
|
||||
)
|
||||
else:
|
||||
mrope_position_deltas = mm_input.mrope_position_delta.flatten().to(
|
||||
model_runner.device, non_blocking=True
|
||||
)
|
||||
mrope_positions_list[batch_idx] = (
|
||||
(mrope_position_deltas + self.seq_lens[batch_idx] - 1)
|
||||
.unsqueeze(0)
|
||||
.repeat(3, 1)
|
||||
)
|
||||
elif self.forward_mode.is_extend():
|
||||
extend_seq_len, extend_prefix_len = (
|
||||
batch.extend_seq_lens[batch_idx],
|
||||
|
||||
Reference in New Issue
Block a user