Qwen2vl support cuda graph and disable radix cache (#1780)
This commit is contained in:
@@ -22,64 +22,33 @@ class MRotaryEmbedding:
|
||||
|
||||
@staticmethod
|
||||
def get_input_positions(
|
||||
input_tokens: List[int],
|
||||
input_tokens: torch.Tensor,
|
||||
image_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||
video_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||
image_token_id: int,
|
||||
video_token_id: int,
|
||||
vision_start_token_id: int,
|
||||
vision_end_token_id: int,
|
||||
spatial_merge_size: int,
|
||||
context_len: int = 0,
|
||||
extend_prefix_len: int = 0,
|
||||
) -> Tuple[List[List[int]], int]:
|
||||
"""Get mrope input positions and delta value."""
|
||||
|
||||
if isinstance(image_grid_thw, torch.Tensor):
|
||||
image_grid_thw = image_grid_thw.tolist()
|
||||
if isinstance(video_grid_thw, torch.Tensor):
|
||||
video_grid_thw = video_grid_thw.tolist()
|
||||
|
||||
input_tokens_tensor = torch.tensor(input_tokens)
|
||||
vision_start_indices = torch.argwhere(
|
||||
input_tokens_tensor == vision_start_token_id
|
||||
input_tokens == vision_start_token_id
|
||||
).squeeze(1)
|
||||
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
|
||||
image_nums = (vision_tokens == image_token_id).sum()
|
||||
video_nums = (vision_tokens == video_token_id).sum()
|
||||
image_indices = vision_start_indices + 1
|
||||
image_nums = image_indices.shape[0]
|
||||
llm_pos_ids_list: list = []
|
||||
|
||||
st = 0
|
||||
remain_images, remain_videos = image_nums, video_nums
|
||||
|
||||
image_index, video_index = 0, 0
|
||||
for _ in range(image_nums + video_nums):
|
||||
if image_token_id in input_tokens and remain_images > 0:
|
||||
ed_image = input_tokens.index(image_token_id, st)
|
||||
else:
|
||||
ed_image = len(input_tokens) + 1
|
||||
if video_token_id in input_tokens and remain_videos > 0:
|
||||
ed_video = input_tokens.index(video_token_id, st)
|
||||
else:
|
||||
ed_video = len(input_tokens) + 1
|
||||
if ed_image < ed_video:
|
||||
t, h, w = (
|
||||
image_grid_thw[image_index][0],
|
||||
image_grid_thw[image_index][1],
|
||||
image_grid_thw[image_index][2],
|
||||
)
|
||||
image_index += 1
|
||||
remain_images -= 1
|
||||
ed = ed_image
|
||||
else:
|
||||
t, h, w = (
|
||||
video_grid_thw[video_index][0],
|
||||
video_grid_thw[video_index][1],
|
||||
video_grid_thw[video_index][2],
|
||||
)
|
||||
video_index += 1
|
||||
remain_videos -= 1
|
||||
ed = ed_video
|
||||
input_tokens_len = input_tokens.shape[0]
|
||||
for image_index in range(image_nums):
|
||||
ed = image_indices[image_index].item()
|
||||
t, h, w = (
|
||||
image_grid_thw[image_index][0],
|
||||
image_grid_thw[image_index][1],
|
||||
image_grid_thw[image_index][2],
|
||||
)
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t,
|
||||
h // spatial_merge_size,
|
||||
@@ -115,18 +84,16 @@ class MRotaryEmbedding:
|
||||
)
|
||||
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||
|
||||
if st < len(input_tokens):
|
||||
if st < input_tokens_len:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
text_len = len(input_tokens) - st
|
||||
text_len = input_tokens_len - st
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
llm_positions = llm_positions[:, context_len:]
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
llm_positions += extend_prefix_len
|
||||
|
||||
mrope_position_delta = (llm_positions.max() + 1 - input_tokens_len).item()
|
||||
return llm_positions.tolist(), mrope_position_delta
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user