Files
enginex-biren-vllm/vllm_br/model_executor/models/qwen2_5_vl.py
2026-03-10 13:31:25 +08:00

531 lines
21 KiB
Python

################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
# Copyright 2025 The vLLM team.
# Copyright 2025 The Qwen Team.
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
import math
from collections.abc import Iterable
from functools import partial
from typing import Callable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_br
from einops import rearrange
from fastcore.basics import patch_to
import vllm
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen2_5_vl import (Qwen2_5_VisionBlock,
Qwen2_5_VisionMLP,
Qwen2_5_VisionPatchMerger,
Qwen2_5_VisionTransformer)
from vllm.model_executor.models.qwen2_vl import apply_rotary_pos_emb_vision
from vllm.model_executor.models.utils import cast_overflow_tensors
from vllm.platforms import _Backend
from vllm_br import envs
from .br_utils import convBB, convSB
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
"""All-gather the input tensor interleavely across model parallel group."""
import torch.distributed as dist
gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
dist.all_gather(gathered_tensors,
local_tensor,
group=parallel_state.get_tp_group().device_group)
gathered_tensors_split = [
torch.split(tensor, hidden_size // tp_size, -1)
for tensor in gathered_tensors
]
ordered_tensors = [
tensor for pair in zip(*gathered_tensors_split, strict=False)
for tensor in pair
]
result_tensor = torch.cat(ordered_tensors, dim=-1)
return result_tensor
class Qwen2_5_VisionAttention_fit(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
projection_size: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA,
use_upstream_fa: bool = False,
) -> None:
super().__init__()
# Per attention head and per partition values.
self.tp_size = (1 if use_data_parallel else
parallel_state.get_tensor_model_parallel_world_size())
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, self.tp_size)
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv")
self.proj = RowParallelLinear(input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj")
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, width = qkv.shape
qkv = qkv.reshape(-1, width)
if self.tp_size > 1:
qkv = all_gather_interleave(qkv, self.qkv.hidden_size,
self.tp_size)
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
q, k, v = qkv.chunk(3, dim=-1)
# 3 * [s, b, head * head_dim]
if self.tp_size > 1:
splitter = partial(dist_utils.split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
v = splitter(v)[self.tp_rank]
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
q, k, v = (x.view(*new_shape) for x in (q, k, v))
return q, k, v
def transform_qkv_shape(self,
qkv_layer,
cur_qkv_shape_state,
obj_qkv_shape_state,
obj_shape=None):
if obj_qkv_shape_state == "bn_s_h":
if cur_qkv_shape_state == "bn_s_h":
return qkv_layer
if cur_qkv_shape_state == "b_s_n_h":
# [b, sq, np or nkvp, hn] --> [b, np or nkvp, sq, hn] --> [b*(np or nkvp), sq, hn]
qkv_layer = qkv_layer.permute(0, 2, 1, 3)
# view 4d matrix to 3d matrix, TODO: use fused_split_view here
qkv_layer = qkv_layer.reshape(-1, qkv_layer.size(2),
qkv_layer.size(3)).contiguous()
return qkv_layer
if cur_qkv_shape_state == "b_n_s_h":
qkv_layer = qkv_layer.reshape(-1, qkv_layer.size(2),
qkv_layer.size(3))
return qkv_layer
if obj_qkv_shape_state == "b_n_s_h":
if cur_qkv_shape_state == "b_n_s_h":
return qkv_layer
if cur_qkv_shape_state == "bn_s_h":
qkv_layer = qkv_layer.reshape(obj_shape[0], -1,
qkv_layer.size(1),
qkv_layer.size(2))
return qkv_layer
if cur_qkv_shape_state == "b_s_n_h":
qkv_layer = qkv_layer.permute(0, 2, 1, 3).contiguous()
return qkv_layer
if obj_qkv_shape_state == "b_s_n_h":
if cur_qkv_shape_state == "b_s_n_h":
return qkv_layer
if cur_qkv_shape_state == "b_n_s_h":
qkv_layer = qkv_layer.permute(0, 2, 1, 3).contiguous()
return qkv_layer
if cur_qkv_shape_state == "bn_s_h":
qkv_layer = qkv_layer.reshape(obj_shape[0], -1,
qkv_layer.size(1),
qkv_layer.size(2))
qkv_layer = qkv_layer.permute(0, 2, 1, 3).contiguous()
return qkv_layer
AssertionError(
f"unsupported shape transform, ori:{cur_qkv_shape_state} obj:{obj_qkv_shape_state}"
)
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: Optional[int] = None, # Only used for Flash Attention
seqlens: Optional[list[int]] = None, # Only used for xFormers
mask: torch.Tensor = None,
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
if envs.VLLM_BR_DEVICE_SPC_NUM > 16:
x = convBB(x)
seql = x.shape[-2]
x = x.reshape(seql, 2, 3,
-1).permute(0, 2, 1,
3).contiguous().reshape(1, seql, -1)
if x.shape[0] == 1:
x = x.permute(1, 0, 2).contiguous()
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
q, k, v = self.split_qkv(x)
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
for x in (q, k, v))
if rotary_pos_emb is not None:
q = apply_rotary_pos_emb_vision(
q,
rotary_pos_emb,
)
k = apply_rotary_pos_emb_vision(
k,
rotary_pos_emb,
)
# q, k, v: [b, s, n, h] -> reshape: [b, n, s, h] -> reshape: [b * n, s, h]
q = q.permute(0, 2, 1, 3).contiguous()
k = k.permute(0, 2, 1, 3).contiguous()
v = v.permute(0, 2, 1, 3).contiguous()
q = self.transform_qkv_shape(q, "b_n_s_h", "bn_s_h")
k = self.transform_qkv_shape(k, "b_n_s_h", "bn_s_h")
v = self.transform_qkv_shape(v, "b_n_s_h", "bn_s_h")
#TODO(qingqi), skip sueager bug, when sueager op fix the bug,remove the code
if q.shape[1] == 8192 or q.shape[1] == 8424 or q.shape[1] == 8464:
mask = mask.to(torch.bfloat16)
context_layer, _ = torch_br.sueager_scaled_dot_product_attention_fwd(
query=q,
key=k,
value=v,
mask=mask,
dropout_prob=0.0,
is_causal=False,
scale=1 / self.norm_factor,
algorithm="FMHA",
)
# reshape attn out: [b*n, s, h] -> [s, b, h*n]
context_layer = torch_br.supa_shape_transform_qkv(
context_layer, 1, context_layer.shape[-2],
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, False, False, None)
if context_layer.shape[0] != 1:
context_layer = context_layer.permute(1, 0, 2).contiguous()
if envs.VLLM_BR_DEVICE_SPC_NUM > 16:
context_layer = convSB(context_layer, -1)
output, _ = self.proj(context_layer)
return output
def vision_block_forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: Optional[int] = None, # Only used for Flash Attention
seqlens: Optional[list[int]] = None, # Only used for xFormers
mask: torch.Tensor = None,
) -> torch.Tensor:
if x.shape[0] != 1:
x = x.permute(1, 0, 2).contiguous()
x = x + self.attn(self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
mask=mask)
x = x + self.mlp(self.norm2(x))
return x
class Qwen2_5_VisionPatchEmbed_fit(nn.Module):
def __init__(
self,
patch_size: int = 14,
temporal_patch_size: int = 2,
in_channels: int = 3,
hidden_size: int = 1152,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.hidden_size = hidden_size
self.proj = ColumnParallelLinear(in_channels * temporal_patch_size *
patch_size * patch_size,
hidden_size,
bias=False,
gather_output=True,
quant_config=quant_config,
prefix="")
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.unsqueeze(0)
L, _ = x.shape[-2], x.shape[-1]
x = self.proj(x)[0].view(L, self.hidden_size)
return x
@patch_to(vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionTransformer)
def gen_normal_mask(self, cu_seqlens, grid_thw, device):
# NOTE: for mask-mock-pack, we precompute mask and store in PackedSeqParams
seq_len = max(cu_seqlens)
attention_mask = torch.full([1, seq_len, seq_len],
1,
dtype=torch.int32,
device=device)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
cu_seqlens[i - 1]:cu_seqlens[i]] = 0
return attention_mask
def vision_transformer_forward(
self,
x: torch.Tensor,
grid_thw: list[list[int]],
) -> torch.Tensor:
# patchify
seq_len, _ = x.size()
rotary_pos_emb_list = []
window_index_list: list = []
cu_window_seqlens_list: list = [
torch.tensor([0], dtype=torch.int32, device="cpu")
]
cu_seqlens_list: list = []
hidden_states = x.to(device=self.device, dtype=self.dtype)
hidden_states = self.patch_embed(hidden_states)
window_index_id = 0
cu_window_seqlens_last = 0
for t, h, w in grid_thw:
t, h, w = int(t), int(h), int(w)
llm_h = h // self.spatial_merge_size
llm_w = w // self.spatial_merge_size
(
rotary_pos_emb_thw,
window_index_thw,
cu_seqlens_window_thw,
cu_seqlens_thw,
) = self.get_rope_by_thw(t, h, w)
window_index_list.append(window_index_thw + window_index_id)
window_index_id += (t * llm_h * llm_w)
cu_seqlens_window_thw = (cu_seqlens_window_thw +
cu_window_seqlens_last)
cu_window_seqlens_last = cu_seqlens_window_thw[-1]
cu_window_seqlens_list.append(cu_seqlens_window_thw)
rotary_pos_emb_list.append(rotary_pos_emb_thw)
cu_seqlens_list.append(cu_seqlens_thw)
rotary_pos_emb = torch.cat(rotary_pos_emb_list)
window_index = torch.cat(window_index_list)
cu_window_seqlens = torch.cat(cu_window_seqlens_list)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
cu_seqlens = torch.cat(cu_seqlens_list)
cu_seqlens = torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
# transformers
# pre-compute seqlens for window/full attn to reduce cuMemcpy operations
max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen(cu_seqlens)
max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen(
cu_window_seqlens)
cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True)
cu_window_seqlens = cu_window_seqlens.to(device=self.device,
non_blocking=True)
rotary_pos_emb = rotary_pos_emb.to(device=self.device, non_blocking=True)
window_index = window_index.to(device=hidden_states.device,
non_blocking=True)
hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit,
self.spatial_merge_unit, -1)
hidden_states = hidden_states[window_index, :, :]
hidden_states = hidden_states.reshape(seq_len, -1)
hidden_states = hidden_states.unsqueeze(1)
attention_mask = self.gen_normal_mask(cu_seqlens, grid_thw, x.device)
for layer_num, blk in enumerate(self.blocks):
if layer_num in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
max_seqlen_now = max_seqlen_full
seqlens_now = seqlens_full
else:
cu_seqlens_now = cu_window_seqlens
max_seqlen_now = max_seqlen_window
seqlens_now = seqlens_window
hidden_states = blk(hidden_states,
cu_seqlens=cu_seqlens_now,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen_now,
seqlens=seqlens_now,
mask=attention_mask)
# For Qwen2.5-VL-3B, float16 will overflow at last block
# for long visual tokens sequences.
if hidden_states.dtype == torch.float16:
hidden_states = cast_overflow_tensors(hidden_states)
# adapter
hidden_states = self.merger(hidden_states).squeeze(0)
reverse_indices = torch.argsort(window_index)
hidden_states = hidden_states[reverse_indices, :]
return hidden_states
def vision_transformer_load_weights(
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("attn.qkv.", "attn.q.", "q"),
("attn.qkv.", "attn.k.", "k"),
("attn.qkv.", "attn.v.", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if name == 'patch_embed.proj.weight':
loaded_weight = loaded_weight.reshape(loaded_weight.shape[0],
-1).contiguous()
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def Qwen2_5_VisionPatchMerger_forward_fit(self,
x: torch.Tensor) -> torch.Tensor:
x = self.ln_q(x)
x = x.view(-1, self.hidden_size).unsqueeze(0)
out = self.mlp(x)
return out
def Qwen2_5_VisionMLP__init__(
self,
in_features: int,
hidden_features: int,
bias: bool = False,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False):
super(Qwen2_5_VisionMLP, self).__init__()
self.gate_proj = ColumnParallelLinear(in_features,
hidden_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_proj")
self.up_proj = ColumnParallelLinear(in_features,
hidden_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.up_proj")
self.down_proj = RowParallelLinear(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
disable_tp=use_data_parallel)
self.act_fn = F.silu
def Qwen2_5_VisionMLP_forward(self, x: torch.Tensor):
x_gate, _ = self.gate_proj(x)
x_gate = self.act_fn(x_gate)
x_up, _ = self.up_proj(x)
x_down, _ = self.down_proj(x_gate * x_up)
return x_down
vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionAttention = Qwen2_5_VisionAttention_fit
vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionPatchEmbed = Qwen2_5_VisionPatchEmbed_fit
Qwen2_5_VisionBlock.forward = vision_block_forward
Qwen2_5_VisionTransformer.forward = vision_transformer_forward
Qwen2_5_VisionTransformer.load_weights = vision_transformer_load_weights
Qwen2_5_VisionPatchMerger.forward = Qwen2_5_VisionPatchMerger_forward_fit
Qwen2_5_VisionMLP.__init__ = Qwen2_5_VisionMLP__init__
Qwen2_5_VisionMLP.forward = Qwen2_5_VisionMLP_forward