# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # Adapted from vllm/model_executor/models/qwen2_5_vl.py # Copyright 2025 The vLLM team. # Copyright 2025 The Qwen Team. # # This file is a part of the vllm-ascend project. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from functools import lru_cache, partial from typing import Annotated, Literal, Optional import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torchvision.transforms import v2 from transformers.utils import logging from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import ( MultiModalEmbeddings, SupportsLoRA, SupportsMRoPE, SupportsMultiModal, SupportsPP, ) from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2_5_vl import ( Qwen2_5_VLDummyInputsBuilder, Qwen2_5_VLMultiModalProcessor, Qwen2_5_VLProcessingInfo, ) from vllm.model_executor.models.utils import ( AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix, ) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalFeatureSpec, MultiModalKwargsItems, ) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import ( PromptReplacement, PromptUpdate, PromptUpdateDetails, ) from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.v1.attention.backends.registry import AttentionBackendEnum from .vision import get_vit_attn_backend logger = logging.get_logger(__name__) class OpenPanguVisionAttention(nn.Module): def __init__( self, embed_dim: int, num_heads: int, projection_size: int, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size_per_attention_head = dist_utils.divide( projection_size, num_heads ) self.tp_size = parallel_state.get_tensor_model_parallel_world_size() self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.num_attention_heads_per_partition = dist_utils.divide( num_heads, self.tp_size ) self.qkv = QKVParallelLinear( hidden_size=embed_dim, head_size=self.hidden_size_per_attention_head, total_num_heads=num_heads, total_num_kv_heads=num_heads, bias=True, quant_config=quant_config, prefix=f"{prefix}.qkv", ) self.proj = RowParallelLinear( input_size=projection_size, output_size=embed_dim, quant_config=quant_config, prefix=f"{prefix}.proj", ) self.attn = MMEncoderAttention( num_heads=self.num_attention_heads_per_partition, head_size=self.hidden_size_per_attention_head, scale=self.hidden_size_per_attention_head**-0.5, prefix=f"{prefix}.attn", ) self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True) def forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> torch.Tensor: seq_length, _ = x.size() x, bias = self.qkv(x) if bias is not None: x = x + bias q, k, v = x.chunk(3, dim=1) q, k, v = ( rearrange( x, "s (b n d) -> b s n d", d=self.hidden_size_per_attention_head, b=1 ).contiguous() for x in (q, k, v) ) qk_concat = torch.cat([q, k], dim=0) qk_rotated = self.apply_rotary_emb(qk_concat, cos, sin) q, k = torch.chunk(qk_rotated, 2, dim=0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() context_layer = self.attn( query=q, key=k, value=v, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) context_layer = rearrange( context_layer, "b s h d -> s (b h d)", b=1 ).contiguous() output, bias = self.proj(context_layer) if bias is not None: output = output + bias return output class OpenPanguVisionMLP(nn.Module): def __init__( self, in_features: int, hidden_features: int, bias: bool = False, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, vision_config=None, quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() self.hidden_act = vision_config.hidden_act if self.hidden_act == "silu": tp_size = parallel_state.get_tensor_model_parallel_world_size() if hidden_features % tp_size != 0: hidden_features = (hidden_features + tp_size - 1) // tp_size * tp_size self.gate_up_proj = MergedColumnParallelLinear( input_size=in_features, output_sizes=[hidden_features] * 2, bias=bias, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", ) else: self.up_proj = ColumnParallelLinear( in_features, hidden_features, bias=bias, quant_config=quant_config, prefix=f"{prefix}.up_proj", ) self.down_proj = RowParallelLinear( hidden_features, in_features, bias=bias, quant_config=quant_config, prefix=f"{prefix}.down_proj", ) self.act_fn = act_fn def forward(self, x: torch.Tensor): if self.hidden_act == "silu": x, _ = self.gate_up_proj(x) else: x, _ = self.up_proj(x) x = self.act_fn(x) x, _ = self.down_proj(x) return x class OpenPanguVisionBlock(nn.Module): def __init__( self, dim: int, num_heads: int, mlp_hidden_dim: int, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, norm_layer: Callable[[int], nn.Module] | None = None, vision_config=None, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.norm1 = norm_layer(dim) self.norm2 = norm_layer(dim) self.attn = OpenPanguVisionAttention( embed_dim=dim, num_heads=num_heads, projection_size=dim, quant_config=quant_config, prefix=f"{prefix}.attn", ) self.mlp = OpenPanguVisionMLP( dim, mlp_hidden_dim, act_fn=act_fn, bias=True, vision_config=vision_config, quant_config=quant_config, prefix=f"{prefix}.mlp", ) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, cos=cos, sin=sin ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states class OpenPanguVisionRotaryEmbedding(nn.Module): def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() self.inv_freq = 1.0 / ( theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim) ) self._seq_len_cached = 0 self._freqs_cached = None def update_freqs_cache(self, seqlen: int) -> None: if seqlen > self._seq_len_cached: seqlen *= 2 self._seq_len_cached = seqlen seq = torch.arange( seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype ) freqs = torch.outer(seq, self.inv_freq) self._freqs_cached = freqs def forward(self, seqlen: int) -> torch.Tensor: self.update_freqs_cache(seqlen) return ( self._freqs_cached[:seqlen] if self._freqs_cached is not None else self._freqs_cached ) class OpenPanguVisionPatchEmbed(nn.Module): def __init__( self, patch_size: int = 14, temporal_patch_size: int = 2, in_channels: int = 3, hidden_size: int = 1152, ) -> None: super().__init__() self.patch_size = patch_size self.temporal_patch_size = temporal_patch_size self.hidden_size = hidden_size self.input_size = ( self.patch_size * self.patch_size * in_channels * self.temporal_patch_size ) kernel_size = (temporal_patch_size, patch_size, patch_size) self.proj = nn.Conv3d( in_channels, hidden_size, kernel_size=kernel_size, stride=kernel_size, bias=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: if x.shape[-1] != self.input_size: x = torch.cat( [ x.reshape(-1, self.patch_size * self.patch_size), x.reshape(-1, self.patch_size * self.patch_size), ], dim=-1, ).reshape(-1, self.input_size) x = x.matmul(self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1)) return x class OpenPanguVisionPatchMerger(nn.Module): def __init__( self, d_model: int, context_dim: int, norm_layer: Callable[[int], nn.Module] | None = None, spatial_merge_size: int = 2, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.hidden_size = context_dim * (spatial_merge_size**2) self.ln_q = norm_layer(context_dim) self.mlp = nn.Sequential( ColumnParallelLinear( self.hidden_size, self.hidden_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.mlp.0", return_bias=False, ), nn.GELU(), RowParallelLinear( self.hidden_size, d_model, bias=True, quant_config=quant_config, prefix=f"{prefix}.mlp.2", return_bias=False, ), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.mlp(self.ln_q(x).view(-1, self.hidden_size)) class OpenPanguVisionTransformer(nn.Module): def __init__( self, vision_config, out_hidden_size, hidden_size, norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, prefix: str = "", interleaved=False, ) -> None: super().__init__() self.hidden_size = vision_config.hidden_size self.num_heads = vision_config.num_heads self.window_size = vision_config.window_size self.patch_size = vision_config.patch_size self.spatial_merge_size = vision_config.spatial_merge_size self.fullatt_block_indexes = vision_config.fullatt_block_indexes self.spatial_merge_unit = self.spatial_merge_size**2 norm_layer = partial(RMSNorm, eps=norm_eps) self.interleaved = interleaved self.out_hidden_size = vision_config.out_hidden_size self.hidden_act = vision_config.hidden_act head_dim = self.hidden_size // self.num_heads self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype(), ) if self.attn_backend not in { AttentionBackendEnum.FLASH_ATTN, }: raise RuntimeError( f"Pangu-VL does not support {self.attn_backend} backend now." ) self.rotary_pos_emb = OpenPanguVisionRotaryEmbedding(head_dim // 2) self.patch_embed = OpenPanguVisionPatchEmbed( patch_size=vision_config.patch_size, temporal_patch_size=vision_config.temporal_patch_size, in_channels=vision_config.in_channels, hidden_size=self.hidden_size, ) self.blocks = nn.ModuleList( [ OpenPanguVisionBlock( dim=self.hidden_size, num_heads=self.num_heads, mlp_hidden_dim=vision_config.intermediate_size, act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], vision_config=vision_config, norm_layer=norm_layer, quant_config=quant_config, prefix=f"{prefix}.blocks.{layer_idx}", ) for layer_idx in range(vision_config.depth) ] ) self.tp_size = parallel_state.get_tensor_model_parallel_world_size() self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.hidden_size_per_attention_head = dist_utils.divide( self.hidden_size, self.num_heads ) self.select_layer = getattr( vision_config, "mm_unit_vision_select_layer", [-1, -3] ) self.select_index = [vision_config.depth + i for i in self.select_layer] self.select_index = self.select_index[::-1] self.select_layer = [-1 * (i + 1) for i in range(len(self.select_index))] self.take_indices = self.select_index self.final_layernorm = RMSNorm(self.hidden_size, eps=norm_eps) self.merger = nn.ModuleList( [ OpenPanguVisionPatchMerger( d_model=vision_config.out_hidden_size, context_dim=self.hidden_size, norm_layer=norm_layer, spatial_merge_size=self.spatial_merge_size, quant_config=quant_config, prefix=f"{prefix}.merger.{i}", ) for i in range(len(self.select_layer)) ] ) self.vision_projection = ProjectionSingle(out_hidden_size, hidden_size) @property def dtype(self) -> torch.dtype: return self.patch_embed.proj.weight.dtype @property def device(self) -> torch.device: return self.patch_embed.proj.weight.device def cal_cos_sin(self, rotary_pos_emb): cos = rotary_pos_emb.cos() sin = rotary_pos_emb.sin() return cos, sin def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: # see https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py for details. #L209 # noqa: E501 pos_ids = [] for t, h, w in grid_thw: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) hpos_ids = ( hpos_ids.reshape( h // self.spatial_merge_size, self.spatial_merge_size, w // self.spatial_merge_size, self.spatial_merge_size, ) .permute(0, 2, 1, 3) .flatten() ) wpos_ids = ( wpos_ids.reshape( h // self.spatial_merge_size, self.spatial_merge_size, w // self.spatial_merge_size, self.spatial_merge_size, ) .permute(0, 2, 1, 3) .flatten() ) pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb def get_window_index(self, grid_thw): # see https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py for details. #L238 # noqa: E501 window_index: list = [] cu_window_seqlens: list = [0] window_index_id = 0 vit_merger_window_size = ( self.window_size // self.spatial_merge_size // self.patch_size ) for grid_t, grid_h, grid_w in grid_thw: llm_grid_h = grid_h // self.spatial_merge_size llm_grid_w = grid_w // self.spatial_merge_size index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( grid_t, llm_grid_h, llm_grid_w ) pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) index_padded = index_padded.reshape( grid_t, num_windows_h, vit_merger_window_size, num_windows_w, vit_merger_window_size, ) index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( grid_t, num_windows_h * num_windows_w, vit_merger_window_size, vit_merger_window_size, ) seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) index_padded = index_padded.reshape(-1) index_new = index_padded[index_padded != -100] window_index.append(index_new + window_index_id) cu_seqlens_tmp = ( seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] ) cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() window_index = torch.cat(window_index, dim=0) return window_index, cu_window_seqlens def forward( self, x: torch.Tensor, grid_thw: torch.Tensor, ) -> torch.Tensor: # compute cu_seqlens cu_seqlens = ( torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]) .to(torch.int32) .to(x.device) ) cu_seqlens = torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) x = self.patch_embed(x) rotary_pos_emb = self.rot_pos_emb(grid_thw) window_index, cu_window_seqlens = self.get_window_index(grid_thw) cu_window_seqlens = torch.tensor( cu_window_seqlens, device=x.device, dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, ) cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) seq_len, _ = x.size() x = x.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) x = x[window_index, :, :] x = x.reshape(seq_len, -1) rotary_pos_emb = rotary_pos_emb.reshape( seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 ) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) cos, sin = self.cal_cos_sin(rotary_pos_emb.to(x.dtype)) intermediates = [] for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: cu_seqlens_now = cu_seqlens else: cu_seqlens_now = cu_window_seqlens x = blk(x, cu_seqlens=cu_seqlens_now, cos=cos, sin=sin) if layer_num in self.take_indices: ln_hs = self.final_layernorm(x) intermediates.append(ln_hs) image_embeddings_list = [] for idx, sl in enumerate(self.select_layer): image_embeddings_list.append(self.merger[idx](intermediates[sl])) x = sum(image_embeddings_list) reverse_indices = torch.argsort(window_index) x = x[reverse_indices, :] x = self.vision_projection(x) return x def load_weights(self, weights) -> set[str]: def _padding_weight(name: str, w: torch.Tensor) -> torch.Tensor: if "gate_proj" in name or "up_proj" in name: dim, size = 0, w.size(0) elif "down_proj" in name: dim, size = 1, w.size(-1) else: return w pad_len = -size % self.tp_size if pad_len == 0: return w pad = [0] * (w.ndim * 2) pad[-(dim + 1) * 2 + 1] = pad_len return F.pad(w, pad, mode="constant", value=0) stacked_params_mapping = [ ("attn.qkv.", "attn.q.", "q"), ("attn.qkv.", "attn.k.", "k"), ("attn.qkv.", "attn.v.", "v"), ] if self.hidden_act == "silu": stacked_params_mapping.extend( [ ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] ) params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: if self.hidden_act == "silu": loaded_weight = _padding_weight(name, loaded_weight) for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class ProjectionSingle(nn.Module): def __init__(self, i_hidden_size: int, t_hidden_size: int): super().__init__() self.act = F.silu self.fc1 = nn.Linear(i_hidden_size, t_hidden_size, bias=True) def forward(self, hidden_states): x = self.act(hidden_states) return self.fc1(x) class OpenPanguVLProcessingInfo(Qwen2_5_VLProcessingInfo): def get_hf_config(self): return self.ctx.model_config.hf_config def get_hf_processor( self, *, min_pixels: int | None = None, max_pixels: int | None = None, size: dict[str, int] | None = None, fps: float | list[float] | None = None, **kwargs: object, ): if fps is not None: kwargs["fps"] = fps return self.ctx.get_hf_processor( use_fast=kwargs.pop("use_fast", True), **kwargs, ) class OpenPanguVLImagePixelInputs(TensorSchema): type: Literal["pixel_values"] pixel_values: Annotated[ torch.Tensor, TensorShape("np", "cps"), ] image_grid_thw: Annotated[ torch.Tensor, TensorShape("ni", 3), ] class OpenPanguVLImageEmbeddingInputs(TensorSchema): type: Literal["image_embeds"] image_embeds: Annotated[ torch.Tensor, TensorShape("nf", "hs"), ] image_grid_thw: Annotated[ torch.Tensor, TensorShape("ni", 3), ] class OpenPanguVLVideoPixelInputs(TensorSchema): type: Literal["pixel_values_videos"] pixel_values_videos: Annotated[ torch.Tensor, TensorShape("np", "ctps"), ] video_grid_thw: Annotated[ torch.Tensor, TensorShape("nv", 3), ] class OpenPanguVLVideoEmbeddingInputs(TensorSchema): type: Literal["video_embeds"] video_embeds: Annotated[ torch.Tensor, TensorShape("nf", "hs"), ] video_grid_thw: Annotated[ torch.Tensor, TensorShape("nv", 3), ] class OpenPanguVLMultiModalProcessor(Qwen2_5_VLMultiModalProcessor): def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, any], out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() image_token = hf_processor.image_token video_token = hf_processor.video_token vision_start_token = hf_processor.vision_start_token vision_end_token = hf_processor.vision_end_token image_token_id = vocab[image_token] video_token_id = vocab[video_token] vision_start_token_id = vocab[vision_start_token] vision_end_token_id = vocab[vision_end_token] placeholder = { "image": image_token_id, "video": video_token_id, } merge_length = image_processor.merge_size**2 def get_replacement_openpangu_vision(item_idx: int, modality: str): out_item = out_mm_kwargs[modality][item_idx] grid_thw = out_item[f"{modality}_grid_thw"].data if not isinstance(grid_thw, torch.Tensor): raise TypeError("Expected 'grid_thw' to be a Tensor") if modality == "image": image_token_id_total = [image_token_id] * ( int(grid_thw.prod()) // merge_length ) return image_token_id_total else: # When modality is video grid_t, grid_h, grid_w = grid_thw video_seq_length_per_time = (grid_h * grid_w).item() // merge_length video_token_id_per_time = ( [vision_start_token_id] + [video_token_id] * video_seq_length_per_time + [vision_end_token_id] ) video_token_id_total = video_token_id_per_time * grid_t.item() video_token_id_middle = video_token_id_total[1:-1] return PromptUpdateDetails.select_token_id( video_token_id_middle, embed_token_id=video_token_id, ) return [ PromptReplacement( modality=modality, target=[placeholder[modality]], replacement=partial( get_replacement_openpangu_vision, modality=modality ), ) for modality in ("image", "video") ] class OpenPanguVLDummyInputsBuilder(Qwen2_5_VLDummyInputsBuilder): pass @MULTIMODAL_REGISTRY.register_processor( OpenPanguVLMultiModalProcessor, info=OpenPanguVLProcessingInfo, dummy_inputs=OpenPanguVLDummyInputsBuilder, ) class OpenPanguVLForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE ): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "model.language_model.": "language_model.model.", "model.visual.": "visual.", "lm_head.": "language_model.lm_head.", "model.": "language_model.model.", } ) packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config self.config = config self.vllm_config = vllm_config quant_config = vllm_config.quant_config with self._mark_tower_model(vllm_config, {"image", "video"}): self.visual = OpenPanguVisionTransformer( vision_config=config.vision_config, out_hidden_size=config.vision_config.out_hidden_size, hidden_size=config.hidden_size, norm_eps=getattr(config.vision_config, "rms_norm_eps", 1e-6), quant_config=self._maybe_ignore_quant_config(quant_config), prefix=maybe_prefix(prefix, "visual"), ) with self._mark_language_model(vllm_config): self.language_model = init_vllm_registered_model( vllm_config=vllm_config, prefix=maybe_prefix("openpangu", "language_model"), architectures=["PanguEmbeddedForCausalLM"], ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors ) self._parse_preprocess_params(config.vision_config) def _parse_preprocess_params(self, vision_config): self.channel = vision_config.in_channels self.patch_size = vision_config.patch_size from vllm.multimodal import MULTIMODAL_REGISTRY image_processor = ( MULTIMODAL_REGISTRY.create_processor(self.vllm_config.model_config) .info.get_hf_processor() .image_processor ) self.do_rescale = image_processor.do_rescale self.rescale_factor = image_processor.rescale_factor self.do_normalize = image_processor.do_normalize self.image_mean = tuple(image_processor.image_mean) self.image_std = tuple(image_processor.image_std) def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): return None return quant_config def _validate_and_reshape_mm_tensor( self, mm_input: object, name: str ) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): if mm_input.ndim == 2: return mm_input if mm_input.ndim != 3: raise ValueError( f"{name} should be 2D or batched 3D tensor. " f"Got ndim: {mm_input.ndim} " f"(shape={mm_input.shape})" ) return torch.concat(list(mm_input)) else: return torch.concat(mm_input) def _parse_and_validate_image_input(self, **kwargs: object): pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) if pixel_values is None and image_embeds is None: return None if pixel_values is not None: pixel_values = self._validate_and_reshape_mm_tensor( pixel_values, "image pixel values" ) image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw" ) if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError( "Incorrect type of image pixel values. " f"Got type: {type(pixel_values)}" ) return OpenPanguVLImagePixelInputs( type="pixel_values", pixel_values=pixel_values, image_grid_thw=image_grid_thw, ) if image_embeds is not None: image_embeds = self._validate_and_reshape_mm_tensor( image_embeds, "image embeds" ) image_grid_thw = self._validate_and_reshape_mm_tensor( image_grid_thw, "image grid_thw" ) if not isinstance(image_embeds, torch.Tensor): raise ValueError( "Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}" ) return OpenPanguVLImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, image_grid_thw=image_grid_thw, ) def _parse_and_validate_video_input(self, **kwargs: object): pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) if pixel_values_videos is None and video_embeds is None: return None if pixel_values_videos is not None: pixel_values_videos = self._validate_and_reshape_mm_tensor( pixel_values_videos, "video pixel values" ) video_grid_thw = self._validate_and_reshape_mm_tensor( video_grid_thw, "video grid_thw" ) return OpenPanguVLVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, ) if video_embeds is not None: video_embeds = self._validate_and_reshape_mm_tensor( video_embeds, "video embeds" ) video_grid_thw = self._validate_and_reshape_mm_tensor( video_grid_thw, "video grid_thw" ) if not isinstance(video_embeds, torch.Tensor): raise ValueError( "Incorrect type of video embeddings. " f"Got type: {type(video_embeds)}" ) return OpenPanguVLVideoEmbeddingInputs( type="video_embeds", video_embeds=video_embeds, video_grid_thw=video_grid_thw, ) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} for input_key in kwargs: if ( input_key in ("pixel_values", "image_embeds") and "image" not in mm_input_by_modality ): mm_input_by_modality["image"] = self._parse_and_validate_image_input( **kwargs ) if ( input_key in ("pixel_values_videos", "video_embeds") and "video" not in mm_input_by_modality ): mm_input_by_modality["video"] = self._parse_and_validate_video_input( **kwargs ) return mm_input_by_modality def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return None multimodal_embeddings: tuple[torch.Tensor, ...] = () for modality in mm_input_by_modality: multimodal_input = mm_input_by_modality[modality] if modality == "image": vision_embeddings = self._process_image_input(multimodal_input) multimodal_embeddings = ( multimodal_embeddings if not vision_embeddings else (multimodal_embeddings + vision_embeddings) ) if modality == "video": video_embeddings = self._process_video_input(multimodal_input) multimodal_embeddings = ( multimodal_embeddings if not video_embeddings else (multimodal_embeddings + video_embeddings) ) return multimodal_embeddings def get_input_embeddings( self, input_ids: torch.Tensor, multimodal_embeddings=None, ) -> torch.Tensor: inputs_embeds = self.language_model.embed_input_ids(input_ids) if multimodal_embeddings is not None: inputs_embeds = self.embed_input_ids( input_ids, inputs_embeds, multimodal_embeddings, [self.config.image_token_id, self.config.video_token_id], ) return inputs_embeds def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]: grid_thw = image_input["image_grid_thw"] if grid_thw.ndim != 2: raise ValueError(f"grid_thw.ndim must be 2, but it is {grid_thw.ndim}") if image_input["type"] == "image_embeds": image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: pixel_values = image_input["pixel_values"].type(self.visual.dtype) # rescale and normalize pixel_values = pixel_values.reshape( -1, self.channel, self.patch_size, self.patch_size ) pixel_values = rescale_and_normalize( pixel_values, self.do_rescale, self.rescale_factor, self.do_normalize, self.image_mean, self.image_std, ) pixel_values = pixel_values.reshape( -1, self.channel * self.patch_size * self.patch_size ) image_embeds = self.visual(pixel_values, grid_thw=grid_thw) # Split concatenated embeddings for each image item. merge_size = self.visual.spatial_merge_size sizes = grid_thw.prod(-1) // merge_size // merge_size return image_embeds.split(sizes.tolist()) def _process_video_input(self, video_input) -> torch.Tensor: grid_thw = video_input["video_grid_thw"] if grid_thw.ndim != 2: raise ValueError(f"grid_thw.ndim must be 2, but it is {grid_thw.ndim}") if video_input["type"] == "video_embeds": video_embeds = video_input["video_embeds"].type(self.visual.dtype) else: pixel_values_videos = video_input["pixel_values_videos"].type( self.visual.dtype ) video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size sizes = grid_thw.prod(-1) // merge_size // merge_size return video_embeds.split(sizes.tolist()) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None hidden_states = self.language_model.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata=None, ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: """ Get the module prefix in multimodal models """ return MultiModelKeys.from_string_field( language_model="language_model", connector="visual.merger.", tower_model="visual.", ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return "[unused18][unused19][unused20]" if modality.startswith("video"): return "[unused18][unused32][unused20]" raise ValueError("Only image or video modality is supported") def iter_mm_grid_thw( self, mm_features: list[MultiModalFeatureSpec] ) -> Iterator[tuple[str, int, int, int, int]]: spatial_merge_size = self.config.vision_config.spatial_merge_size for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset): offset = mm_feature.mm_position.offset modality = mm_feature.modality if modality == "image": t, h, w = mm_feature.data["image_grid_thw"].data.tolist() assert t == 1, f"Image must have 1 frame, got {t}" yield ( modality, offset, 1, h // spatial_merge_size, w // spatial_merge_size, ) elif modality == "video": t, h, w = mm_feature.data["video_grid_thw"].data.tolist() yield ( modality, offset, t, h // spatial_merge_size, w // spatial_merge_size, ) else: raise ValueError(f"Unsupported modality: {modality}") def get_mrope_input_positions( self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: llm_pos_ids_list: list = [] st = 0 for ( modality, offset, llm_grid_t, llm_grid_h, llm_grid_w, ) in self.iter_mm_grid_thw(mm_features): text_len = offset - st st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 llm_pos_ids_list.append( torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx ) if modality == "video": eot_bot_pos = torch.full((3, 1), 0, dtype=torch.long) offset_pos = max(llm_grid_h, llm_grid_w) current_pos = text_len + st_idx grid_h = ( torch.arange(llm_grid_h) .view(-1, 1) .expand(-1, llm_grid_w) .flatten() ) grid_w = ( torch.arange(llm_grid_w) .view(1, -1) .expand(llm_grid_h, -1) .flatten() ) frame_pos = torch.stack( [ torch.full_like(grid_h, 0, dtype=torch.long), # t grid_h, # h grid_w, # w ] ) llm_pos_ids_list.append(frame_pos + current_pos) for _ in range(llm_grid_t - 1): current_pos = current_pos + offset_pos llm_pos_ids_list.append(eot_bot_pos + current_pos) llm_pos_ids_list.append(eot_bot_pos + current_pos + 1) llm_pos_ids_list.append(frame_pos + current_pos + 2) current_pos += 2 st = ( offset + llm_grid_t * llm_grid_h * llm_grid_w + (llm_grid_t - 1) * 2 ) else: t_index = ( ( torch.arange(llm_grid_t) .view(-1, 1) .expand(-1, llm_grid_h * llm_grid_w) ) .long() .flatten() ) h_index = ( torch.arange(llm_grid_h) .view(1, -1, 1) .expand(llm_grid_t, -1, llm_grid_w) .flatten() ) w_index = ( torch.arange(llm_grid_w) .view(1, 1, -1) .expand(llm_grid_t, llm_grid_h, -1) .flatten() ) llm_pos_ids_list.append( torch.stack([t_index, h_index, w_index]) + text_len + st_idx ) st = offset + llm_grid_t * llm_grid_h * llm_grid_w if st < len(input_tokens): st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st llm_pos_ids_list.append( torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx ) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() return llm_positions, mrope_position_delta def rescale(image, scale): return image * scale def normalize(image, mean, std): return v2.functional.normalize(image, mean, std) @lru_cache(maxsize=10) def _fuse_mean_std_and_rescale_factor( do_normalize: bool | None = None, image_mean: float | list[float] | None = None, image_std: float | list[float] | None = None, do_rescale: bool | None = None, rescale_factor: float | None = None, device: Optional["torch.device"] = None, ) -> tuple: if do_rescale and do_normalize: # Fused rescale and normalize image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor) image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor) do_rescale = False return image_mean, image_std, do_rescale def rescale_and_normalize( images: "torch.Tensor", do_rescale: bool, rescale_factor: float, do_normalize: bool, image_mean: float | list[float], image_std: float | list[float], dtype: torch.dtype = torch.bfloat16, ) -> "torch.Tensor": """ Rescale and normalize images. """ image_mean, image_std, do_rescale = _fuse_mean_std_and_rescale_factor( do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, do_rescale=do_rescale, rescale_factor=rescale_factor, device=images.device, ) # if/elif as we use fused rescale and normalize if both are set to True if do_normalize: images = normalize(images.to(dtype=torch.float32), image_mean, image_std) elif do_rescale: images = rescale(images, rescale_factor) images = images.to(dtype) return images