fix: second_per_grid_ts should be used to get mrope position (#3682)

This commit is contained in:
Mick
2025-03-18 09:12:38 +08:00
committed by GitHub
parent 98be3bd306
commit d373a48c98
8 changed files with 93 additions and 69 deletions

View File

@@ -880,8 +880,17 @@ class MRotaryEmbedding(RotaryEmbedding):
spatial_merge_size: int,
context_len: int = 0,
seq_len: Optional[int] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
tokens_per_second: Optional[int] = None,
) -> Tuple[List[List[int]], int]:
"""Get mrope input positions and delta value."""
"""
Get mrope input positions and delta value.
:arg
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
"""
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
@@ -918,6 +927,7 @@ class MRotaryEmbedding(RotaryEmbedding):
)
image_index += 1
remain_images -= 1
second_per_grid_t = 0
ed = ed_image
else:
t, h, w = (
@@ -925,6 +935,10 @@ class MRotaryEmbedding(RotaryEmbedding):
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
if second_per_grid_ts is not None:
second_per_grid_t = second_per_grid_ts[video_index]
else:
second_per_grid_t = 1.0
video_index += 1
remain_videos -= 1
ed = ed_video
@@ -941,11 +955,11 @@ class MRotaryEmbedding(RotaryEmbedding):
)
t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
* second_per_grid_t
* tokens_per_second
).flatten()
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)