[Bugfix]Fix the hang issue of multimodal model when running with DP>1 (#4392)

### What this PR does / why we need it?
When cudagraph_mode is set to FULL_DECODE_ONLY, if dp > 1, the dummy-run
process will be triggered. When calling the update_attn_params function,
the num_tokens parameter needs to be passed, and this value is obtained
through positions.shape[0]. However, the multimodal model uses mRope
(multi-dimensional rotary positional embeddings), which causes the shape
of positions to be 2. As a result, the value obtained from
positions.shape[0] is incorrect. We solve this problem by replacing
positions.shape[0] with num_tokens.

### Does this PR introduce _any_ user-facing change?
NO

### How was this patch tested?
vLLM version: v0.11.0rc3
vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

- vLLM version: v0.11.0
- vLLM main:
2918c1b49c

---------

Signed-off-by: wujinyuan1 <wjy9595@qq.com>
Co-authored-by: wujinyuan1 <wjy9595@qq.com>
This commit is contained in:
wujinyuan1
2025-11-25 09:33:49 +08:00
committed by GitHub
parent 84eae97f27
commit 06f6cc1c81

View File

@@ -2810,8 +2810,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
else: else:
# FIXME: Try using `auto_dispatch_capture=True` # FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_params(self.update_stream, forward_context, update_mla_attn_params(self.update_stream, forward_context,
positions.shape[0], num_tokens, self.speculative_config)
self.speculative_config)
else: else:
if self.pcp_size * self.dcp_size > 1: if self.pcp_size * self.dcp_size > 1:
update_attn_dcp_pcp_params(self.update_stream, update_attn_dcp_pcp_params(self.update_stream,
@@ -2819,7 +2818,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
positions.shape[0]) positions.shape[0])
else: else:
update_attn_params(self.update_stream, forward_context, update_attn_params(self.update_stream, forward_context,
positions.shape[0]) num_tokens)
if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3:
hidden_states, _ = hidden_states hidden_states, _ = hidden_states