### What this PR does / why we need it? Enable Flash Comm V1 (sequence parallelism) for Qwen3-VL models (both dense and MoE variants). Root cause: Qwen3-VL's deepstack embeddings remain full-size [N, H] while hidden states become [N/tp_size, H] after reduce-scatter, causing shape mismatch on add. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - [x] Run Qwen3-VL dense model with FC1 enabled (TP > 1), verify correct output - [x] Run Qwen3-VL MoE model with FC1 enabled (TP > 1), verify correct output --------- Signed-off-by: betta18 <jiangmengyu1@huawei.com> Signed-off-by: jiangmengyu18 <56633611+jiangmengyu18@users.noreply.github.com> Co-authored-by: betta18 <jiangmengyu1@huawei.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
71 lines
2.9 KiB
Python
71 lines
2.9 KiB
Python
import torch
|
|
from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
|
|
from vllm.model_executor.models.qwen3 import Qwen3Attention
|
|
from vllm.model_executor.models.qwen3_moe import Qwen3MoeAttention
|
|
from vllm.model_executor.models.qwen3_vl import Qwen3VLForConditionalGeneration
|
|
|
|
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
|
from vllm_ascend.ops.rotary_embedding import AscendMRotaryEmbedding
|
|
|
|
|
|
def tensor_parallel_wrap(func):
|
|
def wrap(*args, **kwargs):
|
|
deepstack_input_embeds = func(*args, **kwargs)
|
|
if deepstack_input_embeds is None:
|
|
return deepstack_input_embeds
|
|
try:
|
|
flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled
|
|
except (AssertionError, AttributeError, KeyError):
|
|
flash_comm_v1_enabled = False
|
|
if flash_comm_v1_enabled:
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
deepstack_input_embeds.tensors = {
|
|
k: v.chunk(tp_size)[tp_rank] for k, v in deepstack_input_embeds.tensors.items()
|
|
}
|
|
return deepstack_input_embeds
|
|
|
|
return wrap
|
|
|
|
|
|
def forward_with_split_qkv_rmsnorm_mrope(self, positions: torch.Tensor, hidden_states: torch.Tensor):
|
|
qkv, _ = self.qkv_proj(hidden_states)
|
|
if isinstance(self.rotary_emb, AscendMRotaryEmbedding):
|
|
cos_sin = self.rotary_emb.cos_sin_cache[positions]
|
|
if cos_sin.device != qkv.device:
|
|
cos_sin = cos_sin.to(qkv.device)
|
|
if cos_sin.dtype != qkv.dtype:
|
|
cos_sin = cos_sin.to(qkv.dtype)
|
|
q, k, v, _ = torch.ops.vllm.triton_split_qkv_rmsnorm_mrope(
|
|
qkv=qkv,
|
|
q_weight=self.q_norm.weight,
|
|
k_weight=self.k_norm.weight,
|
|
cos_sin=cos_sin,
|
|
num_q_heads=self.num_heads,
|
|
num_kv_heads=self.num_kv_heads,
|
|
head_size=self.head_dim,
|
|
eps=self.q_norm.variance_epsilon,
|
|
mrope_section=self.rotary_emb.mrope_section,
|
|
is_interleaved=self.rotary_emb.mrope_interleaved,
|
|
rope_dim=self.rotary_emb.rotary_dim,
|
|
)
|
|
else:
|
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
|
|
q_by_head = self.q_norm(q_by_head)
|
|
q = q_by_head.view(q.shape)
|
|
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
|
|
k_by_head = self.k_norm(k_by_head)
|
|
k = k_by_head.view(k.shape)
|
|
q, k = self.rotary_emb(positions, q, k)
|
|
attn_output = self.attn(q, k, v)
|
|
output, _ = self.o_proj(attn_output)
|
|
return output
|
|
|
|
|
|
Qwen3Attention.forward = forward_with_split_qkv_rmsnorm_mrope
|
|
Qwen3MoeAttention.forward = forward_with_split_qkv_rmsnorm_mrope
|
|
Qwen3VLForConditionalGeneration._get_deepstack_input_embeds = tensor_parallel_wrap(
|
|
Qwen3VLForConditionalGeneration._get_deepstack_input_embeds
|
|
)
|