[fix] fix mrope positions not picked up (#5265)
This commit is contained in:
@@ -94,7 +94,7 @@ class VisionAttention(nn.Module):
|
|||||||
input_size=embed_dim,
|
input_size=embed_dim,
|
||||||
output_size=embed_dim,
|
output_size=embed_dim,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("out_proj", prefix),
|
prefix=add_prefix("proj", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -268,6 +268,9 @@ class MultimodalDataItem:
|
|||||||
self.modality == Modality.VIDEO
|
self.modality == Modality.VIDEO
|
||||||
) and not MultimodalDataItem.is_empty_list(self.pixel_values)
|
) and not MultimodalDataItem.is_empty_list(self.pixel_values)
|
||||||
|
|
||||||
|
def is_valid(self) -> bool:
|
||||||
|
return self.is_image() or self.is_video() or self.is_audio()
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
...
|
...
|
||||||
# TODO
|
# TODO
|
||||||
@@ -306,11 +309,7 @@ class MultimodalInputs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(ret.mm_items, list)
|
assert isinstance(ret.mm_items, list)
|
||||||
ret.mm_items = [
|
ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
|
||||||
item
|
|
||||||
for item in ret.mm_items
|
|
||||||
if item.is_audio() or item.is_image() or item.is_video()
|
|
||||||
]
|
|
||||||
|
|
||||||
assert len(ret.mm_items) != 0
|
assert len(ret.mm_items) != 0
|
||||||
|
|
||||||
@@ -345,8 +344,8 @@ class MultimodalInputs:
|
|||||||
""" """
|
""" """
|
||||||
return any(item.is_audio() for item in self.mm_items)
|
return any(item.is_audio() for item in self.mm_items)
|
||||||
|
|
||||||
def collect_image_inputs(self) -> List[torch.Tensor]:
|
def contains_mm_input(self) -> bool:
|
||||||
return [item.pixel_values for item in self.mm_items if item.is_image()]
|
return any(True for item in self.mm_items if item.is_valid())
|
||||||
|
|
||||||
def merge(self, other: MultimodalInputs):
|
def merge(self, other: MultimodalInputs):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -33,7 +33,6 @@ from dataclasses import dataclass
|
|||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import TYPE_CHECKING, List, Optional, Union
|
from typing import TYPE_CHECKING, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
@@ -399,13 +398,13 @@ class ForwardBatch:
|
|||||||
)
|
)
|
||||||
elif self.forward_mode.is_extend():
|
elif self.forward_mode.is_extend():
|
||||||
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
|
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
|
||||||
for i, multimodal_inputs in enumerate(batch.multimodal_inputs):
|
for i, mm_input in enumerate(batch.multimodal_inputs):
|
||||||
extend_start_loc, extend_seq_len, extend_prefix_len = (
|
extend_start_loc, extend_seq_len, extend_prefix_len = (
|
||||||
extend_start_loc_cpu[i],
|
extend_start_loc_cpu[i],
|
||||||
batch.extend_seq_lens[i],
|
batch.extend_seq_lens[i],
|
||||||
batch.extend_prefix_lens[i],
|
batch.extend_prefix_lens[i],
|
||||||
)
|
)
|
||||||
if multimodal_inputs is None:
|
if mm_input is None:
|
||||||
# text only
|
# text only
|
||||||
mrope_positions = [
|
mrope_positions = [
|
||||||
[
|
[
|
||||||
@@ -416,23 +415,58 @@ class ForwardBatch:
|
|||||||
]
|
]
|
||||||
] * 3
|
] * 3
|
||||||
else:
|
else:
|
||||||
|
image_grid_thws_list = [
|
||||||
|
item.image_grid_thws
|
||||||
|
for item in mm_input.mm_items
|
||||||
|
if item.image_grid_thws is not None
|
||||||
|
]
|
||||||
|
image_grid_thw = (
|
||||||
|
None
|
||||||
|
if len(image_grid_thws_list) == 0
|
||||||
|
else torch.cat(image_grid_thws_list, dim=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
video_grid_thws_list = [
|
||||||
|
item.video_grid_thws
|
||||||
|
for item in mm_input.mm_items
|
||||||
|
if item.video_grid_thws is not None
|
||||||
|
]
|
||||||
|
video_grid_thw = (
|
||||||
|
None
|
||||||
|
if len(video_grid_thws_list) == 0
|
||||||
|
else torch.cat(video_grid_thws_list, dim=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
second_per_grid_ts_list = [
|
||||||
|
item.second_per_grid_ts
|
||||||
|
for item in mm_input.mm_items
|
||||||
|
if item.second_per_grid_ts is not None
|
||||||
|
]
|
||||||
|
second_per_grid_ts = (
|
||||||
|
None
|
||||||
|
if len(second_per_grid_ts_list) == 0
|
||||||
|
else torch.cat(second_per_grid_ts_list, dim=0)
|
||||||
|
)
|
||||||
|
|
||||||
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
|
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
|
||||||
mrope_positions, mrope_position_delta = (
|
mrope_positions, mrope_position_delta = (
|
||||||
MRotaryEmbedding.get_input_positions(
|
MRotaryEmbedding.get_input_positions(
|
||||||
input_tokens=self.input_ids[
|
input_tokens=self.input_ids[
|
||||||
extend_start_loc : extend_start_loc + extend_seq_len
|
extend_start_loc : extend_start_loc + extend_seq_len
|
||||||
],
|
].tolist(),
|
||||||
image_grid_thw=multimodal_inputs.image_grid_thws,
|
image_grid_thw=image_grid_thw,
|
||||||
video_grid_thw=multimodal_inputs.video_grid_thws,
|
video_grid_thw=video_grid_thw,
|
||||||
image_token_id=multimodal_inputs.im_token_id,
|
image_token_id=hf_config.image_token_id,
|
||||||
video_token_id=multimodal_inputs.video_token_id,
|
video_token_id=hf_config.video_token_id,
|
||||||
vision_start_token_id=hf_config.vision_start_token_id,
|
vision_start_token_id=hf_config.vision_start_token_id,
|
||||||
vision_end_token_id=hf_config.vision_end_token_id,
|
vision_end_token_id=hf_config.vision_end_token_id,
|
||||||
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
|
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
|
||||||
context_len=0,
|
context_len=0,
|
||||||
seq_len=len(self.input_ids),
|
seq_len=len(self.input_ids),
|
||||||
second_per_grid_ts=multimodal_inputs.second_per_grid_ts,
|
second_per_grid_ts=second_per_grid_ts,
|
||||||
tokens_per_second=hf_config.vision_config.tokens_per_second,
|
tokens_per_second=getattr(
|
||||||
|
hf_config.vision_config, "tokens_per_second", None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
batch.multimodal_inputs[i].mrope_position_delta = (
|
batch.multimodal_inputs[i].mrope_position_delta = (
|
||||||
|
|||||||
@@ -1070,7 +1070,8 @@ class ModelRunner:
|
|||||||
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
|
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
|
||||||
if rope_scaling is None:
|
if rope_scaling is None:
|
||||||
return False
|
return False
|
||||||
return rope_scaling.get("type", None) == "mrope"
|
is_mrope_enabled = "mrope_section" in rope_scaling
|
||||||
|
return is_mrope_enabled
|
||||||
|
|
||||||
def save_remote_model(self, url: str):
|
def save_remote_model(self, url: str):
|
||||||
from sglang.srt.model_loader.loader import RemoteModelLoader
|
from sglang.srt.model_loader.loader import RemoteModelLoader
|
||||||
|
|||||||
@@ -30,12 +30,16 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from transformers import Qwen2VLConfig
|
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
|
||||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||||
|
Qwen2_5_VLConfig,
|
||||||
Qwen2_5_VLVisionConfig,
|
Qwen2_5_VLVisionConfig,
|
||||||
)
|
)
|
||||||
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||||
|
Qwen2_5_VisionPatchEmbed,
|
||||||
|
Qwen2_5_VisionRotaryEmbedding,
|
||||||
|
)
|
||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import get_processor
|
from sglang.srt.hf_transformers_utils import get_processor
|
||||||
from sglang.srt.layers.attention.vision import VisionAttention
|
from sglang.srt.layers.attention.vision import VisionAttention
|
||||||
@@ -173,33 +177,6 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Qwen2_5_VisionPatchEmbed(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
patch_size: int = 14,
|
|
||||||
temporal_patch_size: int = 2,
|
|
||||||
in_chans: int = 3,
|
|
||||||
embed_dim: int = 1152,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.temporal_patch_size = temporal_patch_size
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
|
|
||||||
kernel_size = [temporal_patch_size, patch_size, patch_size]
|
|
||||||
self.proj = nn.Conv3d(
|
|
||||||
in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
target_dtype = self.proj.weight.dtype
|
|
||||||
L, C = x.shape
|
|
||||||
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
|
|
||||||
x = self.proj(x.to(dtype=target_dtype)).view(L, self.embed_dim)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen2_5_VisionPatchMerger(nn.Module):
|
class Qwen2_5_VisionPatchMerger(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -244,21 +221,6 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Qwen2_5_VisionRotaryEmbedding(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
|
||||||
super().__init__()
|
|
||||||
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
||||||
|
|
||||||
def forward(self, seqlen: int) -> torch.Tensor:
|
|
||||||
seq = torch.arange(
|
|
||||||
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
|
||||||
)
|
|
||||||
freqs = torch.outer(seq, self.inv_freq)
|
|
||||||
return freqs
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen2_5_VisionTransformer(nn.Module):
|
class Qwen2_5_VisionTransformer(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -275,7 +237,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
spatial_merge_size: int = vision_config.spatial_merge_size
|
spatial_merge_size: int = vision_config.spatial_merge_size
|
||||||
self.spatial_merge_size = spatial_merge_size
|
self.spatial_merge_size = spatial_merge_size
|
||||||
self.spatial_merge_unit: int = spatial_merge_size * spatial_merge_size
|
self.spatial_merge_unit: int = spatial_merge_size * spatial_merge_size
|
||||||
in_chans: int = vision_config.in_channels
|
in_channels: int = vision_config.in_channels
|
||||||
hidden_size: int = vision_config.hidden_size
|
hidden_size: int = vision_config.hidden_size
|
||||||
depth: int = vision_config.depth
|
depth: int = vision_config.depth
|
||||||
num_heads: int = vision_config.num_heads
|
num_heads: int = vision_config.num_heads
|
||||||
@@ -286,7 +248,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
self.patch_embed = Qwen2_5_VisionPatchEmbed(
|
self.patch_embed = Qwen2_5_VisionPatchEmbed(
|
||||||
patch_size=patch_size,
|
patch_size=patch_size,
|
||||||
temporal_patch_size=temporal_patch_size,
|
temporal_patch_size=temporal_patch_size,
|
||||||
in_chans=in_chans,
|
in_channels=in_channels,
|
||||||
embed_dim=hidden_size,
|
embed_dim=hidden_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -469,7 +431,7 @@ cached_get_processor = lru_cache(get_processor)
|
|||||||
class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Qwen2VLConfig,
|
config: Qwen2_5_VLConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -553,14 +515,15 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|||||||
otherwise it will be `(seq_len,).
|
otherwise it will be `(seq_len,).
|
||||||
(Use input_metadata.mrope_positions to replace it)
|
(Use input_metadata.mrope_positions to replace it)
|
||||||
"""
|
"""
|
||||||
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
||||||
|
if is_mrope_enabled:
|
||||||
positions = forward_batch.mrope_positions
|
positions = forward_batch.mrope_positions
|
||||||
|
|
||||||
if not (
|
if not (
|
||||||
forward_batch.forward_mode.is_decode()
|
forward_batch.forward_mode.is_decode()
|
||||||
or not forward_batch.contains_image_inputs()
|
or not forward_batch.contains_image_inputs()
|
||||||
):
|
):
|
||||||
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
if is_mrope_enabled:
|
||||||
assert positions.ndim == 2 and positions.size(0) == 3, (
|
assert positions.ndim == 2 and positions.size(0) == 3, (
|
||||||
"multimodal section rotary embedding requires "
|
"multimodal section rotary embedding requires "
|
||||||
f"(3, seq_len) positions, but got {positions.size()}"
|
f"(3, seq_len) positions, but got {positions.size()}"
|
||||||
|
|||||||
@@ -521,14 +521,15 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
otherwise it will be `(seq_len,).
|
otherwise it will be `(seq_len,).
|
||||||
(Use input_metadata.mrope_positions to replace it)
|
(Use input_metadata.mrope_positions to replace it)
|
||||||
"""
|
"""
|
||||||
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
||||||
|
if is_mrope_enabled:
|
||||||
positions = forward_batch.mrope_positions
|
positions = forward_batch.mrope_positions
|
||||||
|
|
||||||
if not (
|
if not (
|
||||||
forward_batch.forward_mode.is_decode()
|
forward_batch.forward_mode.is_decode()
|
||||||
or not forward_batch.contains_image_inputs()
|
or not forward_batch.contains_image_inputs()
|
||||||
):
|
):
|
||||||
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
if is_mrope_enabled:
|
||||||
assert positions.ndim == 2 and positions.size(0) == 3, (
|
assert positions.ndim == 2 and positions.size(0) == 3, (
|
||||||
"multimodal section rotary embedding requires "
|
"multimodal section rotary embedding requires "
|
||||||
f"(3, seq_len) positions, but got {positions.size()}"
|
f"(3, seq_len) positions, but got {positions.size()}"
|
||||||
|
|||||||
@@ -983,6 +983,8 @@ def v1_chat_generate_request(
|
|||||||
):
|
):
|
||||||
encoded = encoded[1:]
|
encoded = encoded[1:]
|
||||||
prompt_ids += encoded
|
prompt_ids += encoded
|
||||||
|
if tokenizer_manager.model_config.is_multimodal:
|
||||||
|
prompt = tokenizer_manager.tokenizer.decode(prompt_ids)
|
||||||
stop = request.stop
|
stop = request.stop
|
||||||
image_data = None
|
image_data = None
|
||||||
audio_data = None
|
audio_data = None
|
||||||
|
|||||||
Reference in New Issue
Block a user