[fix] fix mrope positions not picked up (#5265)
This commit is contained in:
@@ -30,12 +30,16 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers import Qwen2VLConfig
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
|
||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
Qwen2_5_VLConfig,
|
||||
Qwen2_5_VLVisionConfig,
|
||||
)
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
Qwen2_5_VisionPatchEmbed,
|
||||
Qwen2_5_VisionRotaryEmbedding,
|
||||
)
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_processor
|
||||
from sglang.srt.layers.attention.vision import VisionAttention
|
||||
@@ -173,33 +177,6 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class Qwen2_5_VisionPatchEmbed(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 14,
|
||||
temporal_patch_size: int = 2,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 1152,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
kernel_size = [temporal_patch_size, patch_size, patch_size]
|
||||
self.proj = nn.Conv3d(
|
||||
in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
target_dtype = self.proj.weight.dtype
|
||||
L, C = x.shape
|
||||
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
|
||||
x = self.proj(x.to(dtype=target_dtype)).view(L, self.embed_dim)
|
||||
return x
|
||||
|
||||
|
||||
class Qwen2_5_VisionPatchMerger(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -244,21 +221,6 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
class Qwen2_5_VisionRotaryEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
def forward(self, seqlen: int) -> torch.Tensor:
|
||||
seq = torch.arange(
|
||||
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
||||
)
|
||||
freqs = torch.outer(seq, self.inv_freq)
|
||||
return freqs
|
||||
|
||||
|
||||
class Qwen2_5_VisionTransformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -275,7 +237,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
spatial_merge_size: int = vision_config.spatial_merge_size
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.spatial_merge_unit: int = spatial_merge_size * spatial_merge_size
|
||||
in_chans: int = vision_config.in_channels
|
||||
in_channels: int = vision_config.in_channels
|
||||
hidden_size: int = vision_config.hidden_size
|
||||
depth: int = vision_config.depth
|
||||
num_heads: int = vision_config.num_heads
|
||||
@@ -286,7 +248,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
self.patch_embed = Qwen2_5_VisionPatchEmbed(
|
||||
patch_size=patch_size,
|
||||
temporal_patch_size=temporal_patch_size,
|
||||
in_chans=in_chans,
|
||||
in_channels=in_channels,
|
||||
embed_dim=hidden_size,
|
||||
)
|
||||
|
||||
@@ -469,7 +431,7 @@ cached_get_processor = lru_cache(get_processor)
|
||||
class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen2VLConfig,
|
||||
config: Qwen2_5_VLConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
@@ -553,14 +515,15 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
otherwise it will be `(seq_len,).
|
||||
(Use input_metadata.mrope_positions to replace it)
|
||||
"""
|
||||
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
||||
is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
||||
if is_mrope_enabled:
|
||||
positions = forward_batch.mrope_positions
|
||||
|
||||
if not (
|
||||
forward_batch.forward_mode.is_decode()
|
||||
or not forward_batch.contains_image_inputs()
|
||||
):
|
||||
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
||||
if is_mrope_enabled:
|
||||
assert positions.ndim == 2 and positions.size(0) == 3, (
|
||||
"multimodal section rotary embedding requires "
|
||||
f"(3, seq_len) positions, but got {positions.size()}"
|
||||
|
||||
Reference in New Issue
Block a user