From faf8cd89cb9a853bd8a3c16b8d7321e2da4a2342 Mon Sep 17 00:00:00 2001 From: zouyida2002 Date: Fri, 7 Mar 2025 15:41:47 +0800 Subject: [PATCH] register qwen2_vl to rewrite qwen2_vl forwad (#241) Add qwen2-vl ascend impletation. --------- Signed-off-by: zouyida --- setup.py | 6 +- vllm_ascend/__init__.py | 5 + vllm_ascend/models/__init__.py | 9 ++ vllm_ascend/models/qwen2_vl.py | 169 +++++++++++++++++++++++++++++++++ 4 files changed, 188 insertions(+), 1 deletion(-) create mode 100644 vllm_ascend/models/__init__.py create mode 100644 vllm_ascend/models/qwen2_vl.py diff --git a/setup.py b/setup.py index d278ef9..492934a 100644 --- a/setup.py +++ b/setup.py @@ -99,4 +99,8 @@ setup( python_requires=">=3.9", install_requires=get_requirements(), extras_require={}, - entry_points={'vllm.platform_plugins': ["ascend = vllm_ascend:register"]}) + entry_points={ + 'vllm.platform_plugins': ["ascend = vllm_ascend:register"], + 'vllm.general_plugins': + ["ascend_enhanced_model = vllm_ascend:register_model"] + }) diff --git a/vllm_ascend/__init__.py b/vllm_ascend/__init__.py index 80af5a5..c3b7661 100644 --- a/vllm_ascend/__init__.py +++ b/vllm_ascend/__init__.py @@ -19,3 +19,8 @@ def register(): """Register the NPU platform.""" return "vllm_ascend.platform.NPUPlatform" + + +def register_model(): + from .models import register_model + register_model() diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py new file mode 100644 index 0000000..64b994d --- /dev/null +++ b/vllm_ascend/models/__init__.py @@ -0,0 +1,9 @@ +from vllm import ModelRegistry + + +def register_model(): + from .qwen2_vl import CustomQwen2VLForConditionalGeneration # noqa: F401 + + ModelRegistry.register_model( + "Qwen2VLForConditionalGeneration", + "vllm_ascend.models.qwen2_vl:CustomQwen2VLForConditionalGeneration") diff --git a/vllm_ascend/models/qwen2_vl.py b/vllm_ascend/models/qwen2_vl.py new file mode 100644 index 0000000..d4b108b --- /dev/null +++ b/vllm_ascend/models/qwen2_vl.py @@ -0,0 +1,169 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Adapted from vllm/model_executor/models/qwen2_vl.py +# Copyright 2023 The vLLM team. +# +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Callable, Optional, Type + +import torch +import torch.nn as nn +import torch_npu +from einops import rearrange +from transformers.models.qwen2_vl.configuration_qwen2_vl import \ + Qwen2VLVisionConfig +from vllm.config import VllmConfig +from vllm.model_executor.layers.activation import QuickGELU +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.models.qwen2_vl import ( + Qwen2VisionAttention, Qwen2VisionBlock, Qwen2VisionTransformer, + Qwen2VLDummyInputsBuilder, Qwen2VLForConditionalGeneration, + Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo, + apply_rotary_pos_emb_vision) +from vllm.model_executor.models.utils import maybe_prefix +from vllm.multimodal import MULTIMODAL_REGISTRY + + +class CustomQwen2VisionAttention(Qwen2VisionAttention): + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + ) -> torch.Tensor: + + # [s, b, c] --> [s, b, 3 * head * head_dim] + x, _ = self.qkv(x) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] + q, k, v = self.split_qkv(x) + batch_size = q.shape[1] + + q, k, v = [ + rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v) + ] + if rotary_pos_emb is not None: + q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) + k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + q, k, v = [ + rearrange(x, "b s h d -> (b s) h d").contiguous() + for x in (q, k, v) + ] + + context_layer = torch.torch.empty_like(q) + + # operator requires pta version >= 2.5.1.dev20250226 + torch_npu._npu_flash_attention_unpad( + query=q, + key=k, + value=v, + seq_len=cu_seqlens, + scale_value=self.hidden_size_per_attention_head**-0.5, + num_heads=self.num_attention_heads_per_partition, + num_kv_heads=self.num_attention_heads_per_partition, + out=context_layer) + context_layer = rearrange(context_layer, + "(b s) h d -> s b (h d)", + b=batch_size).contiguous() + + output, _ = self.proj(context_layer) + return output + + +class CustomQwen2VisionBlock(Qwen2VisionBlock): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float, + act_layer: Type[nn.Module] = QuickGELU, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(dim, num_heads, mlp_ratio, act_layer, norm_layer, + quant_config, prefix) + self.attn = CustomQwen2VisionAttention(embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + +class CustomQwen2VisionTransformer(Qwen2VisionTransformer): + + def __init__( + self, + vision_config: Qwen2VLVisionConfig, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(vision_config, norm_eps, quant_config, prefix) + self.blocks = nn.ModuleList([ + CustomQwen2VisionBlock(dim=self.embed_dim, + num_heads=self.num_heads, + mlp_ratio=vision_config.mlp_ratio, + norm_layer=partial(nn.LayerNorm, + eps=norm_eps), + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(vision_config.depth) + ]) + + def forward( + self, + x: torch.Tensor, + grid_thw: torch.Tensor, + ) -> torch.Tensor: + # patchify + x = x.to(device=self.device, dtype=self.dtype) + x = self.patch_embed(x) + + # compute position embedding + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + # compute cu_seqlens and avoid cumsum to fit operator unpadFA + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, + 0]).cpu().to(torch.int32) + + x = x.unsqueeze(1) + for blk in self.blocks: + x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) + + # adapter + x = self.merger(x) + return x + + +@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, + info=Qwen2VLProcessingInfo, + dummy_inputs=Qwen2VLDummyInputsBuilder) +class CustomQwen2VLForConditionalGeneration(Qwen2VLForConditionalGeneration): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config) + self.visual = CustomQwen2VisionTransformer( + self.config.vision_config, + norm_eps=getattr(self.config, "rms_norm_eps", 1e-6), + quant_config=self._maybe_ignore_quant_config( + vllm_config.quant_config), + prefix=maybe_prefix(prefix, "visual"), + )