[qwen3-omni] Add Qwen3-Omni moe thinker
This commit is contained in:
@@ -1227,6 +1227,214 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
@classmethod
|
||||
def _omni3_get_input_positions_tensor(
|
||||
cls,
|
||||
config,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
use_audio_in_video: bool = False,
|
||||
audio_seqlens: Optional[torch.LongTensor] = None,
|
||||
second_per_grids: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
|
||||
input_lengths_leave = input_lengths % 100
|
||||
feat_lengths = (input_lengths_leave - 1) // 2 + 1
|
||||
output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
||||
return output_lengths
|
||||
spatial_merge_size = config.vision_config.spatial_merge_size
|
||||
image_token_id = config.image_token_id
|
||||
video_token_id = config.video_token_id
|
||||
audio_token_id = config.audio_token_id
|
||||
vision_start_token_id = config.vision_start_token_id
|
||||
audio_start_token_id = config.audio_start_token_id
|
||||
position_id_per_seconds = config.position_id_per_seconds
|
||||
mrope_position_deltas = []
|
||||
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
|
||||
total_input_ids = input_ids
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(total_input_ids)
|
||||
position_ids = torch.zeros(
|
||||
3,
|
||||
input_ids.shape[0],
|
||||
input_ids.shape[1],
|
||||
dtype=input_ids.dtype,
|
||||
device=input_ids.device,
|
||||
)
|
||||
image_idx, video_idx, audio_idx = 0, 0, 0
|
||||
attention_mask = attention_mask.to(total_input_ids.device)
|
||||
for i, input_ids in enumerate(total_input_ids):
|
||||
input_ids = input_ids[attention_mask[i] == 1]
|
||||
image_nums, video_nums, audio_nums = 0, 0, 0
|
||||
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
|
||||
vision_tokens = input_ids[vision_start_indices + 1]
|
||||
audio_nums = torch.sum(input_ids == audio_start_token_id)
|
||||
image_nums = (vision_tokens == image_token_id).sum()
|
||||
video_nums = (
|
||||
(vision_tokens == audio_start_token_id).sum()
|
||||
if use_audio_in_video
|
||||
else (vision_tokens == video_token_id).sum()
|
||||
)
|
||||
input_tokens = input_ids.tolist()
|
||||
llm_pos_ids_list: list = []
|
||||
st = 0
|
||||
remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums
|
||||
multimodal_nums = (
|
||||
image_nums + audio_nums if use_audio_in_video else image_nums + video_nums + audio_nums
|
||||
)
|
||||
for _ in range(multimodal_nums):
|
||||
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
if (image_token_id in input_tokens or video_token_id in input_tokens) and (
|
||||
remain_videos > 0 or remain_images > 0
|
||||
):
|
||||
ed_vision_start = input_tokens.index(vision_start_token_id, st)
|
||||
else:
|
||||
ed_vision_start = len(input_tokens) + 1
|
||||
if audio_token_id in input_tokens and remain_audios > 0:
|
||||
ed_audio_start = input_tokens.index(audio_start_token_id, st)
|
||||
else:
|
||||
ed_audio_start = len(input_tokens) + 1
|
||||
min_ed = min(ed_vision_start, ed_audio_start)
|
||||
if min_ed == ed_audio_start:
|
||||
text_len = min_ed - st
|
||||
if text_len != 0:
|
||||
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
bos_len = 1
|
||||
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx])
|
||||
llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
eos_len = 1
|
||||
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
st += text_len + bos_len + audio_len + eos_len
|
||||
audio_idx += 1
|
||||
remain_audios -= 1
|
||||
elif min_ed == ed_vision_start and input_ids[ed_vision_start + 1] == image_token_id:
|
||||
text_len = min_ed - st
|
||||
if text_len != 0:
|
||||
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
bos_len = 1
|
||||
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
grid_t = image_grid_thw[image_idx][0]
|
||||
grid_hs = image_grid_thw[:, 1]
|
||||
grid_ws = image_grid_thw[:, 2]
|
||||
t_index = ((torch.arange(grid_t)) * 1 * position_id_per_seconds)
|
||||
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
|
||||
st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||
)
|
||||
image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
eos_len = 1
|
||||
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
st += text_len + bos_len + image_len + eos_len
|
||||
image_idx += 1
|
||||
remain_images -= 1
|
||||
elif min_ed == ed_vision_start and input_ids[ed_vision_start + 1] == video_token_id and not use_audio_in_video:
|
||||
text_len = min_ed - st
|
||||
if text_len != 0:
|
||||
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
bos_len = 1
|
||||
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
grid_t = video_grid_thw[video_idx][0]
|
||||
grid_hs = video_grid_thw[:, 1]
|
||||
grid_ws = video_grid_thw[:, 2]
|
||||
t_index = (
|
||||
(torch.arange(grid_t)) * second_per_grids[video_idx].cpu() * position_id_per_seconds
|
||||
)
|
||||
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
|
||||
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||
)
|
||||
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
eos_len = 1
|
||||
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
st += text_len + bos_len + video_len + eos_len
|
||||
video_idx += 1
|
||||
remain_videos -= 1
|
||||
elif min_ed == ed_vision_start and ed_vision_start + 1 == ed_audio_start and use_audio_in_video:
|
||||
text_len = min_ed - st
|
||||
if text_len != 0:
|
||||
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
bos_len = 1
|
||||
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx])
|
||||
audio_llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
|
||||
grid_t = video_grid_thw[video_idx][0]
|
||||
grid_hs = video_grid_thw[:, 1]
|
||||
grid_ws = video_grid_thw[:, 2]
|
||||
t_index = (
|
||||
(torch.arange(grid_t)) * second_per_grids[video_idx].cpu() * position_id_per_seconds
|
||||
)
|
||||
video_llm_pos_ids = cls._get_llm_pos_ids_for_vision(
|
||||
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||
)
|
||||
video_data_index, audio_data_index = 0, 0
|
||||
while (
|
||||
video_data_index < video_llm_pos_ids.shape[-1]
|
||||
and audio_data_index < audio_llm_pos_ids.shape[-1]
|
||||
):
|
||||
if video_llm_pos_ids[0][video_data_index] <= audio_llm_pos_ids[0][audio_data_index]:
|
||||
llm_pos_ids_list.append(video_llm_pos_ids[:, video_data_index : video_data_index + 1])
|
||||
video_data_index += 1
|
||||
else:
|
||||
llm_pos_ids_list.append(audio_llm_pos_ids[:, audio_data_index : audio_data_index + 1])
|
||||
audio_data_index += 1
|
||||
if video_data_index < video_llm_pos_ids.shape[-1]:
|
||||
llm_pos_ids_list.append(
|
||||
video_llm_pos_ids[:, video_data_index : video_llm_pos_ids.shape[-1]]
|
||||
)
|
||||
if audio_data_index < audio_llm_pos_ids.shape[-1]:
|
||||
llm_pos_ids_list.append(
|
||||
audio_llm_pos_ids[:, audio_data_index : audio_llm_pos_ids.shape[-1]]
|
||||
)
|
||||
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
|
||||
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
eos_len = 1
|
||||
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2
|
||||
audio_idx += 1
|
||||
video_idx += 1
|
||||
remain_videos -= 1
|
||||
remain_audios -= 1
|
||||
if st < len(input_tokens):
|
||||
st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
|
||||
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
|
||||
mrope_position_deltas.append(llm_positions.long().max() + 1 - len(input_ids))
|
||||
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
|
||||
return position_ids, mrope_position_deltas.long()
|
||||
else:
|
||||
position_ids = attention_mask.cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
|
||||
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
|
||||
mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True)
|
||||
return position_ids, mrope_position_deltas.long()
|
||||
|
||||
|
||||
@classmethod
|
||||
def _omni_get_input_positions_tensor(
|
||||
cls,
|
||||
@@ -1259,7 +1467,29 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
# TODO(fyabc): refactor and share more code with
|
||||
# _vl_get_input_positions_tensor.
|
||||
|
||||
model_type = hf_config.model_type
|
||||
thinker_config = hf_config.thinker_config
|
||||
|
||||
if isinstance(image_grid_thw, list):
|
||||
image_grid_thw = torch.tensor(image_grid_thw)
|
||||
if isinstance(video_grid_thw, list):
|
||||
video_grid_thw = torch.tensor(video_grid_thw)
|
||||
|
||||
if "qwen3_omni" in model_type:
|
||||
llm_positions, mrope_position_delta = cls._omni3_get_input_positions_tensor(
|
||||
thinker_config,
|
||||
torch.tensor([input_tokens]),
|
||||
image_grid_thw,
|
||||
video_grid_thw,
|
||||
None,
|
||||
use_audio_in_video,
|
||||
audio_feature_lengths,
|
||||
torch.ones(len(video_grid_thw))
|
||||
)
|
||||
llm_positions = llm_positions.squeeze(1)
|
||||
mrope_position_delta = mrope_position_delta.squeeze()
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
audio_token_id = thinker_config.audio_token_index
|
||||
image_token_id = thinker_config.image_token_index
|
||||
video_token_id = thinker_config.video_token_index
|
||||
@@ -1272,11 +1502,6 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
tokens_per_second = getattr(thinker_config.vision_config,
|
||||
"tokens_per_second", 25)
|
||||
|
||||
if isinstance(image_grid_thw, list):
|
||||
image_grid_thw = torch.tensor(image_grid_thw)
|
||||
if isinstance(video_grid_thw, list):
|
||||
video_grid_thw = torch.tensor(video_grid_thw)
|
||||
|
||||
src_item = input_tokens
|
||||
audio_seqlens = audio_feature_lengths
|
||||
if not second_per_grid_ts:
|
||||
|
||||
Reference in New Issue
Block a user