Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -170,9 +170,9 @@ def recompute_mrope_positions(
|
||||
multimodal_embeddings may contain zero, some or even some part of all
|
||||
multimodal_embeddings for a given prompt.
|
||||
|
||||
Each multimodal_positions has 4 extra channels
|
||||
(First 3 channels corresponds to original 3 mrope positions, last channel
|
||||
is the maximum width of the media repeated). Provided multimodal_positions
|
||||
Each multimodal_positions has 4 or 5 extra channels
|
||||
(first 3 channels correspond to the original 3 mrope positions;
|
||||
remaining channels vary by model — see below). Provided multimodal_positions
|
||||
do not reflect location of media position in sequence - they are computed
|
||||
like the media is in the 0-th position in the sequence.
|
||||
|
||||
@@ -186,6 +186,16 @@ def recompute_mrope_positions(
|
||||
Args:
|
||||
input_ids: (N,) All input tokens of the prompt (entire sequence).
|
||||
multimodal_positions: List of mrope positions for each media.
|
||||
If a given element is of shape (4, N), it is assumed to only describe
|
||||
positions for video / image embeddings. This is the case of e.g. Qwen2.5 VL,
|
||||
where each multimodal input is a contiguous chunk of embeddings.
|
||||
The expected channels are [t, h, w, max_width].
|
||||
If it is of shape (5, N), it is assumed to possibly describe positions for
|
||||
both video / image embeddings, as well as text embeddings. This is the case
|
||||
of e.g. Qwen3 VL, where each video inputs are comprised of individual
|
||||
frames' embeddings, interleaved with embeddings for timestamp tokens,
|
||||
and vision start / end tokens. The expected channels are
|
||||
[t, h, w, is_vision_start, is_vision].
|
||||
mrope_positions: Existing mrope positions (4, N) for entire sequence.
|
||||
num_computed_tokens: A number of computed tokens so far.
|
||||
vision_start_token_id: Token indicating start of vision media.
|
||||
@@ -233,6 +243,21 @@ def recompute_mrope_positions(
|
||||
# - Current prefill chunk has no vision start indexes at all
|
||||
# - Vision start token appeared in previous prefill round
|
||||
# - Regular case
|
||||
has_video_tokens = False
|
||||
num_timestamp_tokens = 0
|
||||
if mm_pos.shape[0] == 5 and mm_pos.shape[1] > 0:
|
||||
# mm_pos[4, :] indicates which positions are for video embeddings.
|
||||
# If there are no video embeddings, skip timestamp adjustment.
|
||||
has_video_tokens = torch.any(mm_pos[4, :]).item()
|
||||
if has_video_tokens:
|
||||
# Channel 3 flags VISION_START tokens. Timestamp tokens
|
||||
# precede the first VISION_START, so its index gives us the
|
||||
# exact timestamp count. This is robust even when early
|
||||
# frames have all their video tokens pruned (which would
|
||||
# push argmax(channel 4) far into a later frame).
|
||||
first_vs = (mm_pos[3, :] == 1).nonzero(as_tuple=True)[0]
|
||||
num_timestamp_tokens = first_vs[0].item() if len(first_vs) > 0 else 0
|
||||
|
||||
seen_vision_start_indices = vision_start_indices[
|
||||
vision_start_indices < num_computed_tokens
|
||||
]
|
||||
@@ -249,6 +274,18 @@ def recompute_mrope_positions(
|
||||
in_the_middle_of_media = (
|
||||
seen_mm_tokens > seem_mm_tokens_before_last_vision_start
|
||||
)
|
||||
# For Qwen3 VL, we can be inside a media segment even before any
|
||||
# video tokens appear (timestamp tokens are text). If we've passed
|
||||
# the last vision_start token but haven't reached the first video
|
||||
# embedding, treat this as "in the middle of media".
|
||||
if (
|
||||
not in_the_middle_of_media
|
||||
and has_video_tokens
|
||||
and num_computed_tokens > last_vision_start_token
|
||||
and num_computed_tokens
|
||||
<= last_vision_start_token + num_timestamp_tokens + 1
|
||||
):
|
||||
in_the_middle_of_media = True
|
||||
|
||||
if in_the_middle_of_media:
|
||||
mm_embeddings_seen = (
|
||||
@@ -274,14 +311,39 @@ def recompute_mrope_positions(
|
||||
mm_embeddings_seen = 0
|
||||
global_mm_start = next_vision_start_token
|
||||
|
||||
# Offset right after vision_start_token
|
||||
base = positions[-1, global_mm_start] + 1
|
||||
local_start = global_mm_start + 1 + mm_embeddings_seen
|
||||
# For Qwen3 VL, mm_pos includes timestamp tokens before vision_start
|
||||
# when starting a new media. Adjust global_mm_start to point to where
|
||||
# the sequence actually begins (before timestamp tokens).
|
||||
adjusted_for_timestamps = False
|
||||
if mm_pos.shape[0] == 5 and mm_embeddings_seen == 0 and has_video_tokens:
|
||||
# NOTE: -1 is because there is a vision start token right after
|
||||
# timestamp tokens before any video embeddings appear.
|
||||
|
||||
# Adjust global_mm_start to point to the first timestamp token
|
||||
# instead of the vision_start token.
|
||||
global_mm_start -= num_timestamp_tokens
|
||||
adjusted_for_timestamps = True
|
||||
|
||||
# Offset calculation depends on whether we adjusted for timestamp tokens
|
||||
if adjusted_for_timestamps:
|
||||
# Start from position before the first timestamp token
|
||||
base = positions[-1, global_mm_start - 1] + 1
|
||||
local_start = global_mm_start + mm_embeddings_seen
|
||||
else:
|
||||
# Original logic: start after vision_start_token
|
||||
base = positions[-1, global_mm_start] + 1
|
||||
local_start = global_mm_start + 1 + mm_embeddings_seen
|
||||
|
||||
local_end = local_start + mm_pos.shape[1]
|
||||
positions[:, local_start:local_end] = mm_pos[0:3] + base
|
||||
|
||||
# mm_pos[3, 0] is the max width of the media
|
||||
offset = mm_pos[3, 0] + base
|
||||
# For Qwen3 VL (5-channel), use the maximum position reached across
|
||||
# all tokens (both video and text) in all dimensions (t, h, w).
|
||||
# For Qwen2.5 VL (4-channel), mm_pos[3, 0] is the max width.
|
||||
if mm_pos.shape[0] == 5:
|
||||
offset = mm_pos[0:3, :].max() + base + 1
|
||||
else:
|
||||
offset = mm_pos[3, 0] + base
|
||||
|
||||
text_pos_sum = torch.cumsum(text_mask[local_end:].long(), dim=0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user