[Bugfix] Fix mm_merge (#5249)

### What this PR does / why we need it?
We should transfer the mm_embed to the dtype of input_embed before
performing the in-place assignment

- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c

Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
Li Wang
2025-12-31 09:49:55 +08:00
committed by GitHub
parent 3c2d3e52e5
commit a5ae07a5d2

View File

@@ -37,8 +37,9 @@ def _merge_multimodal_embeddings(
This updates ``inputs_embeds`` in place.
"""
flattened = _flatten_embeddings(multimodal_embeddings)
input_dtype = inputs_embeds.dtype
try:
inputs_embeds[is_multimodal] = flattened
inputs_embeds[is_multimodal] = flattened.to(dtype=input_dtype)
except RuntimeError as e:
num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)