From 3cbd6acc8920723e6e026bd66a3c19764eebe67e Mon Sep 17 00:00:00 2001 From: jiangmengyu18 <56633611+jiangmengyu18@users.noreply.github.com> Date: Fri, 3 Apr 2026 11:38:41 +0800 Subject: [PATCH] [v0.18.0][Feature] Support Flash Comm V1 for Qwen3-VL models (#7893) ### What this PR does / why we need it? Enable Flash Comm V1 (sequence parallelism) for Qwen3-VL models (both dense and MoE variants). Root cause: Qwen3-VL's deepstack embeddings remain full-size [N, H] while hidden states become [N/tp_size, H] after reduce-scatter, causing shape mismatch on add. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - [x] Run Qwen3-VL dense model with FC1 enabled (TP > 1), verify correct output - [x] Run Qwen3-VL MoE model with FC1 enabled (TP > 1), verify correct output --------- Signed-off-by: betta18 Signed-off-by: jiangmengyu18 <56633611+jiangmengyu18@users.noreply.github.com> Co-authored-by: betta18 Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- vllm_ascend/patch/worker/patch_qwen3vl.py | 26 +++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/vllm_ascend/patch/worker/patch_qwen3vl.py b/vllm_ascend/patch/worker/patch_qwen3vl.py index 514de099..103e3d42 100644 --- a/vllm_ascend/patch/worker/patch_qwen3vl.py +++ b/vllm_ascend/patch/worker/patch_qwen3vl.py @@ -1,10 +1,33 @@ import torch +from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size from vllm.model_executor.models.qwen3 import Qwen3Attention from vllm.model_executor.models.qwen3_moe import Qwen3MoeAttention +from vllm.model_executor.models.qwen3_vl import Qwen3VLForConditionalGeneration +from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.ops.rotary_embedding import AscendMRotaryEmbedding +def tensor_parallel_wrap(func): + def wrap(*args, **kwargs): + deepstack_input_embeds = func(*args, **kwargs) + if deepstack_input_embeds is None: + return deepstack_input_embeds + try: + flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled + except (AssertionError, AttributeError, KeyError): + flash_comm_v1_enabled = False + if flash_comm_v1_enabled: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + deepstack_input_embeds.tensors = { + k: v.chunk(tp_size)[tp_rank] for k, v in deepstack_input_embeds.tensors.items() + } + return deepstack_input_embeds + + return wrap + + def forward_with_split_qkv_rmsnorm_mrope(self, positions: torch.Tensor, hidden_states: torch.Tensor): qkv, _ = self.qkv_proj(hidden_states) if isinstance(self.rotary_emb, AscendMRotaryEmbedding): @@ -42,3 +65,6 @@ def forward_with_split_qkv_rmsnorm_mrope(self, positions: torch.Tensor, hidden_s Qwen3Attention.forward = forward_with_split_qkv_rmsnorm_mrope Qwen3MoeAttention.forward = forward_with_split_qkv_rmsnorm_mrope +Qwen3VLForConditionalGeneration._get_deepstack_input_embeds = tensor_parallel_wrap( + Qwen3VLForConditionalGeneration._get_deepstack_input_embeds +)