[Bugfix] fix qwen3-vl-moe shape ERROR during the _prepare_inputs phase under high concurrency. (#4658)

### What this PR does / why we need it?
Earlier we fixed a similar issue for qwen2.5-vl 【
https://github.com/vllm-project/vllm-ascend/issues/4430 】, and then the
multimodal models in vllm v0.11.0 should all have this problem. Here, we
have specifically proposed a fix for qwen3-vl-moe.

---------

Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
This commit is contained in:
Levi
2025-12-08 19:30:16 +08:00
committed by GitHub
parent d412565ec9
commit 4e728f1f40
2 changed files with 113 additions and 4 deletions

View File

@@ -65,7 +65,9 @@ except ImportError:
Qwen3VLProcessingInfo = object Qwen3VLProcessingInfo = object
Qwen3VLMoeForConditionalGeneration = object Qwen3VLMoeForConditionalGeneration = object
Qwen3VLMoeProcessingInfo = object Qwen3VLMoeProcessingInfo = object
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix from vllm.model_executor.models.utils import (WeightsMapper,
_merge_multimodal_embeddings,
maybe_prefix)
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm_ascend.models.qwen2_5_vl import AscendQwen2_5_VisionRotaryEmbedding from vllm_ascend.models.qwen2_5_vl import AscendQwen2_5_VisionRotaryEmbedding
@@ -564,8 +566,6 @@ class AscendQwen2_5_VLForConditionalGeneration_Without_Padding(
on those tokens. Note however that doing so increases memory usage on those tokens. Note however that doing so increases memory usage
as an additional buffer is needed to hold the input embeddings. as an additional buffer is needed to hold the input embeddings.
""" """
from vllm.model_executor.models.utils import \
_merge_multimodal_embeddings
inputs_embeds = self._get_text_embeddings( inputs_embeds = self._get_text_embeddings(
input_ids, input_ids,
@@ -669,3 +669,112 @@ class AscendQwen3VLMoeForConditionalGeneration(
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel, use_data_parallel=self.use_data_parallel,
) )
def _get_text_embeddings(
self,
input_ids: torch.Tensor,
get_input_embeddings: Callable[[torch.Tensor], torch.Tensor],
*,
is_multimodal: Optional[torch.Tensor],
handle_oov_mm_token: bool,
) -> torch.Tensor:
if handle_oov_mm_token and is_multimodal is not None:
is_text = ~is_multimodal
text_embeds = get_input_embeddings(input_ids[is_text])
return torch.empty(
(input_ids.shape[0], text_embeds.shape[1]),
dtype=text_embeds.dtype,
device=text_embeds.device,
).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)
return get_input_embeddings(input_ids)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
"""
Apply token embeddings to `input_ids`.
If `multimodal_embeddings` is passed, scatter them into
`input_ids` according to the mask `is_multimodal`.
In case the multi-modal token IDs exceed the vocabulary size of
the language model, you can set `handle_oov_mm_token=False`
to avoid calling the language model's `get_input_embeddings` method
on those tokens. Note however that doing so increases memory usage
as an additional buffer is needed to hold the input embeddings.
"""
inputs_embeds = self._get_text_embeddings(
input_ids,
self.get_language_model().get_input_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
if is_multimodal is None:
raise ValueError(
"`get_input_embeddings` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229.")
if self.use_deepstack:
(
deepstack_input_embeds,
multimodal_embeddings,
) = self._compute_deepstack_embeds(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
else:
deepstack_input_embeds = None
inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
is_multimodal=is_multimodal,
multimodal_embeddings=multimodal_embeddings,
)
if deepstack_input_embeds is not None:
self._set_deepstack_input_embeds(deepstack_input_embeds)
return inputs_embeds
def _compute_deepstack_embeds(
self,
inputs_embeds: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings,
is_multimodal: torch.Tensor,
) -> tuple[torch.Tensor, MultiModalEmbeddings]:
visual_lens = [len(x) for x in multimodal_embeddings]
multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0)
total_dim = multimodal_embeddings_cat.shape[-1]
assert total_dim == self.visual_dim + self.multiscale_dim, \
f"Total dimension mismatch: input {total_dim}, expected {self.visual_dim + self.multiscale_dim}"
multimodal_embeddings_main = multimodal_embeddings_cat[
..., :self.visual_dim]
multimodal_embeddings_multiscale = multimodal_embeddings_cat[
..., self.visual_dim:]
multimodal_embeddings = torch.split(multimodal_embeddings_main,
visual_lens,
dim=0)
multimodal_embeddings_multiscale = torch.split(
multimodal_embeddings_multiscale, visual_lens, dim=0)
deepstack_input_embeds = inputs_embeds.new_zeros(
inputs_embeds.size(0),
self.deepstack_num_level * inputs_embeds.size(1))
deepstack_input_embeds = _merge_multimodal_embeddings(
inputs_embeds=deepstack_input_embeds,
multimodal_embeddings=multimodal_embeddings_multiscale,
is_multimodal=is_multimodal,
)
deepstack_input_embeds = deepstack_input_embeds.view(
inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim)
deepstack_input_embeds = deepstack_input_embeds.permute(
1, 0, 2).contiguous()
return deepstack_input_embeds, multimodal_embeddings

View File

@@ -1395,7 +1395,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# as input to the multimodal model, even when the input is text. # as input to the multimodal model, even when the input is text.
input_ids = self.input_ids[:total_num_scheduled_tokens] input_ids = self.input_ids[:total_num_scheduled_tokens]
model_type = self.vllm_config.model_config.hf_config.model_type model_type = self.vllm_config.model_config.hf_config.model_type
if model_type == "qwen2_5_vl": if model_type == "qwen2_5_vl" or model_type == "qwen3_vl_moe":
inputs_embeds = self.model.get_input_embeddings( inputs_embeds = self.model.get_input_embeddings(
input_ids, input_ids,
multimodal_embeddings=mm_embeds, multimodal_embeddings=mm_embeds,