[qwen3-omni] Add Qwen3-Omni moe thinker

This commit is contained in:
2025-10-09 17:51:14 +08:00
parent bc57e2ef60
commit 24fab12b2f
8 changed files with 1543 additions and 37 deletions

View File

@@ -1,4 +1,5 @@
FROM cr.metax-tech.com/public-ai-release/maca/vllm:maca.ai3.0.0.5-torch2.6-py310-ubuntu22.04-amd64
RUN /opt/conda/bin/pip install --no-cache-dir --upgrade transformers
COPY vllm/ /opt/conda/lib/python3.10/site-packages/vllm/
COPY code_generator.py /opt/conda/lib/python3.10/site-packages/triton/compiler/code_generator.py

View File

@@ -1,6 +1,6 @@
# metax-c500-vllm
本项目包含了对于原版 vllm 的升级,使其可以在沐曦 C500芯片上支持运行 gpt-oss
本项目包含了对于原版 vllm 的升级,使其可以在沐曦 C500芯片上支持运行 gpt-oss, qwen3-omni
本项目中提供的 Dockerfile 中的主要内容为:
1.`vllm` 目录覆盖到镜像中的 `/opt/conda/lib/python3.10/site-packages/vllm`。运行`gpt-oss`时需指定`VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1`

View File

@@ -54,6 +54,14 @@ def check_xformers_availability():
return USE_XFORMERS_OPS
def check_upstream_fa_availability(dtype: torch.dtype):
if dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda(
) and current_platform.has_device_capability(80):
from transformers.utils import is_flash_attn_2_available
return is_flash_attn_2_available()
return False
class Attention(nn.Module):
"""Attention layer.

View File

@@ -531,7 +531,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return "<image>"
if model_type in ("mllama", "llama4"):
return "<|image|>"
if model_type in ("qwen2_vl", "qwen2_5_vl"):
if model_type in ("qwen2_vl", "qwen2_5_vl", "qwen3_omni_moe"):
return "<|vision_start|><|image_pad|><|vision_end|>"
if model_type == "qwen2_5_omni":
return "<|vision_start|><|IMAGE|><|vision_end|>"
@@ -553,13 +553,15 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if model_type in ("qwen2_audio", "qwen2_5_omni"):
return (f"Audio {current_count}: "
f"<|audio_bos|><|AUDIO|><|audio_eos|>")
if model_type == "qwen3_omni_moe":
return f"<|audio_start|><|audio_pad|><|audio_end|>"
if model_type == "minicpmo":
return "(<audio>./</audio>)"
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "video":
if model_type == "internvl_chat":
return "<video>"
if model_type in ("qwen2_vl", "qwen2_5_vl"):
if model_type in ("qwen2_vl", "qwen2_5_vl", "qwen3_omni_moe"):
return "<|vision_start|><|video_pad|><|vision_end|>"
if model_type == "qwen2_5_omni":
return "<|vision_start|><|VIDEO|><|vision_end|>"

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import abstractmethod
from collections.abc import Iterable
from collections.abc import Iterable, Sequence
from enum import Enum
from typing import Callable, Literal, Optional, overload
@@ -1667,37 +1667,57 @@ class FusedMoE(CustomOp):
return final_hidden_states
# @classmethod
# def make_expert_params_mapping(
# cls,
# ckpt_gate_proj_name: str,
# ckpt_down_proj_name: str,
# ckpt_up_proj_name: str,
# num_experts: int,
# num_redundant_experts: int = 0) -> list[tuple[str, str, int, str]]:
@classmethod
def make_expert_params_mapping(
cls,
ckpt_gate_proj_name: str,
ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int,
num_redundant_experts: int = 0) -> list[tuple[str, str, int, str]]:
# num_physical_experts = num_experts + num_redundant_experts
num_physical_experts = num_experts + num_redundant_experts
# # In the returned mapping:
# # - `expert_id` is the physical expert id
# # - `weight_name` contains the weight name of the logical expert
# # So that we should map the expert id to logical in `weight_name`
# physical_to_logical_map = \
# EplbState.build_initial_global_physical_to_logical_map(
# num_experts, num_redundant_experts)
# return [
# # (param_name, weight_name, expert_id, shard_id)
# ("experts.w13_" if weight_name
# in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
# f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.",
# expert_id, shard_id) for expert_id in range(num_physical_experts)
# for shard_id, weight_name in [
# ("w1", ckpt_gate_proj_name),
# ("w2", ckpt_down_proj_name),
# ("w3", ckpt_up_proj_name),
# ]
# ]
def build_initial_global_physical_to_logical_map(
num_routed_experts: int,
num_redundant_experts: int,
) -> Sequence[int]:
"""
Build an initial expert arrangement using the following structure:
[original routed experts, redundant experts]
Returns:
physical_to_logical_map (Sequence[int]): A list of integers,
where each integer is the index of the logical expert
that the corresponding physical expert maps to.
"""
global_physical_to_logical_map = list(range(num_routed_experts))
global_physical_to_logical_map += [
i % num_routed_experts for i in range(num_redundant_experts)
]
return global_physical_to_logical_map
# In the returned mapping:
# - `expert_id` is the physical expert id
# - `weight_name` contains the weight name of the logical expert
# So that we should map the expert id to logical in `weight_name`
physical_to_logical_map = \
build_initial_global_physical_to_logical_map(
num_experts, num_redundant_experts)
return [
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_" if weight_name
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.",
expert_id, shard_id) for expert_id in range(num_physical_experts)
for shard_id, weight_name in [
("w1", ckpt_gate_proj_name),
("w2", ckpt_down_proj_name),
("w3", ckpt_up_proj_name),
]
]
def extra_repr(self) -> str:

View File

@@ -1227,6 +1227,214 @@ class MRotaryEmbedding(RotaryEmbedding):
return llm_positions, mrope_position_delta
@classmethod
def _omni3_get_input_positions_tensor(
cls,
config,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
audio_seqlens: Optional[torch.LongTensor] = None,
second_per_grids: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
input_lengths_leave = input_lengths % 100
feat_lengths = (input_lengths_leave - 1) // 2 + 1
output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
return output_lengths
spatial_merge_size = config.vision_config.spatial_merge_size
image_token_id = config.image_token_id
video_token_id = config.video_token_id
audio_token_id = config.audio_token_id
vision_start_token_id = config.vision_start_token_id
audio_start_token_id = config.audio_start_token_id
position_id_per_seconds = config.position_id_per_seconds
mrope_position_deltas = []
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.zeros(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=input_ids.dtype,
device=input_ids.device,
)
image_idx, video_idx, audio_idx = 0, 0, 0
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i] == 1]
image_nums, video_nums, audio_nums = 0, 0, 0
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
audio_nums = torch.sum(input_ids == audio_start_token_id)
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (
(vision_tokens == audio_start_token_id).sum()
if use_audio_in_video
else (vision_tokens == video_token_id).sum()
)
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums
multimodal_nums = (
image_nums + audio_nums if use_audio_in_video else image_nums + video_nums + audio_nums
)
for _ in range(multimodal_nums):
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
if (image_token_id in input_tokens or video_token_id in input_tokens) and (
remain_videos > 0 or remain_images > 0
):
ed_vision_start = input_tokens.index(vision_start_token_id, st)
else:
ed_vision_start = len(input_tokens) + 1
if audio_token_id in input_tokens and remain_audios > 0:
ed_audio_start = input_tokens.index(audio_start_token_id, st)
else:
ed_audio_start = len(input_tokens) + 1
min_ed = min(ed_vision_start, ed_audio_start)
if min_ed == ed_audio_start:
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
bos_len = 1
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx])
llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
eos_len = 1
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
st += text_len + bos_len + audio_len + eos_len
audio_idx += 1
remain_audios -= 1
elif min_ed == ed_vision_start and input_ids[ed_vision_start + 1] == image_token_id:
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
bos_len = 1
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
grid_t = image_grid_thw[image_idx][0]
grid_hs = image_grid_thw[:, 1]
grid_ws = image_grid_thw[:, 2]
t_index = ((torch.arange(grid_t)) * 1 * position_id_per_seconds)
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2)
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
eos_len = 1
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
st += text_len + bos_len + image_len + eos_len
image_idx += 1
remain_images -= 1
elif min_ed == ed_vision_start and input_ids[ed_vision_start + 1] == video_token_id and not use_audio_in_video:
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
bos_len = 1
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
(torch.arange(grid_t)) * second_per_grids[video_idx].cpu() * position_id_per_seconds
)
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
eos_len = 1
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
st += text_len + bos_len + video_len + eos_len
video_idx += 1
remain_videos -= 1
elif min_ed == ed_vision_start and ed_vision_start + 1 == ed_audio_start and use_audio_in_video:
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
bos_len = 1
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx])
audio_llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
(torch.arange(grid_t)) * second_per_grids[video_idx].cpu() * position_id_per_seconds
)
video_llm_pos_ids = cls._get_llm_pos_ids_for_vision(
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
video_data_index, audio_data_index = 0, 0
while (
video_data_index < video_llm_pos_ids.shape[-1]
and audio_data_index < audio_llm_pos_ids.shape[-1]
):
if video_llm_pos_ids[0][video_data_index] <= audio_llm_pos_ids[0][audio_data_index]:
llm_pos_ids_list.append(video_llm_pos_ids[:, video_data_index : video_data_index + 1])
video_data_index += 1
else:
llm_pos_ids_list.append(audio_llm_pos_ids[:, audio_data_index : audio_data_index + 1])
audio_data_index += 1
if video_data_index < video_llm_pos_ids.shape[-1]:
llm_pos_ids_list.append(
video_llm_pos_ids[:, video_data_index : video_llm_pos_ids.shape[-1]]
)
if audio_data_index < audio_llm_pos_ids.shape[-1]:
llm_pos_ids_list.append(
audio_llm_pos_ids[:, audio_data_index : audio_llm_pos_ids.shape[-1]]
)
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
eos_len = 1
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2
audio_idx += 1
video_idx += 1
remain_videos -= 1
remain_audios -= 1
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(llm_positions.long().max() + 1 - len(input_ids))
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas.long()
else:
position_ids = attention_mask.cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True)
return position_ids, mrope_position_deltas.long()
@classmethod
def _omni_get_input_positions_tensor(
cls,
@@ -1259,7 +1467,29 @@ class MRotaryEmbedding(RotaryEmbedding):
# TODO(fyabc): refactor and share more code with
# _vl_get_input_positions_tensor.
model_type = hf_config.model_type
thinker_config = hf_config.thinker_config
if isinstance(image_grid_thw, list):
image_grid_thw = torch.tensor(image_grid_thw)
if isinstance(video_grid_thw, list):
video_grid_thw = torch.tensor(video_grid_thw)
if "qwen3_omni" in model_type:
llm_positions, mrope_position_delta = cls._omni3_get_input_positions_tensor(
thinker_config,
torch.tensor([input_tokens]),
image_grid_thw,
video_grid_thw,
None,
use_audio_in_video,
audio_feature_lengths,
torch.ones(len(video_grid_thw))
)
llm_positions = llm_positions.squeeze(1)
mrope_position_delta = mrope_position_delta.squeeze()
return llm_positions, mrope_position_delta
audio_token_id = thinker_config.audio_token_index
image_token_id = thinker_config.image_token_index
video_token_id = thinker_config.video_token_index
@@ -1272,11 +1502,6 @@ class MRotaryEmbedding(RotaryEmbedding):
tokens_per_second = getattr(thinker_config.vision_config,
"tokens_per_second", 25)
if isinstance(image_grid_thw, list):
image_grid_thw = torch.tensor(image_grid_thw)
if isinstance(video_grid_thw, list):
video_grid_thw = torch.tensor(video_grid_thw)
src_item = input_tokens
audio_seqlens = audio_feature_lengths
if not second_per_grid_ts:

File diff suppressed because it is too large Load Diff

View File

@@ -214,6 +214,8 @@ _MULTIMODAL_MODELS = {
"Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501
"Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
"Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
"Qwen3OmniMoeForConditionalGeneration": ("qwen3_omni_moe_thinker", "Qwen3OmniMoeThinkerForConditionalGeneration"),
"Qwen3OmniMoeModel": ("qwen3_omni_moe_thinker", "Qwen3OmniMoeThinkerForConditionalGeneration"),
"UltravoxModel": ("ultravox", "UltravoxModel"),
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501