From e53a0b3d5b3ef814c40e273f80aefefa132bb7d7 Mon Sep 17 00:00:00 2001 From: Mick Date: Fri, 11 Apr 2025 16:29:45 +0800 Subject: [PATCH] [fix] fix mrope positions not picked up (#5265) --- python/sglang/srt/layers/attention/vision.py | 2 +- python/sglang/srt/managers/schedule_batch.py | 13 ++-- .../srt/model_executor/forward_batch_info.py | 54 +++++++++++++---- .../sglang/srt/model_executor/model_runner.py | 3 +- python/sglang/srt/models/qwen2_5_vl.py | 59 ++++--------------- python/sglang/srt/models/qwen2_vl.py | 5 +- python/sglang/srt/openai_api/adapter.py | 2 + 7 files changed, 69 insertions(+), 69 deletions(-) diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index c76188e5e..860994913 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -94,7 +94,7 @@ class VisionAttention(nn.Module): input_size=embed_dim, output_size=embed_dim, quant_config=quant_config, - prefix=add_prefix("out_proj", prefix), + prefix=add_prefix("proj", prefix), ) def forward( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index cce17729e..86612a135 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -268,6 +268,9 @@ class MultimodalDataItem: self.modality == Modality.VIDEO ) and not MultimodalDataItem.is_empty_list(self.pixel_values) + def is_valid(self) -> bool: + return self.is_image() or self.is_video() or self.is_audio() + def validate(self): ... # TODO @@ -306,11 +309,7 @@ class MultimodalInputs: ) assert isinstance(ret.mm_items, list) - ret.mm_items = [ - item - for item in ret.mm_items - if item.is_audio() or item.is_image() or item.is_video() - ] + ret.mm_items = [item for item in ret.mm_items if item.is_valid()] assert len(ret.mm_items) != 0 @@ -345,8 +344,8 @@ class MultimodalInputs: """ """ return any(item.is_audio() for item in self.mm_items) - def collect_image_inputs(self) -> List[torch.Tensor]: - return [item.pixel_values for item in self.mm_items if item.is_image()] + def contains_mm_input(self) -> bool: + return any(True for item in self.mm_items if item.is_valid()) def merge(self, other: MultimodalInputs): """ diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index ae3e3eb8c..a0ead1784 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -33,7 +33,6 @@ from dataclasses import dataclass from enum import IntEnum, auto from typing import TYPE_CHECKING, List, Optional, Union -import numpy as np import torch import triton import triton.language as tl @@ -399,13 +398,13 @@ class ForwardBatch: ) elif self.forward_mode.is_extend(): extend_start_loc_cpu = self.extend_start_loc.cpu().numpy() - for i, multimodal_inputs in enumerate(batch.multimodal_inputs): + for i, mm_input in enumerate(batch.multimodal_inputs): extend_start_loc, extend_seq_len, extend_prefix_len = ( extend_start_loc_cpu[i], batch.extend_seq_lens[i], batch.extend_prefix_lens[i], ) - if multimodal_inputs is None: + if mm_input is None: # text only mrope_positions = [ [ @@ -416,23 +415,58 @@ class ForwardBatch: ] ] * 3 else: + image_grid_thws_list = [ + item.image_grid_thws + for item in mm_input.mm_items + if item.image_grid_thws is not None + ] + image_grid_thw = ( + None + if len(image_grid_thws_list) == 0 + else torch.cat(image_grid_thws_list, dim=0) + ) + + video_grid_thws_list = [ + item.video_grid_thws + for item in mm_input.mm_items + if item.video_grid_thws is not None + ] + video_grid_thw = ( + None + if len(video_grid_thws_list) == 0 + else torch.cat(video_grid_thws_list, dim=0) + ) + + second_per_grid_ts_list = [ + item.second_per_grid_ts + for item in mm_input.mm_items + if item.second_per_grid_ts is not None + ] + second_per_grid_ts = ( + None + if len(second_per_grid_ts_list) == 0 + else torch.cat(second_per_grid_ts_list, dim=0) + ) + # TODO: current qwen2-vl do not support radix cache since mrope position calculation mrope_positions, mrope_position_delta = ( MRotaryEmbedding.get_input_positions( input_tokens=self.input_ids[ extend_start_loc : extend_start_loc + extend_seq_len - ], - image_grid_thw=multimodal_inputs.image_grid_thws, - video_grid_thw=multimodal_inputs.video_grid_thws, - image_token_id=multimodal_inputs.im_token_id, - video_token_id=multimodal_inputs.video_token_id, + ].tolist(), + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + image_token_id=hf_config.image_token_id, + video_token_id=hf_config.video_token_id, vision_start_token_id=hf_config.vision_start_token_id, vision_end_token_id=hf_config.vision_end_token_id, spatial_merge_size=hf_config.vision_config.spatial_merge_size, context_len=0, seq_len=len(self.input_ids), - second_per_grid_ts=multimodal_inputs.second_per_grid_ts, - tokens_per_second=hf_config.vision_config.tokens_per_second, + second_per_grid_ts=second_per_grid_ts, + tokens_per_second=getattr( + hf_config.vision_config, "tokens_per_second", None + ), ) ) batch.multimodal_inputs[i].mrope_position_delta = ( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index f49cdae60..d254cb73f 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1070,7 +1070,8 @@ class ModelRunner: rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {}) if rope_scaling is None: return False - return rope_scaling.get("type", None) == "mrope" + is_mrope_enabled = "mrope_section" in rope_scaling + return is_mrope_enabled def save_remote_model(self, url: str): from sglang.srt.model_loader.loader import RemoteModelLoader diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 4608c39b3..a9a586cf5 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -30,12 +30,16 @@ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from transformers import Qwen2VLConfig from transformers.activations import ACT2FN from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( + Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig, ) +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionPatchEmbed, + Qwen2_5_VisionRotaryEmbedding, +) from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.layers.attention.vision import VisionAttention @@ -173,33 +177,6 @@ class Qwen2_5_VisionBlock(nn.Module): return x -class Qwen2_5_VisionPatchEmbed(nn.Module): - - def __init__( - self, - patch_size: int = 14, - temporal_patch_size: int = 2, - in_chans: int = 3, - embed_dim: int = 1152, - ) -> None: - super().__init__() - self.patch_size = patch_size - self.temporal_patch_size = temporal_patch_size - self.embed_dim = embed_dim - - kernel_size = [temporal_patch_size, patch_size, patch_size] - self.proj = nn.Conv3d( - in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - target_dtype = self.proj.weight.dtype - L, C = x.shape - x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) - x = self.proj(x.to(dtype=target_dtype)).view(L, self.embed_dim) - return x - - class Qwen2_5_VisionPatchMerger(nn.Module): def __init__( @@ -244,21 +221,6 @@ class Qwen2_5_VisionPatchMerger(nn.Module): return out -class Qwen2_5_VisionRotaryEmbedding(nn.Module): - - def __init__(self, dim: int, theta: float = 10000.0) -> None: - super().__init__() - inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange( - seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype - ) - freqs = torch.outer(seq, self.inv_freq) - return freqs - - class Qwen2_5_VisionTransformer(nn.Module): def __init__( @@ -275,7 +237,7 @@ class Qwen2_5_VisionTransformer(nn.Module): spatial_merge_size: int = vision_config.spatial_merge_size self.spatial_merge_size = spatial_merge_size self.spatial_merge_unit: int = spatial_merge_size * spatial_merge_size - in_chans: int = vision_config.in_channels + in_channels: int = vision_config.in_channels hidden_size: int = vision_config.hidden_size depth: int = vision_config.depth num_heads: int = vision_config.num_heads @@ -286,7 +248,7 @@ class Qwen2_5_VisionTransformer(nn.Module): self.patch_embed = Qwen2_5_VisionPatchEmbed( patch_size=patch_size, temporal_patch_size=temporal_patch_size, - in_chans=in_chans, + in_channels=in_channels, embed_dim=hidden_size, ) @@ -469,7 +431,7 @@ cached_get_processor = lru_cache(get_processor) class Qwen2_5_VLForConditionalGeneration(nn.Module): def __init__( self, - config: Qwen2VLConfig, + config: Qwen2_5_VLConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: @@ -553,14 +515,15 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): otherwise it will be `(seq_len,). (Use input_metadata.mrope_positions to replace it) """ - if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": + is_mrope_enabled = "mrope_section" in self.config.rope_scaling + if is_mrope_enabled: positions = forward_batch.mrope_positions if not ( forward_batch.forward_mode.is_decode() or not forward_batch.contains_image_inputs() ): - if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": + if is_mrope_enabled: assert positions.ndim == 2 and positions.size(0) == 3, ( "multimodal section rotary embedding requires " f"(3, seq_len) positions, but got {positions.size()}" diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index a708db504..da878e867 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -521,14 +521,15 @@ class Qwen2VLForConditionalGeneration(nn.Module): otherwise it will be `(seq_len,). (Use input_metadata.mrope_positions to replace it) """ - if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": + is_mrope_enabled = "mrope_section" in self.config.rope_scaling + if is_mrope_enabled: positions = forward_batch.mrope_positions if not ( forward_batch.forward_mode.is_decode() or not forward_batch.contains_image_inputs() ): - if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": + if is_mrope_enabled: assert positions.ndim == 2 and positions.size(0) == 3, ( "multimodal section rotary embedding requires " f"(3, seq_len) positions, but got {positions.size()}" diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 13caec271..64be034a4 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -983,6 +983,8 @@ def v1_chat_generate_request( ): encoded = encoded[1:] prompt_ids += encoded + if tokenizer_manager.model_config.is_multimodal: + prompt = tokenizer_manager.tokenizer.decode(prompt_ids) stop = request.stop image_data = None audio_data = None