fix: second_per_grid_ts should be used to get mrope position (#3682)
This commit is contained in:
@@ -880,8 +880,17 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
spatial_merge_size: int,
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||
tokens_per_second: Optional[int] = None,
|
||||
) -> Tuple[List[List[int]], int]:
|
||||
"""Get mrope input positions and delta value."""
|
||||
"""
|
||||
Get mrope input positions and delta value.
|
||||
|
||||
:arg
|
||||
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
|
||||
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
|
||||
|
||||
"""
|
||||
|
||||
if isinstance(image_grid_thw, torch.Tensor):
|
||||
image_grid_thw = image_grid_thw.tolist()
|
||||
@@ -918,6 +927,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
)
|
||||
image_index += 1
|
||||
remain_images -= 1
|
||||
second_per_grid_t = 0
|
||||
ed = ed_image
|
||||
else:
|
||||
t, h, w = (
|
||||
@@ -925,6 +935,10 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
video_grid_thw[video_index][1],
|
||||
video_grid_thw[video_index][2],
|
||||
)
|
||||
if second_per_grid_ts is not None:
|
||||
second_per_grid_t = second_per_grid_ts[video_index]
|
||||
else:
|
||||
second_per_grid_t = 1.0
|
||||
video_index += 1
|
||||
remain_videos -= 1
|
||||
ed = ed_video
|
||||
@@ -941,11 +955,11 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
)
|
||||
|
||||
t_index = (
|
||||
torch.arange(llm_grid_t)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
|
||||
* second_per_grid_t
|
||||
* tokens_per_second
|
||||
).flatten()
|
||||
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
|
||||
Reference in New Issue
Block a user