[fix] fix mrope positions not picked up (#5265)
This commit is contained in:
@@ -33,7 +33,6 @@ from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@@ -399,13 +398,13 @@ class ForwardBatch:
|
||||
)
|
||||
elif self.forward_mode.is_extend():
|
||||
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
|
||||
for i, multimodal_inputs in enumerate(batch.multimodal_inputs):
|
||||
for i, mm_input in enumerate(batch.multimodal_inputs):
|
||||
extend_start_loc, extend_seq_len, extend_prefix_len = (
|
||||
extend_start_loc_cpu[i],
|
||||
batch.extend_seq_lens[i],
|
||||
batch.extend_prefix_lens[i],
|
||||
)
|
||||
if multimodal_inputs is None:
|
||||
if mm_input is None:
|
||||
# text only
|
||||
mrope_positions = [
|
||||
[
|
||||
@@ -416,23 +415,58 @@ class ForwardBatch:
|
||||
]
|
||||
] * 3
|
||||
else:
|
||||
image_grid_thws_list = [
|
||||
item.image_grid_thws
|
||||
for item in mm_input.mm_items
|
||||
if item.image_grid_thws is not None
|
||||
]
|
||||
image_grid_thw = (
|
||||
None
|
||||
if len(image_grid_thws_list) == 0
|
||||
else torch.cat(image_grid_thws_list, dim=0)
|
||||
)
|
||||
|
||||
video_grid_thws_list = [
|
||||
item.video_grid_thws
|
||||
for item in mm_input.mm_items
|
||||
if item.video_grid_thws is not None
|
||||
]
|
||||
video_grid_thw = (
|
||||
None
|
||||
if len(video_grid_thws_list) == 0
|
||||
else torch.cat(video_grid_thws_list, dim=0)
|
||||
)
|
||||
|
||||
second_per_grid_ts_list = [
|
||||
item.second_per_grid_ts
|
||||
for item in mm_input.mm_items
|
||||
if item.second_per_grid_ts is not None
|
||||
]
|
||||
second_per_grid_ts = (
|
||||
None
|
||||
if len(second_per_grid_ts_list) == 0
|
||||
else torch.cat(second_per_grid_ts_list, dim=0)
|
||||
)
|
||||
|
||||
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
|
||||
mrope_positions, mrope_position_delta = (
|
||||
MRotaryEmbedding.get_input_positions(
|
||||
input_tokens=self.input_ids[
|
||||
extend_start_loc : extend_start_loc + extend_seq_len
|
||||
],
|
||||
image_grid_thw=multimodal_inputs.image_grid_thws,
|
||||
video_grid_thw=multimodal_inputs.video_grid_thws,
|
||||
image_token_id=multimodal_inputs.im_token_id,
|
||||
video_token_id=multimodal_inputs.video_token_id,
|
||||
].tolist(),
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
image_token_id=hf_config.image_token_id,
|
||||
video_token_id=hf_config.video_token_id,
|
||||
vision_start_token_id=hf_config.vision_start_token_id,
|
||||
vision_end_token_id=hf_config.vision_end_token_id,
|
||||
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
|
||||
context_len=0,
|
||||
seq_len=len(self.input_ids),
|
||||
second_per_grid_ts=multimodal_inputs.second_per_grid_ts,
|
||||
tokens_per_second=hf_config.vision_config.tokens_per_second,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
tokens_per_second=getattr(
|
||||
hf_config.vision_config, "tokens_per_second", None
|
||||
),
|
||||
)
|
||||
)
|
||||
batch.multimodal_inputs[i].mrope_position_delta = (
|
||||
|
||||
@@ -1070,7 +1070,8 @@ class ModelRunner:
|
||||
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
|
||||
if rope_scaling is None:
|
||||
return False
|
||||
return rope_scaling.get("type", None) == "mrope"
|
||||
is_mrope_enabled = "mrope_section" in rope_scaling
|
||||
return is_mrope_enabled
|
||||
|
||||
def save_remote_model(self, url: str):
|
||||
from sglang.srt.model_loader.loader import RemoteModelLoader
|
||||
|
||||
Reference in New Issue
Block a user