[v0.18.0][Feature] Support Flash Comm V1 for Qwen3-VL models (#7893)

### What this PR does / why we need it?
Enable Flash Comm V1 (sequence parallelism) for Qwen3-VL models (both
dense and MoE variants).

Root cause: Qwen3-VL's deepstack embeddings remain full-size [N, H]
while hidden states become [N/tp_size, H] after reduce-scatter, causing
shape mismatch on add.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- [x] Run Qwen3-VL dense model with FC1 enabled (TP > 1), verify correct
output
- [x] Run Qwen3-VL MoE model with FC1 enabled (TP > 1), verify correct
output

---------

Signed-off-by: betta18 <jiangmengyu1@huawei.com>
Signed-off-by: jiangmengyu18 <56633611+jiangmengyu18@users.noreply.github.com>
Co-authored-by: betta18 <jiangmengyu1@huawei.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
jiangmengyu18
2026-04-03 11:38:41 +08:00
committed by GitHub
parent 8ce4cfdae7
commit 3cbd6acc89

View File

@@ -1,10 +1,33 @@
import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
from vllm.model_executor.models.qwen3 import Qwen3Attention
from vllm.model_executor.models.qwen3_moe import Qwen3MoeAttention
from vllm.model_executor.models.qwen3_vl import Qwen3VLForConditionalGeneration
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.ops.rotary_embedding import AscendMRotaryEmbedding
def tensor_parallel_wrap(func):
def wrap(*args, **kwargs):
deepstack_input_embeds = func(*args, **kwargs)
if deepstack_input_embeds is None:
return deepstack_input_embeds
try:
flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled
except (AssertionError, AttributeError, KeyError):
flash_comm_v1_enabled = False
if flash_comm_v1_enabled:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
deepstack_input_embeds.tensors = {
k: v.chunk(tp_size)[tp_rank] for k, v in deepstack_input_embeds.tensors.items()
}
return deepstack_input_embeds
return wrap
def forward_with_split_qkv_rmsnorm_mrope(self, positions: torch.Tensor, hidden_states: torch.Tensor):
qkv, _ = self.qkv_proj(hidden_states)
if isinstance(self.rotary_emb, AscendMRotaryEmbedding):
@@ -42,3 +65,6 @@ def forward_with_split_qkv_rmsnorm_mrope(self, positions: torch.Tensor, hidden_s
Qwen3Attention.forward = forward_with_split_qkv_rmsnorm_mrope
Qwen3MoeAttention.forward = forward_with_split_qkv_rmsnorm_mrope
Qwen3VLForConditionalGeneration._get_deepstack_input_embeds = tensor_parallel_wrap(
Qwen3VLForConditionalGeneration._get_deepstack_input_embeds
)