Files
xc-llm-ascend/vllm_ascend/patch/worker/patch_qwen2_5_omni.py
Canlin Guo bb3a826e08 [Refactor] Remove the process patches of Qwen2.5-VL and Qwen2.5-Omni (#5035)
### What this PR does / why we need it?

Related to #4084. Before we add the patches temporarily for making
`set_forward_context` patched by `set_ascend_forward_context` in the
function `_process_image_input` and `_process_video_input` of
`Qwen2.5-VL` and `Qwen2.5-Omni` models. After removing these patches, I
met the `AttributeError` for `ForwardContext` missing
`prefetch_mlp_enabled`. So we need to add the defensive check for
`prefetch_mlp_enabled`.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

```
vllm serve Qwen/Qwen2.5-VL-7B-Instruct \
    --max-model-len 30000 \
    --max-num-batched-tokens 50000 \
    --max-num-seqs 30 \
    --no-enable-prefix-caching \
    --trust-remote-code \
    --dtype bfloat16
```

```
{"id":"chatcmpl-b66d8acb76905c49","object":"chat.completion","created":1765796863,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration reads \"TONGYI Qwen.\"","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":73,"total_tokens":88,"completion_tokens":15,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null}
```


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-12-16 11:43:52 +08:00

67 lines
2.6 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import torch.nn as nn
from vllm.model_executor.models.qwen2_5_omni_thinker import (
Qwen2_5_VLImageInputs, Qwen2_5_VLVideoInputs)
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
class AscendQwen2_5OmniThinkerForConditionalGeneration(nn.Module):
def _process_image_input(
self,
image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]:
if image_input["type"] == "image_embeds":
return image_input["image_embeds"].type(self.visual.dtype)
grid_thw = image_input["image_grid_thw"]
assert grid_thw.ndim == 2
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
with set_ascend_forward_context(None, self.vllm_config):
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
# Split concatenated embeddings for each image item.
merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size
return image_embeds.split(sizes.tolist())
def _process_video_input(
self,
video_input: Qwen2_5_VLVideoInputs,
video_hashes: list[str] | None = None,
cached_video_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
if video_input["type"] == "video_embeds":
return video_input["video_embeds"].type(self.visual.dtype)
grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2
pixel_values_videos = video_input["pixel_values_videos"].type(
self.visual.dtype)
with set_ascend_forward_context(None, self.vllm_config):
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
# Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size
return video_embeds.split(sizes.tolist())