[qwen3-omni] Add Qwen3-Omni moe thinker

This commit is contained in:
2025-10-09 17:51:14 +08:00
parent bc57e2ef60
commit 24fab12b2f
8 changed files with 1543 additions and 37 deletions

View File

@@ -1,4 +1,5 @@
FROM cr.metax-tech.com/public-ai-release/maca/vllm:maca.ai3.0.0.5-torch2.6-py310-ubuntu22.04-amd64 FROM cr.metax-tech.com/public-ai-release/maca/vllm:maca.ai3.0.0.5-torch2.6-py310-ubuntu22.04-amd64
RUN /opt/conda/bin/pip install --no-cache-dir --upgrade transformers
COPY vllm/ /opt/conda/lib/python3.10/site-packages/vllm/ COPY vllm/ /opt/conda/lib/python3.10/site-packages/vllm/
COPY code_generator.py /opt/conda/lib/python3.10/site-packages/triton/compiler/code_generator.py COPY code_generator.py /opt/conda/lib/python3.10/site-packages/triton/compiler/code_generator.py

View File

@@ -1,6 +1,6 @@
# metax-c500-vllm # metax-c500-vllm
本项目包含了对于原版 vllm 的升级,使其可以在沐曦 C500芯片上支持运行 gpt-oss 本项目包含了对于原版 vllm 的升级,使其可以在沐曦 C500芯片上支持运行 gpt-oss, qwen3-omni
本项目中提供的 Dockerfile 中的主要内容为: 本项目中提供的 Dockerfile 中的主要内容为:
1.`vllm` 目录覆盖到镜像中的 `/opt/conda/lib/python3.10/site-packages/vllm`。运行`gpt-oss`时需指定`VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1` 1.`vllm` 目录覆盖到镜像中的 `/opt/conda/lib/python3.10/site-packages/vllm`。运行`gpt-oss`时需指定`VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1`

View File

@@ -54,6 +54,14 @@ def check_xformers_availability():
return USE_XFORMERS_OPS return USE_XFORMERS_OPS
def check_upstream_fa_availability(dtype: torch.dtype):
if dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda(
) and current_platform.has_device_capability(80):
from transformers.utils import is_flash_attn_2_available
return is_flash_attn_2_available()
return False
class Attention(nn.Module): class Attention(nn.Module):
"""Attention layer. """Attention layer.

View File

@@ -531,7 +531,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return "<image>" return "<image>"
if model_type in ("mllama", "llama4"): if model_type in ("mllama", "llama4"):
return "<|image|>" return "<|image|>"
if model_type in ("qwen2_vl", "qwen2_5_vl"): if model_type in ("qwen2_vl", "qwen2_5_vl", "qwen3_omni_moe"):
return "<|vision_start|><|image_pad|><|vision_end|>" return "<|vision_start|><|image_pad|><|vision_end|>"
if model_type == "qwen2_5_omni": if model_type == "qwen2_5_omni":
return "<|vision_start|><|IMAGE|><|vision_end|>" return "<|vision_start|><|IMAGE|><|vision_end|>"
@@ -553,13 +553,15 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if model_type in ("qwen2_audio", "qwen2_5_omni"): if model_type in ("qwen2_audio", "qwen2_5_omni"):
return (f"Audio {current_count}: " return (f"Audio {current_count}: "
f"<|audio_bos|><|AUDIO|><|audio_eos|>") f"<|audio_bos|><|AUDIO|><|audio_eos|>")
if model_type == "qwen3_omni_moe":
return f"<|audio_start|><|audio_pad|><|audio_end|>"
if model_type == "minicpmo": if model_type == "minicpmo":
return "(<audio>./</audio>)" return "(<audio>./</audio>)"
raise TypeError(f"Unknown model type: {model_type}") raise TypeError(f"Unknown model type: {model_type}")
elif modality == "video": elif modality == "video":
if model_type == "internvl_chat": if model_type == "internvl_chat":
return "<video>" return "<video>"
if model_type in ("qwen2_vl", "qwen2_5_vl"): if model_type in ("qwen2_vl", "qwen2_5_vl", "qwen3_omni_moe"):
return "<|vision_start|><|video_pad|><|vision_end|>" return "<|vision_start|><|video_pad|><|vision_end|>"
if model_type == "qwen2_5_omni": if model_type == "qwen2_5_omni":
return "<|vision_start|><|VIDEO|><|vision_end|>" return "<|vision_start|><|VIDEO|><|vision_end|>"

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable from collections.abc import Iterable, Sequence
from enum import Enum from enum import Enum
from typing import Callable, Literal, Optional, overload from typing import Callable, Literal, Optional, overload
@@ -1667,37 +1667,57 @@ class FusedMoE(CustomOp):
return final_hidden_states return final_hidden_states
# @classmethod @classmethod
# def make_expert_params_mapping( def make_expert_params_mapping(
# cls, cls,
# ckpt_gate_proj_name: str, ckpt_gate_proj_name: str,
# ckpt_down_proj_name: str, ckpt_down_proj_name: str,
# ckpt_up_proj_name: str, ckpt_up_proj_name: str,
# num_experts: int, num_experts: int,
# num_redundant_experts: int = 0) -> list[tuple[str, str, int, str]]: num_redundant_experts: int = 0) -> list[tuple[str, str, int, str]]:
# num_physical_experts = num_experts + num_redundant_experts num_physical_experts = num_experts + num_redundant_experts
# # In the returned mapping:
# # - `expert_id` is the physical expert id
# # - `weight_name` contains the weight name of the logical expert
# # So that we should map the expert id to logical in `weight_name`
# physical_to_logical_map = \
# EplbState.build_initial_global_physical_to_logical_map(
# num_experts, num_redundant_experts)
# return [ def build_initial_global_physical_to_logical_map(
# # (param_name, weight_name, expert_id, shard_id) num_routed_experts: int,
# ("experts.w13_" if weight_name num_redundant_experts: int,
# in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", ) -> Sequence[int]:
# f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.", """
# expert_id, shard_id) for expert_id in range(num_physical_experts) Build an initial expert arrangement using the following structure:
# for shard_id, weight_name in [ [original routed experts, redundant experts]
# ("w1", ckpt_gate_proj_name),
# ("w2", ckpt_down_proj_name), Returns:
# ("w3", ckpt_up_proj_name), physical_to_logical_map (Sequence[int]): A list of integers,
# ] where each integer is the index of the logical expert
# ] that the corresponding physical expert maps to.
"""
global_physical_to_logical_map = list(range(num_routed_experts))
global_physical_to_logical_map += [
i % num_routed_experts for i in range(num_redundant_experts)
]
return global_physical_to_logical_map
# In the returned mapping:
# - `expert_id` is the physical expert id
# - `weight_name` contains the weight name of the logical expert
# So that we should map the expert id to logical in `weight_name`
physical_to_logical_map = \
build_initial_global_physical_to_logical_map(
num_experts, num_redundant_experts)
return [
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_" if weight_name
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.",
expert_id, shard_id) for expert_id in range(num_physical_experts)
for shard_id, weight_name in [
("w1", ckpt_gate_proj_name),
("w2", ckpt_down_proj_name),
("w3", ckpt_up_proj_name),
]
]
def extra_repr(self) -> str: def extra_repr(self) -> str:

View File

@@ -1227,6 +1227,214 @@ class MRotaryEmbedding(RotaryEmbedding):
return llm_positions, mrope_position_delta return llm_positions, mrope_position_delta
@classmethod
def _omni3_get_input_positions_tensor(
cls,
config,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
audio_seqlens: Optional[torch.LongTensor] = None,
second_per_grids: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
input_lengths_leave = input_lengths % 100
feat_lengths = (input_lengths_leave - 1) // 2 + 1
output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
return output_lengths
spatial_merge_size = config.vision_config.spatial_merge_size
image_token_id = config.image_token_id
video_token_id = config.video_token_id
audio_token_id = config.audio_token_id
vision_start_token_id = config.vision_start_token_id
audio_start_token_id = config.audio_start_token_id
position_id_per_seconds = config.position_id_per_seconds
mrope_position_deltas = []
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.zeros(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=input_ids.dtype,
device=input_ids.device,
)
image_idx, video_idx, audio_idx = 0, 0, 0
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i] == 1]
image_nums, video_nums, audio_nums = 0, 0, 0
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
audio_nums = torch.sum(input_ids == audio_start_token_id)
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (
(vision_tokens == audio_start_token_id).sum()
if use_audio_in_video
else (vision_tokens == video_token_id).sum()
)
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums
multimodal_nums = (
image_nums + audio_nums if use_audio_in_video else image_nums + video_nums + audio_nums
)
for _ in range(multimodal_nums):
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
if (image_token_id in input_tokens or video_token_id in input_tokens) and (
remain_videos > 0 or remain_images > 0
):
ed_vision_start = input_tokens.index(vision_start_token_id, st)
else:
ed_vision_start = len(input_tokens) + 1
if audio_token_id in input_tokens and remain_audios > 0:
ed_audio_start = input_tokens.index(audio_start_token_id, st)
else:
ed_audio_start = len(input_tokens) + 1
min_ed = min(ed_vision_start, ed_audio_start)
if min_ed == ed_audio_start:
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
bos_len = 1
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx])
llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
eos_len = 1
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
st += text_len + bos_len + audio_len + eos_len
audio_idx += 1
remain_audios -= 1
elif min_ed == ed_vision_start and input_ids[ed_vision_start + 1] == image_token_id:
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
bos_len = 1
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
grid_t = image_grid_thw[image_idx][0]
grid_hs = image_grid_thw[:, 1]
grid_ws = image_grid_thw[:, 2]
t_index = ((torch.arange(grid_t)) * 1 * position_id_per_seconds)
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2)
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
eos_len = 1
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
st += text_len + bos_len + image_len + eos_len
image_idx += 1
remain_images -= 1
elif min_ed == ed_vision_start and input_ids[ed_vision_start + 1] == video_token_id and not use_audio_in_video:
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
bos_len = 1
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
(torch.arange(grid_t)) * second_per_grids[video_idx].cpu() * position_id_per_seconds
)
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
eos_len = 1
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
st += text_len + bos_len + video_len + eos_len
video_idx += 1
remain_videos -= 1
elif min_ed == ed_vision_start and ed_vision_start + 1 == ed_audio_start and use_audio_in_video:
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
bos_len = 1
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx])
audio_llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
(torch.arange(grid_t)) * second_per_grids[video_idx].cpu() * position_id_per_seconds
)
video_llm_pos_ids = cls._get_llm_pos_ids_for_vision(
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
video_data_index, audio_data_index = 0, 0
while (
video_data_index < video_llm_pos_ids.shape[-1]
and audio_data_index < audio_llm_pos_ids.shape[-1]
):
if video_llm_pos_ids[0][video_data_index] <= audio_llm_pos_ids[0][audio_data_index]:
llm_pos_ids_list.append(video_llm_pos_ids[:, video_data_index : video_data_index + 1])
video_data_index += 1
else:
llm_pos_ids_list.append(audio_llm_pos_ids[:, audio_data_index : audio_data_index + 1])
audio_data_index += 1
if video_data_index < video_llm_pos_ids.shape[-1]:
llm_pos_ids_list.append(
video_llm_pos_ids[:, video_data_index : video_llm_pos_ids.shape[-1]]
)
if audio_data_index < audio_llm_pos_ids.shape[-1]:
llm_pos_ids_list.append(
audio_llm_pos_ids[:, audio_data_index : audio_llm_pos_ids.shape[-1]]
)
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
eos_len = 1
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2
audio_idx += 1
video_idx += 1
remain_videos -= 1
remain_audios -= 1
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(llm_positions.long().max() + 1 - len(input_ids))
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas.long()
else:
position_ids = attention_mask.cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True)
return position_ids, mrope_position_deltas.long()
@classmethod @classmethod
def _omni_get_input_positions_tensor( def _omni_get_input_positions_tensor(
cls, cls,
@@ -1259,7 +1467,29 @@ class MRotaryEmbedding(RotaryEmbedding):
# TODO(fyabc): refactor and share more code with # TODO(fyabc): refactor and share more code with
# _vl_get_input_positions_tensor. # _vl_get_input_positions_tensor.
model_type = hf_config.model_type
thinker_config = hf_config.thinker_config thinker_config = hf_config.thinker_config
if isinstance(image_grid_thw, list):
image_grid_thw = torch.tensor(image_grid_thw)
if isinstance(video_grid_thw, list):
video_grid_thw = torch.tensor(video_grid_thw)
if "qwen3_omni" in model_type:
llm_positions, mrope_position_delta = cls._omni3_get_input_positions_tensor(
thinker_config,
torch.tensor([input_tokens]),
image_grid_thw,
video_grid_thw,
None,
use_audio_in_video,
audio_feature_lengths,
torch.ones(len(video_grid_thw))
)
llm_positions = llm_positions.squeeze(1)
mrope_position_delta = mrope_position_delta.squeeze()
return llm_positions, mrope_position_delta
audio_token_id = thinker_config.audio_token_index audio_token_id = thinker_config.audio_token_index
image_token_id = thinker_config.image_token_index image_token_id = thinker_config.image_token_index
video_token_id = thinker_config.video_token_index video_token_id = thinker_config.video_token_index
@@ -1272,11 +1502,6 @@ class MRotaryEmbedding(RotaryEmbedding):
tokens_per_second = getattr(thinker_config.vision_config, tokens_per_second = getattr(thinker_config.vision_config,
"tokens_per_second", 25) "tokens_per_second", 25)
if isinstance(image_grid_thw, list):
image_grid_thw = torch.tensor(image_grid_thw)
if isinstance(video_grid_thw, list):
video_grid_thw = torch.tensor(video_grid_thw)
src_item = input_tokens src_item = input_tokens
audio_seqlens = audio_feature_lengths audio_seqlens = audio_feature_lengths
if not second_per_grid_ts: if not second_per_grid_ts:

File diff suppressed because it is too large Load Diff

View File

@@ -214,6 +214,8 @@ _MULTIMODAL_MODELS = {
"Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501 "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501
"Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
"Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
"Qwen3OmniMoeForConditionalGeneration": ("qwen3_omni_moe_thinker", "Qwen3OmniMoeThinkerForConditionalGeneration"),
"Qwen3OmniMoeModel": ("qwen3_omni_moe_thinker", "Qwen3OmniMoeThinkerForConditionalGeneration"),
"UltravoxModel": ("ultravox", "UltravoxModel"), "UltravoxModel": ("ultravox", "UltravoxModel"),
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501 "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501