[Bugfix] fix qwen2.5-vl-72b shape ERROR during the _prepare_inputs phase under high concurrency. (#4553)

### What this PR does / why we need it?
qwen2.5-vl-72b reports a shape ERROR during the _prepare_inputs phase
under high concurrency【 issue
https://github.com/vllm-project/vllm-ascend/issues/4430 】

This PR fix it.

The related PR in main branch
:https://github.com/vllm-project/vllm-ascend/pull/3612

The related commit in vllm :
17c540a993/vllm/model_executor/models/interfaces.py

【The _get_text_embeddings function has been refactored to
interfaces.pyin vllm.】

Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
This commit is contained in:
Levi
2025-12-02 14:20:45 +08:00
committed by GitHub
parent 52abd47f8c
commit 3b4cb23616
3 changed files with 168 additions and 10 deletions

View File

@@ -34,6 +34,7 @@ from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.qwen2_5_vl import (
Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed,
Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VisionTransformer,
@@ -560,3 +561,68 @@ class AscendQwen2_5_VLForConditionalGeneration(
merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size
return video_embeds.split(sizes.tolist())
def _get_text_embeddings(
self,
input_ids: torch.Tensor,
get_input_embeddings: Callable[[torch.Tensor], torch.Tensor],
*,
is_multimodal: Optional[torch.Tensor],
handle_oov_mm_token: bool,
) -> torch.Tensor:
if handle_oov_mm_token and is_multimodal is not None:
is_text = ~is_multimodal
text_embeds = get_input_embeddings(input_ids[is_text])
return torch.empty(
(input_ids.shape[0], text_embeds.shape[1]),
dtype=text_embeds.dtype,
device=text_embeds.device,
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
return get_input_embeddings(input_ids)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
"""
Apply token embeddings to `input_ids`.
If `multimodal_embeddings` is passed, scatter them into
`input_ids` according to the mask `is_multimodal`.
In case the multi-modal token IDs exceed the vocabulary size of
the language model, you can set `handle_oov_mm_token=False`
to avoid calling the language model's `get_input_embeddings` method
on those tokens. Note however that doing so increases memory usage
as an additional buffer is needed to hold the input embeddings.
"""
from vllm.model_executor.models.utils import \
_merge_multimodal_embeddings
inputs_embeds = self._get_text_embeddings(
input_ids,
self.get_language_model().get_input_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
if is_multimodal is None:
raise ValueError(
"`get_input_embeddings` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229.")
return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
is_multimodal=is_multimodal,
multimodal_embeddings=multimodal_embeddings,
)

View File

@@ -26,6 +26,7 @@ import torch_npu
from einops import rearrange
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
try:
from transformers.models.qwen3_vl.configuration_qwen3_vl import \
@@ -523,6 +524,71 @@ class AscendQwen2_5_VLForConditionalGeneration_Without_Padding(
sizes = grid_thw.prod(-1) // merge_size // merge_size
return video_embeds.split(sizes.tolist())
def _get_text_embeddings(
self,
input_ids: torch.Tensor,
get_input_embeddings: Callable[[torch.Tensor], torch.Tensor],
*,
is_multimodal: Optional[torch.Tensor],
handle_oov_mm_token: bool,
) -> torch.Tensor:
if handle_oov_mm_token and is_multimodal is not None:
is_text = ~is_multimodal
text_embeds = get_input_embeddings(input_ids[is_text])
return torch.empty(
(input_ids.shape[0], text_embeds.shape[1]),
dtype=text_embeds.dtype,
device=text_embeds.device,
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
return get_input_embeddings(input_ids)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
"""
Apply token embeddings to `input_ids`.
If `multimodal_embeddings` is passed, scatter them into
`input_ids` according to the mask `is_multimodal`.
In case the multi-modal token IDs exceed the vocabulary size of
the language model, you can set `handle_oov_mm_token=False`
to avoid calling the language model's `get_input_embeddings` method
on those tokens. Note however that doing so increases memory usage
as an additional buffer is needed to hold the input embeddings.
"""
from vllm.model_executor.models.utils import \
_merge_multimodal_embeddings
inputs_embeds = self._get_text_embeddings(
input_ids,
self.get_language_model().get_input_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
if is_multimodal is None:
raise ValueError(
"`get_input_embeddings` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229.")
return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
is_multimodal=is_multimodal,
multimodal_embeddings=multimodal_embeddings,
)
@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor,
info=Qwen3VLProcessingInfo,

View File

@@ -62,6 +62,7 @@ from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.interfaces import supports_transcription
from vllm.model_executor.models.interfaces_base import (
VllmModelForPooling, is_pooling_model, is_text_generation_model)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.pooling_params import PoolingParams
@@ -550,6 +551,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
# Only relevant for multimodal models
self.mm_registry = MULTIMODAL_REGISTRY
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
self.model_config)
if self.supports_mm_inputs:
self.is_mm_embed = self._make_buffer(self.max_num_tokens,
dtype=torch.bool)
def _make_buffer(self,
*size: Union[int, torch.SymInt],
dtype: torch.dtype,
@@ -1034,7 +1043,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def _gather_mm_embeddings(
self,
scheduler_output: "SchedulerOutput",
) -> list[torch.Tensor]:
) -> tuple[list[torch.Tensor], torch.Tensor]:
def _iter_mm_features(req_state: CachedRequestState):
assert req_state.mm_features is not None
@@ -1044,8 +1053,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
pos_info, "is_embed", None)
mm_embeds: list[torch.Tensor] = []
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
is_mm_embed = self.is_mm_embed.cpu
is_mm_embed[:total_num_scheduled_tokens] = False
req_start_idx = 0
for req_id in self.input_batch.req_ids:
mm_embeds_req: list[torch.Tensor] = []
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
req_state = self.requests[req_id]
@@ -1074,12 +1090,22 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if is_embed is not None:
is_embed = is_embed[start_idx:end_idx]
req_start_pos = req_start_idx + start_pos - num_computed_tokens
is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \
= True if is_embed is None else is_embed
mm_embeds_item = gather_mm_placeholders(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
mm_embeds.append(mm_embeds_item)
return mm_embeds
mm_embeds_req.append(mm_embeds_item)
mm_embeds.extend(mm_embeds_req)
req_start_idx += num_scheduled_tokens
is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens)
return mm_embeds, is_mm_embed
def _get_cumsum_and_arange(
self,
@@ -1362,17 +1388,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.is_multimodal_model:
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
mm_embeds = self._gather_mm_embeddings(scheduler_output)
mm_embeds, is_mm_embed = self._gather_mm_embeddings(
scheduler_output)
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
input_ids = self.input_ids[:total_num_scheduled_tokens]
if mm_embeds:
inputs_embeds = self.model.get_input_embeddings(
input_ids, mm_embeds)
else:
inputs_embeds = self.model.get_input_embeddings(input_ids)
input_ids,
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
)
# TODO(woosuk): Avoid the copy. Optimize.
self.inputs_embeds[:total_num_scheduled_tokens].copy_(
inputs_embeds)