[fix] fix mrope positions not picked up (#5265)

This commit is contained in:
Mick
2025-04-11 16:29:45 +08:00
committed by GitHub
parent 038bc5d521
commit e53a0b3d5b
7 changed files with 69 additions and 69 deletions

View File

@@ -30,12 +30,16 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import Qwen2VLConfig
from transformers.activations import ACT2FN
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig,
Qwen2_5_VLVisionConfig,
)
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionPatchEmbed,
Qwen2_5_VisionRotaryEmbedding,
)
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.attention.vision import VisionAttention
@@ -173,33 +177,6 @@ class Qwen2_5_VisionBlock(nn.Module):
return x
class Qwen2_5_VisionPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 14,
temporal_patch_size: int = 2,
in_chans: int = 3,
embed_dim: int = 1152,
) -> None:
super().__init__()
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.embed_dim = embed_dim
kernel_size = [temporal_patch_size, patch_size, patch_size]
self.proj = nn.Conv3d(
in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
target_dtype = self.proj.weight.dtype
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x.to(dtype=target_dtype)).view(L, self.embed_dim)
return x
class Qwen2_5_VisionPatchMerger(nn.Module):
def __init__(
@@ -244,21 +221,6 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
return out
class Qwen2_5_VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seqlen: int) -> torch.Tensor:
seq = torch.arange(
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
)
freqs = torch.outer(seq, self.inv_freq)
return freqs
class Qwen2_5_VisionTransformer(nn.Module):
def __init__(
@@ -275,7 +237,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
spatial_merge_size: int = vision_config.spatial_merge_size
self.spatial_merge_size = spatial_merge_size
self.spatial_merge_unit: int = spatial_merge_size * spatial_merge_size
in_chans: int = vision_config.in_channels
in_channels: int = vision_config.in_channels
hidden_size: int = vision_config.hidden_size
depth: int = vision_config.depth
num_heads: int = vision_config.num_heads
@@ -286,7 +248,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
self.patch_embed = Qwen2_5_VisionPatchEmbed(
patch_size=patch_size,
temporal_patch_size=temporal_patch_size,
in_chans=in_chans,
in_channels=in_channels,
embed_dim=hidden_size,
)
@@ -469,7 +431,7 @@ cached_get_processor = lru_cache(get_processor)
class Qwen2_5_VLForConditionalGeneration(nn.Module):
def __init__(
self,
config: Qwen2VLConfig,
config: Qwen2_5_VLConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
@@ -553,14 +515,15 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
is_mrope_enabled = "mrope_section" in self.config.rope_scaling
if is_mrope_enabled:
positions = forward_batch.mrope_positions
if not (
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
):
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
if is_mrope_enabled:
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"