################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ # SPDX-License-Identifier: Apache-2.0 # Adapted from # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py # Copyright 2025 The vLLM team. # Copyright 2025 The Qwen Team. # Copyright 2025 The HuggingFace Inc. team. # All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" import math from collections.abc import Iterable from functools import partial from typing import Callable, Optional import torch import torch.nn as nn import torch.nn.functional as F import torch_br from einops import rearrange from fastcore.basics import patch_to import vllm from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.qwen2_5_vl import (Qwen2_5_VisionBlock, Qwen2_5_VisionMLP, Qwen2_5_VisionPatchMerger, Qwen2_5_VisionTransformer) from vllm.model_executor.models.qwen2_vl import apply_rotary_pos_emb_vision from vllm.model_executor.models.utils import cast_overflow_tensors from vllm.platforms import _Backend from vllm_br import envs from .br_utils import convBB, convSB def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): """All-gather the input tensor interleavely across model parallel group.""" import torch.distributed as dist gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)] dist.all_gather(gathered_tensors, local_tensor, group=parallel_state.get_tp_group().device_group) gathered_tensors_split = [ torch.split(tensor, hidden_size // tp_size, -1) for tensor in gathered_tensors ] ordered_tensors = [ tensor for pair in zip(*gathered_tensors_split, strict=False) for tensor in pair ] result_tensor = torch.cat(ordered_tensors, dim=-1) return result_tensor class Qwen2_5_VisionAttention_fit(nn.Module): def __init__( self, embed_dim: int, num_heads: int, projection_size: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", use_data_parallel: bool = False, attn_backend: _Backend = _Backend.TORCH_SDPA, use_upstream_fa: bool = False, ) -> None: super().__init__() # Per attention head and per partition values. self.tp_size = (1 if use_data_parallel else parallel_state.get_tensor_model_parallel_world_size()) self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.hidden_size_per_attention_head = dist_utils.divide( projection_size, num_heads) self.num_attention_heads_per_partition = dist_utils.divide( num_heads, self.tp_size) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) self.qkv = QKVParallelLinear( hidden_size=embed_dim, head_size=self.hidden_size_per_attention_head, total_num_heads=num_heads, total_num_kv_heads=num_heads, bias=True, quant_config=quant_config, prefix=f"{prefix}.qkv") self.proj = RowParallelLinear(input_size=projection_size, output_size=embed_dim, quant_config=quant_config, prefix=f"{prefix}.proj") def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] seq_len, bs, width = qkv.shape qkv = qkv.reshape(-1, width) if self.tp_size > 1: qkv = all_gather_interleave(qkv, self.qkv.hidden_size, self.tp_size) # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] q, k, v = qkv.chunk(3, dim=-1) # 3 * [s, b, head * head_dim] if self.tp_size > 1: splitter = partial(dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size) q = splitter(q)[self.tp_rank] k = splitter(k)[self.tp_rank] v = splitter(v)[self.tp_rank] # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] new_shape = (seq_len, bs, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) q, k, v = (x.view(*new_shape) for x in (q, k, v)) return q, k, v def transform_qkv_shape(self, qkv_layer, cur_qkv_shape_state, obj_qkv_shape_state, obj_shape=None): if obj_qkv_shape_state == "bn_s_h": if cur_qkv_shape_state == "bn_s_h": return qkv_layer if cur_qkv_shape_state == "b_s_n_h": # [b, sq, np or nkvp, hn] --> [b, np or nkvp, sq, hn] --> [b*(np or nkvp), sq, hn] qkv_layer = qkv_layer.permute(0, 2, 1, 3) # view 4d matrix to 3d matrix, TODO: use fused_split_view here qkv_layer = qkv_layer.reshape(-1, qkv_layer.size(2), qkv_layer.size(3)).contiguous() return qkv_layer if cur_qkv_shape_state == "b_n_s_h": qkv_layer = qkv_layer.reshape(-1, qkv_layer.size(2), qkv_layer.size(3)) return qkv_layer if obj_qkv_shape_state == "b_n_s_h": if cur_qkv_shape_state == "b_n_s_h": return qkv_layer if cur_qkv_shape_state == "bn_s_h": qkv_layer = qkv_layer.reshape(obj_shape[0], -1, qkv_layer.size(1), qkv_layer.size(2)) return qkv_layer if cur_qkv_shape_state == "b_s_n_h": qkv_layer = qkv_layer.permute(0, 2, 1, 3).contiguous() return qkv_layer if obj_qkv_shape_state == "b_s_n_h": if cur_qkv_shape_state == "b_s_n_h": return qkv_layer if cur_qkv_shape_state == "b_n_s_h": qkv_layer = qkv_layer.permute(0, 2, 1, 3).contiguous() return qkv_layer if cur_qkv_shape_state == "bn_s_h": qkv_layer = qkv_layer.reshape(obj_shape[0], -1, qkv_layer.size(1), qkv_layer.size(2)) qkv_layer = qkv_layer.permute(0, 2, 1, 3).contiguous() return qkv_layer AssertionError( f"unsupported shape transform, ori:{cur_qkv_shape_state} obj:{obj_qkv_shape_state}" ) def forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, max_seqlen: Optional[int] = None, # Only used for Flash Attention seqlens: Optional[list[int]] = None, # Only used for xFormers mask: torch.Tensor = None, ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) if envs.VLLM_BR_DEVICE_SPC_NUM > 16: x = convBB(x) seql = x.shape[-2] x = x.reshape(seql, 2, 3, -1).permute(0, 2, 1, 3).contiguous().reshape(1, seql, -1) if x.shape[0] == 1: x = x.permute(1, 0, 2).contiguous() # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] q, k, v = self.split_qkv(x) q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb is not None: q = apply_rotary_pos_emb_vision( q, rotary_pos_emb, ) k = apply_rotary_pos_emb_vision( k, rotary_pos_emb, ) # q, k, v: [b, s, n, h] -> reshape: [b, n, s, h] -> reshape: [b * n, s, h] q = q.permute(0, 2, 1, 3).contiguous() k = k.permute(0, 2, 1, 3).contiguous() v = v.permute(0, 2, 1, 3).contiguous() q = self.transform_qkv_shape(q, "b_n_s_h", "bn_s_h") k = self.transform_qkv_shape(k, "b_n_s_h", "bn_s_h") v = self.transform_qkv_shape(v, "b_n_s_h", "bn_s_h") #TODO(qingqi), skip sueager bug, when sueager op fix the bug,remove the code if q.shape[1] == 8192 or q.shape[1] == 8424 or q.shape[1] == 8464: mask = mask.to(torch.bfloat16) context_layer, _ = torch_br.sueager_scaled_dot_product_attention_fwd( query=q, key=k, value=v, mask=mask, dropout_prob=0.0, is_causal=False, scale=1 / self.norm_factor, algorithm="FMHA", ) # reshape attn out: [b*n, s, h] -> [s, b, h*n] context_layer = torch_br.supa_shape_transform_qkv( context_layer, 1, context_layer.shape[-2], self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, False, False, None) if context_layer.shape[0] != 1: context_layer = context_layer.permute(1, 0, 2).contiguous() if envs.VLLM_BR_DEVICE_SPC_NUM > 16: context_layer = convSB(context_layer, -1) output, _ = self.proj(context_layer) return output def vision_block_forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, max_seqlen: Optional[int] = None, # Only used for Flash Attention seqlens: Optional[list[int]] = None, # Only used for xFormers mask: torch.Tensor = None, ) -> torch.Tensor: if x.shape[0] != 1: x = x.permute(1, 0, 2).contiguous() x = x + self.attn(self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, max_seqlen=max_seqlen, seqlens=seqlens, mask=mask) x = x + self.mlp(self.norm2(x)) return x class Qwen2_5_VisionPatchEmbed_fit(nn.Module): def __init__( self, patch_size: int = 14, temporal_patch_size: int = 2, in_channels: int = 3, hidden_size: int = 1152, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.patch_size = patch_size self.temporal_patch_size = temporal_patch_size self.hidden_size = hidden_size self.proj = ColumnParallelLinear(in_channels * temporal_patch_size * patch_size * patch_size, hidden_size, bias=False, gather_output=True, quant_config=quant_config, prefix="") def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.unsqueeze(0) L, _ = x.shape[-2], x.shape[-1] x = self.proj(x)[0].view(L, self.hidden_size) return x @patch_to(vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionTransformer) def gen_normal_mask(self, cu_seqlens, grid_thw, device): # NOTE: for mask-mock-pack, we precompute mask and store in PackedSeqParams seq_len = max(cu_seqlens) attention_mask = torch.full([1, seq_len, seq_len], 1, dtype=torch.int32, device=device) for i in range(1, len(cu_seqlens)): attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i], cu_seqlens[i - 1]:cu_seqlens[i]] = 0 return attention_mask def vision_transformer_forward( self, x: torch.Tensor, grid_thw: list[list[int]], ) -> torch.Tensor: # patchify seq_len, _ = x.size() rotary_pos_emb_list = [] window_index_list: list = [] cu_window_seqlens_list: list = [ torch.tensor([0], dtype=torch.int32, device="cpu") ] cu_seqlens_list: list = [] hidden_states = x.to(device=self.device, dtype=self.dtype) hidden_states = self.patch_embed(hidden_states) window_index_id = 0 cu_window_seqlens_last = 0 for t, h, w in grid_thw: t, h, w = int(t), int(h), int(w) llm_h = h // self.spatial_merge_size llm_w = w // self.spatial_merge_size ( rotary_pos_emb_thw, window_index_thw, cu_seqlens_window_thw, cu_seqlens_thw, ) = self.get_rope_by_thw(t, h, w) window_index_list.append(window_index_thw + window_index_id) window_index_id += (t * llm_h * llm_w) cu_seqlens_window_thw = (cu_seqlens_window_thw + cu_window_seqlens_last) cu_window_seqlens_last = cu_seqlens_window_thw[-1] cu_window_seqlens_list.append(cu_seqlens_window_thw) rotary_pos_emb_list.append(rotary_pos_emb_thw) cu_seqlens_list.append(cu_seqlens_thw) rotary_pos_emb = torch.cat(rotary_pos_emb_list) window_index = torch.cat(window_index_list) cu_window_seqlens = torch.cat(cu_window_seqlens_list) cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) cu_seqlens = torch.cat(cu_seqlens_list) cu_seqlens = torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) # transformers # pre-compute seqlens for window/full attn to reduce cuMemcpy operations max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen(cu_seqlens) max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen( cu_window_seqlens) cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True) cu_window_seqlens = cu_window_seqlens.to(device=self.device, non_blocking=True) rotary_pos_emb = rotary_pos_emb.to(device=self.device, non_blocking=True) window_index = window_index.to(device=hidden_states.device, non_blocking=True) hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) hidden_states = hidden_states.unsqueeze(1) attention_mask = self.gen_normal_mask(cu_seqlens, grid_thw, x.device) for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: cu_seqlens_now = cu_seqlens max_seqlen_now = max_seqlen_full seqlens_now = seqlens_full else: cu_seqlens_now = cu_window_seqlens max_seqlen_now = max_seqlen_window seqlens_now = seqlens_window hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb, max_seqlen=max_seqlen_now, seqlens=seqlens_now, mask=attention_mask) # For Qwen2.5-VL-3B, float16 will overflow at last block # for long visual tokens sequences. if hidden_states.dtype == torch.float16: hidden_states = cast_overflow_tensors(hidden_states) # adapter hidden_states = self.merger(hidden_states).squeeze(0) reverse_indices = torch.argsort(window_index) hidden_states = hidden_states[reverse_indices, :] return hidden_states def vision_transformer_load_weights( self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("attn.qkv.", "attn.q.", "q"), ("attn.qkv.", "attn.k.", "k"), ("attn.qkv.", "attn.v.", "v"), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) if name == 'patch_embed.proj.weight': loaded_weight = loaded_weight.reshape(loaded_weight.shape[0], -1).contiguous() weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params def Qwen2_5_VisionPatchMerger_forward_fit(self, x: torch.Tensor) -> torch.Tensor: x = self.ln_q(x) x = x.view(-1, self.hidden_size).unsqueeze(0) out = self.mlp(x) return out def Qwen2_5_VisionMLP__init__( self, in_features: int, hidden_features: int, bias: bool = False, act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", use_data_parallel: bool = False): super(Qwen2_5_VisionMLP, self).__init__() self.gate_proj = ColumnParallelLinear(in_features, hidden_features, bias=bias, quant_config=quant_config, prefix=f"{prefix}.gate_proj") self.up_proj = ColumnParallelLinear(in_features, hidden_features, bias=bias, quant_config=quant_config, prefix=f"{prefix}.up_proj") self.down_proj = RowParallelLinear(hidden_features, in_features, bias=bias, quant_config=quant_config, prefix=f"{prefix}.down_proj", disable_tp=use_data_parallel) self.act_fn = F.silu def Qwen2_5_VisionMLP_forward(self, x: torch.Tensor): x_gate, _ = self.gate_proj(x) x_gate = self.act_fn(x_gate) x_up, _ = self.up_proj(x) x_down, _ = self.down_proj(x_gate * x_up) return x_down vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionAttention = Qwen2_5_VisionAttention_fit vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionPatchEmbed = Qwen2_5_VisionPatchEmbed_fit Qwen2_5_VisionBlock.forward = vision_block_forward Qwen2_5_VisionTransformer.forward = vision_transformer_forward Qwen2_5_VisionTransformer.load_weights = vision_transformer_load_weights Qwen2_5_VisionPatchMerger.forward = Qwen2_5_VisionPatchMerger_forward_fit Qwen2_5_VisionMLP.__init__ = Qwen2_5_VisionMLP__init__ Qwen2_5_VisionMLP.forward = Qwen2_5_VisionMLP_forward