[fix] fix mrope positions not picked up (#5265)

This commit is contained in:
Mick
2025-04-11 16:29:45 +08:00
committed by GitHub
parent 038bc5d521
commit e53a0b3d5b
7 changed files with 69 additions and 69 deletions

View File

@@ -33,7 +33,6 @@ from dataclasses import dataclass
from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Optional, Union
import numpy as np
import torch
import triton
import triton.language as tl
@@ -399,13 +398,13 @@ class ForwardBatch:
)
elif self.forward_mode.is_extend():
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
for i, multimodal_inputs in enumerate(batch.multimodal_inputs):
for i, mm_input in enumerate(batch.multimodal_inputs):
extend_start_loc, extend_seq_len, extend_prefix_len = (
extend_start_loc_cpu[i],
batch.extend_seq_lens[i],
batch.extend_prefix_lens[i],
)
if multimodal_inputs is None:
if mm_input is None:
# text only
mrope_positions = [
[
@@ -416,23 +415,58 @@ class ForwardBatch:
]
] * 3
else:
image_grid_thws_list = [
item.image_grid_thws
for item in mm_input.mm_items
if item.image_grid_thws is not None
]
image_grid_thw = (
None
if len(image_grid_thws_list) == 0
else torch.cat(image_grid_thws_list, dim=0)
)
video_grid_thws_list = [
item.video_grid_thws
for item in mm_input.mm_items
if item.video_grid_thws is not None
]
video_grid_thw = (
None
if len(video_grid_thws_list) == 0
else torch.cat(video_grid_thws_list, dim=0)
)
second_per_grid_ts_list = [
item.second_per_grid_ts
for item in mm_input.mm_items
if item.second_per_grid_ts is not None
]
second_per_grid_ts = (
None
if len(second_per_grid_ts_list) == 0
else torch.cat(second_per_grid_ts_list, dim=0)
)
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
mrope_positions, mrope_position_delta = (
MRotaryEmbedding.get_input_positions(
input_tokens=self.input_ids[
extend_start_loc : extend_start_loc + extend_seq_len
],
image_grid_thw=multimodal_inputs.image_grid_thws,
video_grid_thw=multimodal_inputs.video_grid_thws,
image_token_id=multimodal_inputs.im_token_id,
video_token_id=multimodal_inputs.video_token_id,
].tolist(),
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
context_len=0,
seq_len=len(self.input_ids),
second_per_grid_ts=multimodal_inputs.second_per_grid_ts,
tokens_per_second=hf_config.vision_config.tokens_per_second,
second_per_grid_ts=second_per_grid_ts,
tokens_per_second=getattr(
hf_config.vision_config, "tokens_per_second", None
),
)
)
batch.multimodal_inputs[i].mrope_position_delta = (

View File

@@ -1070,7 +1070,8 @@ class ModelRunner:
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
if rope_scaling is None:
return False
return rope_scaling.get("type", None) == "mrope"
is_mrope_enabled = "mrope_section" in rope_scaling
return is_mrope_enabled
def save_remote_model(self, url: str):
from sglang.srt.model_loader.loader import RemoteModelLoader