[Bugfix] Fix mm_merge (#5249)
### What this PR does / why we need it?
We should transfer the mm_embed to the dtype of input_embed before
performing the in-place assignment
- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c
Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
@@ -37,8 +37,9 @@ def _merge_multimodal_embeddings(
|
||||
This updates ``inputs_embeds`` in place.
|
||||
"""
|
||||
flattened = _flatten_embeddings(multimodal_embeddings)
|
||||
input_dtype = inputs_embeds.dtype
|
||||
try:
|
||||
inputs_embeds[is_multimodal] = flattened
|
||||
inputs_embeds[is_multimodal] = flattened.to(dtype=input_dtype)
|
||||
except RuntimeError as e:
|
||||
num_expected_tokens = is_multimodal.sum().item()
|
||||
assert isinstance(num_expected_tokens, int)
|
||||
|
||||
Reference in New Issue
Block a user