From bb3a826e089f861736664f828050716671ae6431 Mon Sep 17 00:00:00 2001 From: Canlin Guo Date: Tue, 16 Dec 2025 11:43:52 +0800 Subject: [PATCH] [Refactor] Remove the process patches of Qwen2.5-VL and Qwen2.5-Omni (#5035) ### What this PR does / why we need it? Related to #4084. Before we add the patches temporarily for making `set_forward_context` patched by `set_ascend_forward_context` in the function `_process_image_input` and `_process_video_input` of `Qwen2.5-VL` and `Qwen2.5-Omni` models. After removing these patches, I met the `AttributeError` for `ForwardContext` missing `prefetch_mlp_enabled`. So we need to add the defensive check for `prefetch_mlp_enabled`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? ``` vllm serve Qwen/Qwen2.5-VL-7B-Instruct \ --max-model-len 30000 \ --max-num-batched-tokens 50000 \ --max-num-seqs 30 \ --no-enable-prefix-caching \ --trust-remote-code \ --dtype bfloat16 ``` ``` {"id":"chatcmpl-b66d8acb76905c49","object":"chat.completion","created":1765796863,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration reads \"TONGYI Qwen.\"","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":73,"total_tokens":88,"completion_tokens":15,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null} ``` - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: gcanlin Co-authored-by: wangxiyuan --- vllm_ascend/ops/register_custom_ops.py | 6 +++--- vllm_ascend/patch/worker/patch_qwen2_5_omni.py | 8 +------- vllm_ascend/patch/worker/patch_qwen2_5_vl.py | 10 +++------- 3 files changed, 7 insertions(+), 17 deletions(-) diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index 06a52ae1..b7100991 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -116,7 +116,7 @@ def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor, except AssertionError: return - if not forward_context.prefetch_mlp_enabled: + if not getattr(forward_context, 'prefetch_mlp_enabled', False): return model_instance = forward_context.model_instance prefetch_stream = forward_context.prefetch_stream @@ -173,7 +173,7 @@ def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None: except AssertionError: return - if not forward_context.prefetch_mlp_enabled: + if not getattr(forward_context, 'prefetch_mlp_enabled', False): return forward_context.prefetch_mlp_down_proj = True model_instance = forward_context.model_instance @@ -202,7 +202,7 @@ def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None: except AssertionError: return - if not forward_context.prefetch_mlp_enabled: + if not getattr(forward_context, 'prefetch_mlp_enabled', False): return if forward_context.prefetch_mlp_gate_up_proj or \ forward_context.prefetch_mlp_down_proj: diff --git a/vllm_ascend/patch/worker/patch_qwen2_5_omni.py b/vllm_ascend/patch/worker/patch_qwen2_5_omni.py index f52d1a1d..c272edb3 100644 --- a/vllm_ascend/patch/worker/patch_qwen2_5_omni.py +++ b/vllm_ascend/patch/worker/patch_qwen2_5_omni.py @@ -18,8 +18,7 @@ import torch import torch.nn as nn from vllm.model_executor.models.qwen2_5_omni_thinker import ( - Qwen2_5_VLImageInputs, Qwen2_5_VLVideoInputs, - Qwen2_5OmniThinkerForConditionalGeneration) + Qwen2_5_VLImageInputs, Qwen2_5_VLVideoInputs) from vllm_ascend.ascend_forward_context import set_ascend_forward_context @@ -65,8 +64,3 @@ class AscendQwen2_5OmniThinkerForConditionalGeneration(nn.Module): sizes = grid_thw.prod(-1) // merge_size // merge_size return video_embeds.split(sizes.tolist()) - - -# NOTE: These will be removed after ascend_forward_context is refactored. -Qwen2_5OmniThinkerForConditionalGeneration._process_image_input = AscendQwen2_5OmniThinkerForConditionalGeneration._process_image_input -Qwen2_5OmniThinkerForConditionalGeneration._process_video_input = AscendQwen2_5OmniThinkerForConditionalGeneration._process_video_input diff --git a/vllm_ascend/patch/worker/patch_qwen2_5_vl.py b/vllm_ascend/patch/worker/patch_qwen2_5_vl.py index 7db4323d..62a1e67e 100644 --- a/vllm_ascend/patch/worker/patch_qwen2_5_vl.py +++ b/vllm_ascend/patch/worker/patch_qwen2_5_vl.py @@ -20,9 +20,9 @@ import torch import torch.nn as nn import torch.nn.functional as F import torch_npu -from vllm.model_executor.models.qwen2_5_vl import ( - Qwen2_5_VisionAttention, Qwen2_5_VLForConditionalGeneration, - Qwen2_5_VLImageInputs, Qwen2_5_VLVideoInputs) +from vllm.model_executor.models.qwen2_5_vl import (Qwen2_5_VisionAttention, + Qwen2_5_VLImageInputs, + Qwen2_5_VLVideoInputs) from vllm.model_executor.models.qwen2_vl import Qwen2VisionAttention from vllm.model_executor.models.vision import run_dp_sharded_mrope_vision_model @@ -169,7 +169,3 @@ class AscendQwen2_5_VLForConditionalGeneration(nn.Module): # NOTE: This will be removed after MMEncoderAttention has been extract as a CustomOp in vllm. Qwen2VisionAttention.forward = AscendQwen2_5_VisionAttention.forward Qwen2_5_VisionAttention.forward = AscendQwen2_5_VisionAttention.forward - -# NOTE: These will be removed after ascend_forward_context is refactored. -Qwen2_5_VLForConditionalGeneration._process_image_input = AscendQwen2_5_VLForConditionalGeneration._process_image_input -Qwen2_5_VLForConditionalGeneration._process_video_input = AscendQwen2_5_VLForConditionalGeneration._process_video_input