[Bugfix] Fix mm_merge (#5249)
### What this PR does / why we need it?
We should transfer the mm_embed to the dtype of input_embed before
performing the in-place assignment
- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c
Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
@@ -37,8 +37,9 @@ def _merge_multimodal_embeddings(
|
|||||||
This updates ``inputs_embeds`` in place.
|
This updates ``inputs_embeds`` in place.
|
||||||
"""
|
"""
|
||||||
flattened = _flatten_embeddings(multimodal_embeddings)
|
flattened = _flatten_embeddings(multimodal_embeddings)
|
||||||
|
input_dtype = inputs_embeds.dtype
|
||||||
try:
|
try:
|
||||||
inputs_embeds[is_multimodal] = flattened
|
inputs_embeds[is_multimodal] = flattened.to(dtype=input_dtype)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
num_expected_tokens = is_multimodal.sum().item()
|
num_expected_tokens = is_multimodal.sum().item()
|
||||||
assert isinstance(num_expected_tokens, int)
|
assert isinstance(num_expected_tokens, int)
|
||||||
|
|||||||
Reference in New Issue
Block a user