Remove mrope position sync (#9460)
Co-authored-by: Nathan Wang <nathan.r.wang@gmail.com>
This commit is contained in:
@@ -1433,24 +1433,6 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
|
|
||||||
return position_ids, mrope_position_deltas
|
return position_ids, mrope_position_deltas
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_next_input_positions(
|
|
||||||
mrope_position_delta: int,
|
|
||||||
context_len: int,
|
|
||||||
seq_len: int,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return torch.tensor(
|
|
||||||
[
|
|
||||||
list(
|
|
||||||
range(
|
|
||||||
context_len + mrope_position_delta,
|
|
||||||
seq_len + mrope_position_delta,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for _ in range(3)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DualChunkRotaryEmbedding(CustomOp):
|
class DualChunkRotaryEmbedding(CustomOp):
|
||||||
"""Rotary positional embedding for Dual Chunk Attention."""
|
"""Rotary positional embedding for Dual Chunk Attention."""
|
||||||
|
|||||||
@@ -516,24 +516,23 @@ class ForwardBatch:
|
|||||||
for batch_idx in range(batch_size):
|
for batch_idx in range(batch_size):
|
||||||
mm_input = batch.multimodal_inputs[batch_idx]
|
mm_input = batch.multimodal_inputs[batch_idx]
|
||||||
if self.forward_mode.is_decode():
|
if self.forward_mode.is_decode():
|
||||||
mrope_position_deltas = (
|
|
||||||
[0]
|
|
||||||
if mm_input is None
|
|
||||||
else flatten_nested_list(mm_input.mrope_position_delta.tolist())
|
|
||||||
)
|
|
||||||
next_input_positions = []
|
|
||||||
for mrope_position_delta in mrope_position_deltas:
|
|
||||||
# batched deltas needs to be processed separately
|
|
||||||
# Convert list of lists to tensor with shape [3, seq_len]
|
|
||||||
next_input_positions += [
|
|
||||||
MRotaryEmbedding.get_next_input_positions(
|
|
||||||
mrope_position_delta,
|
|
||||||
int(self.seq_lens[batch_idx]) - 1,
|
|
||||||
int(self.seq_lens[batch_idx]),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
# 3 * N
|
# 3 * N
|
||||||
mrope_positions_list[batch_idx] = torch.cat(next_input_positions, dim=1)
|
if mm_input is None:
|
||||||
|
mrope_positions_list[batch_idx] = torch.full(
|
||||||
|
(3, 1),
|
||||||
|
self.seq_lens[batch_idx] - 1,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=model_runner.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mrope_position_deltas = mm_input.mrope_position_delta.flatten().to(
|
||||||
|
model_runner.device, non_blocking=True
|
||||||
|
)
|
||||||
|
mrope_positions_list[batch_idx] = (
|
||||||
|
(mrope_position_deltas + self.seq_lens[batch_idx] - 1)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.repeat(3, 1)
|
||||||
|
)
|
||||||
elif self.forward_mode.is_extend():
|
elif self.forward_mode.is_extend():
|
||||||
extend_seq_len, extend_prefix_len = (
|
extend_seq_len, extend_prefix_len = (
|
||||||
batch.extend_seq_lens[batch_idx],
|
batch.extend_seq_lens[batch_idx],
|
||||||
|
|||||||
Reference in New Issue
Block a user