Files
xc-llm-ascend/vllm_ascend/patch/worker/patch_multimodal_merge.py

59 lines
2.1 KiB
Python
Raw Permalink Normal View History

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
import torch
import vllm
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #10) (#6173) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | |`vllm_ascend/ops/layer_shard_linear.py`| |`vllm_ascend/ops/linear.py`| |`vllm_ascend/ops/linear_op.py`| |`vllm_ascend/worker/worker.py`| | ` vllm_ascend/patch/worker/patch_bert.py` | | ` vllm_ascend/patch/worker/patch_deepseek.py` | | ` vllm_ascend/patch/worker/patch_distributed.py` | | ` vllm_ascend/patch/worker/patch_module.py` | | ` vllm_ascend/patch/worker/patch_multimodal_merge.py` | | ` vllm_ascend/patch/worker/patch_qwen3_next.py` | | ` vllm_ascend/patch/worker/patch_qwen3_next_mtp.py` | | ` vllm_ascend/patch/worker/patch_rejection_sampler.py` | | ` vllm_ascend/patch/worker/patch_rope.py` | | ` vllm_ascend/patch/worker/patch_triton.py` | | ` vllm_ascend/patch/worker/patch_unquantized_gemm.py` | | ` vllm_ascend/patch/worker/patch_v2_egale.py` | |` vllm_ascend/worker/npu_input_batch.py`| |` vllm_ascend/worker/v2/aclgraph_utils.py`| |` vllm_ascend/worker/v2/attn_utils.py`| |` vllm_ascend/worker/v2/model_runner.py`| |` vllm_ascend/worker/v2/sample/gumbel.py`| |` vllm_ascend/worker/v2/sample/penalties.py`| |` vllm_ascend/worker/v2/sample/sampler.py`| |` vllm_ascend/worker/v2/spec_decode/__init__.py`| |` vllm_ascend/worker/v2/spec_decode/eagle.py`| |` vllm_ascend/worker/v2/states.py`| ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com> Signed-off-by: SILONG ZENG <2609716663@qq.com> Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
2026-02-06 15:35:06 +08:00
from vllm.model_executor.models.utils import _embedding_count_expression, _flatten_embeddings
from vllm.multimodal import NestedTensors
def _merge_multimodal_embeddings(
inputs_embeds: torch.Tensor,
multimodal_embeddings: NestedTensors,
is_multimodal: torch.Tensor,
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``.
Note:
This updates ``inputs_embeds`` in place.
"""
flattened = _flatten_embeddings(multimodal_embeddings)
input_dtype = inputs_embeds.dtype
try:
inputs_embeds[is_multimodal] = flattened.to(dtype=input_dtype)
except RuntimeError as e:
num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)
if flattened.shape[0] != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError(
f"Attempted to assign {expr} = {flattened.shape[0]} "
f"multimodal tokens to {num_expected_tokens} placeholders"
) from e
else:
raise ValueError("Error during masked scatter operation") from e
return inputs_embeds
vllm.model_executor.models.utils._merge_multimodal_embeddings = _merge_multimodal_embeddings