[Bugfix] fix qwen2.5-vl-72b shape ERROR during the _prepare_inputs phase under high concurrency. (#4553)
### What this PR does / why we need it?
qwen2.5-vl-72b reports a shape ERROR during the _prepare_inputs phase
under high concurrency【 issue
https://github.com/vllm-project/vllm-ascend/issues/4430 】
This PR fix it.
The related PR in main branch
:https://github.com/vllm-project/vllm-ascend/pull/3612
The related commit in vllm :
17c540a993/vllm/model_executor/models/interfaces.py
【The _get_text_embeddings function has been refactored to
interfaces.pyin vllm.】
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
This commit is contained in:
@@ -62,6 +62,7 @@ from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models.interfaces import supports_transcription
|
||||
from vllm.model_executor.models.interfaces_base import (
|
||||
VllmModelForPooling, is_pooling_model, is_text_generation_model)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@@ -550,6 +551,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
|
||||
self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
|
||||
|
||||
# Only relevant for multimodal models
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
|
||||
self.model_config)
|
||||
if self.supports_mm_inputs:
|
||||
self.is_mm_embed = self._make_buffer(self.max_num_tokens,
|
||||
dtype=torch.bool)
|
||||
|
||||
def _make_buffer(self,
|
||||
*size: Union[int, torch.SymInt],
|
||||
dtype: torch.dtype,
|
||||
@@ -1034,7 +1043,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
def _gather_mm_embeddings(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> list[torch.Tensor]:
|
||||
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
||||
|
||||
def _iter_mm_features(req_state: CachedRequestState):
|
||||
assert req_state.mm_features is not None
|
||||
@@ -1044,8 +1053,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
pos_info, "is_embed", None)
|
||||
|
||||
mm_embeds: list[torch.Tensor] = []
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
is_mm_embed = self.is_mm_embed.cpu
|
||||
is_mm_embed[:total_num_scheduled_tokens] = False
|
||||
|
||||
req_start_idx = 0
|
||||
|
||||
for req_id in self.input_batch.req_ids:
|
||||
mm_embeds_req: list[torch.Tensor] = []
|
||||
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||
req_id]
|
||||
req_state = self.requests[req_id]
|
||||
@@ -1074,12 +1090,22 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if is_embed is not None:
|
||||
is_embed = is_embed[start_idx:end_idx]
|
||||
|
||||
req_start_pos = req_start_idx + start_pos - num_computed_tokens
|
||||
is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \
|
||||
= True if is_embed is None else is_embed
|
||||
|
||||
mm_embeds_item = gather_mm_placeholders(
|
||||
encoder_output[start_idx:end_idx],
|
||||
is_embed=is_embed,
|
||||
)
|
||||
mm_embeds.append(mm_embeds_item)
|
||||
return mm_embeds
|
||||
mm_embeds_req.append(mm_embeds_item)
|
||||
|
||||
mm_embeds.extend(mm_embeds_req)
|
||||
req_start_idx += num_scheduled_tokens
|
||||
|
||||
is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens)
|
||||
|
||||
return mm_embeds, is_mm_embed
|
||||
|
||||
def _get_cumsum_and_arange(
|
||||
self,
|
||||
@@ -1362,17 +1388,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.is_multimodal_model:
|
||||
# Run the multimodal encoder if any.
|
||||
self._execute_mm_encoder(scheduler_output)
|
||||
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
||||
|
||||
mm_embeds, is_mm_embed = self._gather_mm_embeddings(
|
||||
scheduler_output)
|
||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||
# embeddings), we always use embeddings (rather than token ids)
|
||||
# as input to the multimodal model, even when the input is text.
|
||||
input_ids = self.input_ids[:total_num_scheduled_tokens]
|
||||
if mm_embeds:
|
||||
inputs_embeds = self.model.get_input_embeddings(
|
||||
input_ids, mm_embeds)
|
||||
else:
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
inputs_embeds = self.model.get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings=mm_embeds,
|
||||
is_multimodal=is_mm_embed,
|
||||
)
|
||||
# TODO(woosuk): Avoid the copy. Optimize.
|
||||
self.inputs_embeds[:total_num_scheduled_tokens].copy_(
|
||||
inputs_embeds)
|
||||
|
||||
Reference in New Issue
Block a user