"""Inference-only Qwen2VL model compatible with HuggingFace weights.""" import torch import torch.nn as nn from typing import Optional from vllm.attention.layer import check_upstream_fa_availability from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce, parallel_state from vllm.distributed import utils as dist_utils from vllm.model_executor.models.qwen2_vl import Qwen2VisionAttention as Qwen2VisionAttentionOrg from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.platforms import _Backend from vllm.logger import init_logger from .hf_processor.qwenvl_processor import Qwen2VLProcessorWithVacc from .hf_processor.qwen2vl_image_processor import Qwen2VLImageProcessorFastWithVacc from vllm.distributed import (get_pp_group, get_ep_group, get_tp_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank, tensor_model_parallel_all_reduce) from vllm_vacc.vllm.model_executor.models.vars import USE_FUSED_QWEN_ATTENTION logger = init_logger(__name__) class Qwen2VisionPatchEmbed(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: if hasattr(self.proj, 'bias') and self.proj.bias is not None: return torch.nn.functional.linear(x, self.proj.weight.view(self.hidden_size, -1), self.proj.bias) return torch.matmul(x, self.proj.weight.view(self.embed_dim, -1).T) class Qwen2VLProcessingInfo(): def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessorWithVacc: return self.ctx.get_hf_processor( Qwen2VLProcessorWithVacc, use_fast=kwargs.pop("use_fast", True), **kwargs, ) def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessorFastWithVacc: return self.get_hf_processor(**kwargs).image_processor import torch.nn.functional as F class Qwen2VisionTransformer(nn.Module): def forward( self, x: torch.Tensor, grid_thw: list[list[int]], ) -> torch.Tensor: # patchify x = x.to(device=self.device, dtype=self.dtype) x = self.patch_embed(x) # compute position embedding if USE_FUSED_QWEN_ATTENTION: try: from torch_vacc.vacc.custom_qwen3_ops import rot_pos_emb_qwenvl sin_cache, cos_cache = rot_pos_emb_qwenvl(grid_thw, self.embed_dim, self.num_heads, self.spatial_merge_size, self.dtype, self.device) except Exception as e: logger.error(f"rot_pos_emb fused ops run fail, e:{e}") rotary_pos_emb = None else: rotary_pos_emb = self.rot_pos_emb(grid_thw) sin_cache, cos_cache = None, None # tmp_rotary_pos_emb = self.transformer_rot_pos_emb(grid_thw) # qwen3_rotary_pos_emb = self.qwen3_rot_pos_emb(grid_thw) # compute cu_seqlens grid_thw_ = torch.tensor(grid_thw) cu_seqlens = torch.repeat_interleave(grid_thw_[:, 1] * grid_thw_[:, 2], grid_thw_[:, 0]).cumsum( dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) # transformers x = x.unsqueeze(1) # pre-compute seqlens for attn mask to reduce cuMemcpy operations if USE_FUSED_QWEN_ATTENTION: cu_seqlens = cu_seqlens.tolist() max_seqlen, seqlens = None, None else: max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) for blk in self.blocks: x = blk( x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, sin_cache=sin_cache, cos_cache=cos_cache, max_seqlen=max_seqlen, seqlens=seqlens, ) # adapter x = self.merger(x) return x class Qwen2VisionAttention(nn.Module): def __init__( self, embed_dim: int, num_heads: int, projection_size: int, quant_config: Optional["QuantizationConfig"] = None, prefix: str = "", use_data_parallel: bool = False, ) -> None: super(Qwen2VisionAttentionOrg, self).__init__() # Per attention head and per partition values. self.tp_size = (1 if use_data_parallel else parallel_state.get_tensor_model_parallel_world_size()) self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.hidden_size_per_attention_head = dist_utils.divide( projection_size, num_heads) self.num_attention_heads_per_partition = dist_utils.divide( num_heads, self.tp_size) # self.qkv = ColumnParallelLinear(input_size=embed_dim, # output_size=3 * projection_size, # quant_config=quant_config, # prefix=f"{prefix}.qkv", # disable_tp=use_data_parallel) self.qkv = QKVParallelLinear( hidden_size=embed_dim, head_size=self.hidden_size_per_attention_head, total_num_heads=num_heads, total_num_kv_heads=num_heads, bias=True, quant_config=quant_config, prefix=f"{prefix}.qkv", disable_tp=use_data_parallel) self.proj = RowParallelLinear(input_size=projection_size, output_size=embed_dim, quant_config=quant_config, prefix=f"{prefix}.proj", disable_tp=use_data_parallel) # Detect attention implementation. self.attn_backend = get_vit_attn_backend( head_size=self.hidden_size_per_attention_head, dtype=torch.get_default_dtype()) self.use_upstream_fa = False if self.attn_backend != _Backend.FLASH_ATTN and \ check_upstream_fa_availability( torch.get_default_dtype()): self.attn_backend = _Backend.FLASH_ATTN self.use_upstream_fa = True if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.ROCM_AITER_FA }: raise RuntimeError( f"Qwen2-VL does not support {self.attn_backend} backend now.") self.is_flash_attn_backend = self.attn_backend in { _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] seq_len, bs, _ = qkv.shape new_shape = (seq_len, bs, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) q1, k1, v1 = qkv.chunk(3, dim=-1) q1, k1, v1 = (x.view(*new_shape) for x in (q1, k1, v1)) return q1, k1, v1 class Qwen2VisionBlock(nn.Module): def forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, sin_cache: torch.Tensor, cos_cache: torch.Tensor, max_seqlen: Optional[int] = None, # Only used for Flash Attention seqlens: Optional[list[int]] = None, # Only used for xFormers ) -> torch.Tensor: if USE_FUSED_QWEN_ATTENTION: total_bytes = x.numel() * x.element_size() * get_tp_group().world_size reduce_result = get_tp_group().world_size > 1 and total_bytes < 4194304 # hidden_states = self.norm1(x) attn_outs = torch.vacc.fuse_atten_vit( hidden_states=x.view(-1, x.shape[-1]), hidden_states_norm_weight = self.norm1.weight, hidden_states_norm_bias = self.norm1.bias, # hidden_states_norm_weight = torch.Tensor(), # hidden_states_norm_bias = torch.Tensor(), qkv_proj_weight=self.attn.qkv.weight, qkv_proj_bias=self.attn.qkv.bias, sin_cache=sin_cache, cos_cache=cos_cache, o_proj_weight=self.attn.proj.weight, o_proj_bias=self.attn.proj.bias if self.attn.proj.tp_rank == 0 else torch.Tensor(), seq_lens=cu_seqlens, sm_scale=-1, num_attention_heads=self.attn.num_attention_heads_per_partition * get_tp_group().world_size, flash_attention=True, reduce_result=reduce_result, world_size=get_tp_group().world_size, rank=get_tp_group().rank_in_group, group_id=get_tp_group().group_id, dev_info=get_tp_group().rank_device_infos ) attn_out = attn_outs[0] if reduce_result else tensor_model_parallel_all_reduce(attn_outs[0]) attn_out = attn_out.view(x.shape) x = x + attn_out else: x = x + self.attn( self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, max_seqlen=max_seqlen, seqlens=seqlens, ) x = x + self.mlp(self.norm2(x)) return x class Qwen2VisionMLP(): def forward(self, x: torch.Tensor): try: from torch_vacc.vacc import fuse_mlp_vision hiddens_shape = x.shape tp_rank_id = get_tp_group().rank_in_group fc2_bias = None if tp_rank_id > 0 else self.fc2.bias hidden_states = fuse_mlp_vision(x.view(-1, hiddens_shape[-1]), self.fc1.weight, # nk self.fc2.weight, # nk self.fc1.bias, fc2_bias, 2) # 0 is gelu, 1 is relu, 2 is quick_gelu vacc_res = tensor_model_parallel_all_reduce(hidden_states).view(hiddens_shape) return vacc_res except Exception as e: logger.error(f"mlp fused ops run fail, e:{e}") x_parallel, _ = self.fc1(x) x_parallel = self.act(x_parallel) x, _ = self.fc2(x_parallel) return x class Qwen2VisionPatchMerger(): def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.ln_q(x) x = x.view(-1, self.hidden_size) mlp_fc1, mlp_act, mlp_fc2 = self.mlp try: from torch_vacc.vacc import patch_merger_vision tp_rank_id = get_tp_group().rank_in_group fc2_bias = None if tp_rank_id > 0 else mlp_fc2.bias hidden_states = patch_merger_vision(x, mlp_fc1.weight, mlp_fc2.weight, mlp_fc1.bias, fc2_bias, 0) #0 is gelu, 1 is silu vacc_res = tensor_model_parallel_all_reduce(hidden_states) return vacc_res except Exception as e: logger.error(f"merge patch fused vision mlp run fail, cased by:{e}") x_parallel, _ = mlp_fc1(x) x_parallel = mlp_act(x_parallel) out, _ = mlp_fc2(x_parallel) return out