295 lines
11 KiB
Python
295 lines
11 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
#
|
|
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
# and proprietary rights in and to this software, related documentation
|
|
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
# distribution of this software and related documentation without an express
|
|
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
|
|
import typing
|
|
|
|
import torch
|
|
|
|
|
|
def compute_retained_tokens_count(
|
|
tokens_per_frame: int, num_frames: int, q: float
|
|
) -> int:
|
|
"""
|
|
Compute the number of retained tokens for a given video.
|
|
Method ensures that we retain all the tokens from the first frame
|
|
regardless of the pruning rate.
|
|
|
|
Args:
|
|
tokens_per_frame: The number of tokens per frame.
|
|
num_frames: The total number of frames.
|
|
q: The pruning rate.
|
|
|
|
Returns:
|
|
The number of retained tokens.
|
|
"""
|
|
total_tokens = tokens_per_frame * num_frames
|
|
evs_num_tokens = int(total_tokens * (1 - q))
|
|
min_num_tokens = tokens_per_frame
|
|
return max(min_num_tokens, evs_num_tokens)
|
|
|
|
|
|
def compute_retention_mask(
|
|
video_embeds: torch.Tensor,
|
|
video_size_thw: torch.LongTensor | tuple[int, int, int],
|
|
spatial_merge_size: int,
|
|
q: float,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Computes the retention mask for input video embeddings.
|
|
|
|
Args:
|
|
video_embeds (`torch.Tensor`): The input video embeddings
|
|
of shape `(T * H * W // spatial_merge_size ^ 2, hidden_size)`
|
|
video_size_thw (`torch.LongTensor` of shape `(3)`):
|
|
The temporal, height and width of video.
|
|
spatial_merge_size: Size reduction for rows & cols dimensions.
|
|
q: (`float`): Pruning rate factor [0,1)
|
|
|
|
Returns:
|
|
`torch.Tensor`: The retention mask for the video embeddings of
|
|
`(T * H * W // spatial_merge_size ^ 2)` shape.
|
|
"""
|
|
T, H, W = map(int, video_size_thw)
|
|
|
|
# Use reshape instead of einops to avoid graph breaks
|
|
video_embeds = video_embeds.reshape(
|
|
T,
|
|
H // spatial_merge_size,
|
|
W // spatial_merge_size,
|
|
video_embeds.size(-1),
|
|
)
|
|
tokens_per_frame = (H // spatial_merge_size) * (W // spatial_merge_size)
|
|
# Core EVS
|
|
similarity = torch.nn.functional.cosine_similarity(
|
|
video_embeds[1:, ...], video_embeds[:-1, ...], dim=-1
|
|
)
|
|
dissimilarity = 1 - similarity
|
|
|
|
# Always ensure we include all tokens from the first frame
|
|
dissimilarity = torch.cat(
|
|
[255 * torch.ones_like(video_embeds[:1, :, :, 0]), dissimilarity], dim=0
|
|
)
|
|
|
|
dissimilarity_flat = dissimilarity.view(-1)
|
|
order = torch.argsort(dissimilarity_flat, dim=-1, descending=True, stable=True)
|
|
retain_num_tokens = compute_retained_tokens_count(
|
|
tokens_per_frame=tokens_per_frame, num_frames=T, q=q
|
|
)
|
|
topk_indices = order[:retain_num_tokens]
|
|
|
|
retention_mask = torch.zeros_like(dissimilarity_flat, dtype=torch.bool)
|
|
retention_mask[topk_indices] = True
|
|
retention_mask = retention_mask.reshape(dissimilarity.size())
|
|
|
|
mask = retention_mask.view(-1) # "T H W -> (T H W)"
|
|
return mask
|
|
|
|
|
|
def compute_mrope_for_media(
|
|
video_size_thw: torch.LongTensor,
|
|
spatial_merge_size: int,
|
|
tokens_per_second: float = 1.0,
|
|
video_second_per_grid: float = 1.0,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Computes the mrope for video embeddings based on the grid dimensions.
|
|
Computed mrope positions match original qwen 2.5 implementation,
|
|
but positions are built for media being the first element in sequence.
|
|
|
|
Args:
|
|
video_size_thw: Media size (num frames, rows, cols)
|
|
spatial_merge_size: Size reduction for rows & cols dimensions.
|
|
tokens_per_second: Number of tokens per second.
|
|
video_second_per_grid: Number of seconds per video.
|
|
|
|
Returns:
|
|
Tensor of shape `(T * H * W, 4)` where last dimension
|
|
represents mrope positions [0:3), while the last channel
|
|
contains value of llm_grid_w repeated for all positions.
|
|
"""
|
|
llm_grid_t = video_size_thw[0]
|
|
llm_grid_h = video_size_thw[1] // spatial_merge_size
|
|
llm_grid_w = video_size_thw[2] // spatial_merge_size
|
|
|
|
t_index = (
|
|
(
|
|
torch.arange(llm_grid_t)
|
|
.view(-1, 1)
|
|
.expand(-1, llm_grid_h * llm_grid_w)
|
|
.mul(tokens_per_second * video_second_per_grid)
|
|
)
|
|
.long()
|
|
.flatten()
|
|
)
|
|
h_index = (
|
|
torch.arange(llm_grid_h)
|
|
.view(1, -1, 1)
|
|
.expand(llm_grid_t, -1, llm_grid_w)
|
|
.flatten()
|
|
)
|
|
w_index = (
|
|
torch.arange(llm_grid_w)
|
|
.view(1, 1, -1)
|
|
.expand(llm_grid_t, llm_grid_h, -1)
|
|
.flatten()
|
|
)
|
|
llm_grid_w = (
|
|
torch.tensor([llm_grid_w])
|
|
.view(1, 1, 1)
|
|
.expand(llm_grid_t, llm_grid_h, llm_grid_w)
|
|
.flatten()
|
|
)
|
|
|
|
positions = torch.stack([t_index, h_index, w_index, llm_grid_w], dim=1)
|
|
return positions
|
|
|
|
|
|
def recompute_mrope_positions(
|
|
input_ids: torch.LongTensor,
|
|
multimodal_positions: list[torch.Tensor],
|
|
mrope_positions: torch.LongTensor,
|
|
num_computed_tokens: int,
|
|
vision_start_token_id: int,
|
|
image_token_id: int,
|
|
video_token_id: int,
|
|
) -> tuple[torch.LongTensor, int]:
|
|
"""
|
|
Update part of input mrope positions.
|
|
Original mrope_positions are computed incorrectly, so once we prune media
|
|
tokens we should reflect this in the mrope positions for the LLM.
|
|
|
|
This method supports chunked prefill approach where
|
|
multimodal_embeddings are passed to LLM in chunks, so input
|
|
multimodal_embeddings may contain zero, some or even some part of all
|
|
multimodal_embeddings for a given prompt.
|
|
|
|
Each multimodal_positions has 4 extra channels
|
|
(First 3 channels corresponds to original 3 mrope positions, last channel
|
|
is the maximum width of the media repeated). Provided multimodal_positions
|
|
do not reflect location of media position in sequence - they are computed
|
|
like the media is in the 0-th position in the sequence.
|
|
|
|
Method works as follows: it recomputes mrope_positions starting from the
|
|
`num_computed_tokens` for `total_len_of_multimodal_embeddings` and then
|
|
shifts all text tokens that goes after total_len_of_multimodal_embeddings.
|
|
|
|
It also handles case when multimodal_embeddings is partial
|
|
(e.g. one media is split into two prefill stages)
|
|
|
|
Args:
|
|
input_ids: (N,) All input tokens of the prompt (entire sequence).
|
|
multimodal_positions: List of mrope positsions for each media.
|
|
mrope_positions: Existing mrope positions (4, N) for entire sequence.
|
|
num_computed_tokens: A number of computed tokens so far.
|
|
vision_start_token_id: Token indicating start of vision media.
|
|
image_token_id: Image token id
|
|
video_token_id: Video token id
|
|
|
|
Returns:
|
|
Tuple of (mrope_positions, mrope_position_delta).
|
|
"""
|
|
|
|
# Tensors
|
|
positions: torch.LongTensor = typing.cast(
|
|
torch.LongTensor, mrope_positions.clone()
|
|
) # (3, N)
|
|
N = input_ids.numel()
|
|
|
|
image_mask = input_ids.eq(image_token_id)
|
|
video_mask = input_ids.eq(video_token_id)
|
|
media_mask = image_mask | video_mask
|
|
text_mask = ~media_mask
|
|
|
|
# Early exit: no media in this chunk
|
|
if len(multimodal_positions) == 0:
|
|
delta = int((positions.max().item() + 1) - N) if positions.numel() else -N
|
|
return positions, delta
|
|
|
|
total_mm_tokens = torch.count_nonzero(media_mask)
|
|
seen_mm_tokens = torch.count_nonzero(media_mask[:num_computed_tokens])
|
|
|
|
# Early exit: we've updated positions for all media tokens
|
|
# (and consequently - for all remaining text tokens)
|
|
if seen_mm_tokens == total_mm_tokens:
|
|
delta = int((positions.max().item() + 1) - N) if positions.numel() else -N
|
|
return positions, delta
|
|
|
|
vision_start_indices = (input_ids == vision_start_token_id).nonzero(as_tuple=True)[
|
|
0
|
|
]
|
|
|
|
for mm_pos in multimodal_positions:
|
|
# Each mm_pos can be a complete embedding for single media
|
|
# or it can be a part of a single media (due to chunked prefill)
|
|
|
|
# Cases to cover
|
|
# - Current prefill chunk has no vision start indexes at all
|
|
# - Vision start token appeared in previous prefill round
|
|
# - Regular case
|
|
seen_vision_start_indices = vision_start_indices[
|
|
vision_start_indices < num_computed_tokens
|
|
]
|
|
|
|
if len(seen_vision_start_indices):
|
|
# If we have encountered some vision start indexes,
|
|
# then we should check the condition:
|
|
# | --- prefill 1 ------| ---- prefill 2 ----- |
|
|
# | TTTTTTTTTSVVVVVVVVVV|VVVVVVTTTTTTTTTTTTTTTT|
|
|
last_vision_start_token = seen_vision_start_indices[-1]
|
|
seem_mm_tokens_before_last_vision_start = torch.count_nonzero(
|
|
media_mask[:last_vision_start_token]
|
|
)
|
|
in_the_middle_of_media = (
|
|
seen_mm_tokens > seem_mm_tokens_before_last_vision_start
|
|
)
|
|
|
|
if in_the_middle_of_media:
|
|
mm_embeddings_seen = (
|
|
seen_mm_tokens - seem_mm_tokens_before_last_vision_start
|
|
)
|
|
global_mm_start = last_vision_start_token
|
|
else:
|
|
# We have completed previous mm_embedding part and
|
|
# ready to start a new one
|
|
next_vision_start_token = vision_start_indices[
|
|
vision_start_indices >= num_computed_tokens
|
|
][0]
|
|
mm_embeddings_seen = 0
|
|
global_mm_start = next_vision_start_token
|
|
|
|
else:
|
|
# If there were no vision start indexes so far,
|
|
# let's find first vision start index
|
|
next_vision_start_token = vision_start_indices[
|
|
vision_start_indices >= num_computed_tokens
|
|
][0]
|
|
|
|
mm_embeddings_seen = 0
|
|
global_mm_start = next_vision_start_token
|
|
|
|
# Offset right after vision_start_token
|
|
base = positions[-1, global_mm_start] + 1
|
|
local_start = global_mm_start + 1 + mm_embeddings_seen
|
|
local_end = local_start + mm_pos.shape[1]
|
|
positions[:, local_start:local_end] = mm_pos[0:3] + base
|
|
|
|
# mm_pos[3, 0] is the max width of the media
|
|
offset = mm_pos[3, 0] + base
|
|
|
|
text_pos_sum = torch.cumsum(text_mask[local_end:].long(), dim=0)
|
|
|
|
positions[:, local_end:N] = text_pos_sum + offset - 1
|
|
|
|
# Include distance to the next vision start token
|
|
num_computed_tokens += mm_pos.shape[1]
|
|
|
|
mrope_positions_delta = (positions.max() + 1 - N).item()
|
|
return positions, mrope_positions_delta
|