[Bugfix] fix qwen2.5-vl-72b shape ERROR during the _prepare_inputs phase under high concurrency. (#4553)
### What this PR does / why we need it?
qwen2.5-vl-72b reports a shape ERROR during the _prepare_inputs phase
under high concurrency【 issue
https://github.com/vllm-project/vllm-ascend/issues/4430 】
This PR fix it.
The related PR in main branch
:https://github.com/vllm-project/vllm-ascend/pull/3612
The related commit in vllm :
17c540a993/vllm/model_executor/models/interfaces.py
【The _get_text_embeddings function has been refactored to
interfaces.pyin vllm.】
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
This commit is contained in:
@@ -34,6 +34,7 @@ from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.qwen2_5_vl import (
|
||||
Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed,
|
||||
Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VisionTransformer,
|
||||
@@ -560,3 +561,68 @@ class AscendQwen2_5_VLForConditionalGeneration(
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||
return video_embeds.split(sizes.tolist())
|
||||
|
||||
def _get_text_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
get_input_embeddings: Callable[[torch.Tensor], torch.Tensor],
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor],
|
||||
handle_oov_mm_token: bool,
|
||||
) -> torch.Tensor:
|
||||
if handle_oov_mm_token and is_multimodal is not None:
|
||||
is_text = ~is_multimodal
|
||||
text_embeds = get_input_embeddings(input_ids[is_text])
|
||||
|
||||
return torch.empty(
|
||||
(input_ids.shape[0], text_embeds.shape[1]),
|
||||
dtype=text_embeds.dtype,
|
||||
device=text_embeds.device,
|
||||
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
|
||||
|
||||
return get_input_embeddings(input_ids)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply token embeddings to `input_ids`.
|
||||
|
||||
If `multimodal_embeddings` is passed, scatter them into
|
||||
`input_ids` according to the mask `is_multimodal`.
|
||||
|
||||
In case the multi-modal token IDs exceed the vocabulary size of
|
||||
the language model, you can set `handle_oov_mm_token=False`
|
||||
to avoid calling the language model's `get_input_embeddings` method
|
||||
on those tokens. Note however that doing so increases memory usage
|
||||
as an additional buffer is needed to hold the input embeddings.
|
||||
"""
|
||||
from vllm.model_executor.models.utils import \
|
||||
_merge_multimodal_embeddings
|
||||
|
||||
inputs_embeds = self._get_text_embeddings(
|
||||
input_ids,
|
||||
self.get_language_model().get_input_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
if is_multimodal is None:
|
||||
raise ValueError(
|
||||
"`get_input_embeddings` now requires `is_multimodal` arg, "
|
||||
"please update your model runner according to "
|
||||
"https://github.com/vllm-project/vllm/pull/16229.")
|
||||
|
||||
return _merge_multimodal_embeddings(
|
||||
inputs_embeds=inputs_embeds,
|
||||
is_multimodal=is_multimodal,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
)
|
||||
|
||||
@@ -26,6 +26,7 @@ import torch_npu
|
||||
from einops import rearrange
|
||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
|
||||
try:
|
||||
from transformers.models.qwen3_vl.configuration_qwen3_vl import \
|
||||
@@ -523,6 +524,71 @@ class AscendQwen2_5_VLForConditionalGeneration_Without_Padding(
|
||||
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||
return video_embeds.split(sizes.tolist())
|
||||
|
||||
def _get_text_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
get_input_embeddings: Callable[[torch.Tensor], torch.Tensor],
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor],
|
||||
handle_oov_mm_token: bool,
|
||||
) -> torch.Tensor:
|
||||
if handle_oov_mm_token and is_multimodal is not None:
|
||||
is_text = ~is_multimodal
|
||||
text_embeds = get_input_embeddings(input_ids[is_text])
|
||||
|
||||
return torch.empty(
|
||||
(input_ids.shape[0], text_embeds.shape[1]),
|
||||
dtype=text_embeds.dtype,
|
||||
device=text_embeds.device,
|
||||
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
|
||||
|
||||
return get_input_embeddings(input_ids)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply token embeddings to `input_ids`.
|
||||
|
||||
If `multimodal_embeddings` is passed, scatter them into
|
||||
`input_ids` according to the mask `is_multimodal`.
|
||||
|
||||
In case the multi-modal token IDs exceed the vocabulary size of
|
||||
the language model, you can set `handle_oov_mm_token=False`
|
||||
to avoid calling the language model's `get_input_embeddings` method
|
||||
on those tokens. Note however that doing so increases memory usage
|
||||
as an additional buffer is needed to hold the input embeddings.
|
||||
"""
|
||||
from vllm.model_executor.models.utils import \
|
||||
_merge_multimodal_embeddings
|
||||
|
||||
inputs_embeds = self._get_text_embeddings(
|
||||
input_ids,
|
||||
self.get_language_model().get_input_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
if is_multimodal is None:
|
||||
raise ValueError(
|
||||
"`get_input_embeddings` now requires `is_multimodal` arg, "
|
||||
"please update your model runner according to "
|
||||
"https://github.com/vllm-project/vllm/pull/16229.")
|
||||
|
||||
return _merge_multimodal_embeddings(
|
||||
inputs_embeds=inputs_embeds,
|
||||
is_multimodal=is_multimodal,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor,
|
||||
info=Qwen3VLProcessingInfo,
|
||||
|
||||
Reference in New Issue
Block a user