Files
xc-llm-ascend/vllm_ascend/patch/worker/patch_qwen3vl.py

71 lines
2.9 KiB
Python
Raw Permalink Normal View History

import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
from vllm.model_executor.models.qwen3 import Qwen3Attention
from vllm.model_executor.models.qwen3_moe import Qwen3MoeAttention
from vllm.model_executor.models.qwen3_vl import Qwen3VLForConditionalGeneration
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.ops.rotary_embedding import AscendMRotaryEmbedding
def tensor_parallel_wrap(func):
def wrap(*args, **kwargs):
deepstack_input_embeds = func(*args, **kwargs)
if deepstack_input_embeds is None:
return deepstack_input_embeds
try:
flash_comm_v1_enabled = _EXTRA_CTX.flash_comm_v1_enabled
except (AssertionError, AttributeError, KeyError):
flash_comm_v1_enabled = False
if flash_comm_v1_enabled:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
deepstack_input_embeds.tensors = {
k: v.chunk(tp_size)[tp_rank] for k, v in deepstack_input_embeds.tensors.items()
}
return deepstack_input_embeds
return wrap
def forward_with_split_qkv_rmsnorm_mrope(self, positions: torch.Tensor, hidden_states: torch.Tensor):
qkv, _ = self.qkv_proj(hidden_states)
if isinstance(self.rotary_emb, AscendMRotaryEmbedding):
cos_sin = self.rotary_emb.cos_sin_cache[positions]
if cos_sin.device != qkv.device:
cos_sin = cos_sin.to(qkv.device)
if cos_sin.dtype != qkv.dtype:
cos_sin = cos_sin.to(qkv.dtype)
q, k, v, _ = torch.ops.vllm.triton_split_qkv_rmsnorm_mrope(
qkv=qkv,
q_weight=self.q_norm.weight,
k_weight=self.k_norm.weight,
cos_sin=cos_sin,
num_q_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_size=self.head_dim,
eps=self.q_norm.variance_epsilon,
mrope_section=self.rotary_emb.mrope_section,
is_interleaved=self.rotary_emb.mrope_interleaved,
rope_dim=self.rotary_emb.rotary_dim,
)
else:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
Qwen3Attention.forward = forward_with_split_qkv_rmsnorm_mrope
Qwen3MoeAttention.forward = forward_with_split_qkv_rmsnorm_mrope
Qwen3VLForConditionalGeneration._get_deepstack_input_embeds = tensor_parallel_wrap(
Qwen3VLForConditionalGeneration._get_deepstack_input_embeds
)