[Misc] Add data preprocess functions to qwen2.5_vl_without_padding (#2148)
### What this PR does / why we need it?
Cherry pick #1705 from v0.9.1-dev
Compared qwen2_5_vl.py, qwen2_5_vl_without_padding.py missing some
funtions. The purpose of this PR is to supplement these.
add:
- rot_pos_emb(self, grid_thw: torch.Tensor)
- get_window_index(self, grid_thw)
- _process_image_input(self, image_input)
- _process_video_input(self, video_input)
Co-authored-by: zheliuyu
[15750543867@163.com](mailto:15750543867@163.com)
Co-authored-by: wangli
[wangli858794774@gmail.com](mailto:wangli858794774@gmail.com)
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.10.0
- vLLM main:
207b750e19
Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
@@ -207,6 +207,66 @@ class AscendQwen2_5_VisionTransformer_Without_Padding(Qwen2_5_VisionTransformer
|
||||
self.hidden_size_per_attention_head)
|
||||
return cos_new, sin_new
|
||||
|
||||
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
pos_ids = []
|
||||
for t, h, w in grid_thw:
|
||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
||||
hpos_ids = hpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
).permute(0, 2, 1, 3).flatten()
|
||||
wpos_ids = wpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
).permute(0, 2, 1, 3).flatten()
|
||||
pos_ids.append(
|
||||
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||
pos_ids = torch.cat(pos_ids, dim=0)
|
||||
max_grid_size = grid_thw[:, 1:].max()
|
||||
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
return rotary_pos_emb
|
||||
|
||||
def get_window_index(self, grid_thw):
|
||||
window_index: list = []
|
||||
cu_window_seqlens: list = [0]
|
||||
window_index_id = 0
|
||||
vit_merger_window_size = (self.window_size //
|
||||
self.spatial_merge_size // self.patch_size)
|
||||
|
||||
for grid_t, grid_h, grid_w in grid_thw:
|
||||
llm_grid_h = grid_h // self.spatial_merge_size
|
||||
llm_grid_w = grid_w // self.spatial_merge_size
|
||||
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
|
||||
grid_t, llm_grid_h, llm_grid_w)
|
||||
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
||||
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
||||
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
||||
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
||||
index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100)
|
||||
index_padded = index_padded.reshape(grid_t, num_windows_h,
|
||||
vit_merger_window_size,
|
||||
num_windows_w,
|
||||
vit_merger_window_size)
|
||||
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
||||
grid_t, num_windows_h * num_windows_w, vit_merger_window_size,
|
||||
vit_merger_window_size)
|
||||
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
||||
index_padded = index_padded.reshape(-1)
|
||||
index_new = index_padded[index_padded != -100]
|
||||
window_index.append(index_new + window_index_id)
|
||||
cu_seqlens_tmp = seqlens.cumsum(
|
||||
0) * self.spatial_merge_unit + cu_window_seqlens[-1]
|
||||
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
||||
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
||||
window_index = torch.cat(window_index, dim=0)
|
||||
return window_index, cu_window_seqlens
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@@ -258,6 +318,39 @@ class AscendQwen2_5_VisionTransformer_Without_Padding(Qwen2_5_VisionTransformer
|
||||
x = x[reverse_indices, :]
|
||||
return x
|
||||
|
||||
def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]:
|
||||
|
||||
grid_thw = image_input["image_grid_thw"]
|
||||
assert grid_thw.ndim == 2
|
||||
|
||||
if image_input["type"] == "image_embeds":
|
||||
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
|
||||
|
||||
# Split concatenated embeddings for each image item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||
return image_embeds.split(sizes.tolist())
|
||||
|
||||
def _process_video_input(self, video_input) -> tuple[torch.Tensor, ...]:
|
||||
|
||||
grid_thw = video_input["video_grid_thw"]
|
||||
assert grid_thw.ndim == 2
|
||||
|
||||
if video_input["type"] == "video_embeds":
|
||||
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
|
||||
else:
|
||||
pixel_values_videos = video_input["pixel_values_videos"].type(
|
||||
self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
|
||||
|
||||
# Split concatenated embeddings for each video item.
|
||||
merge_size = self.visual.spatial_merge_size
|
||||
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||
return video_embeds.split(sizes.tolist())
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Qwen2_5_VLMultiModalProcessor,
|
||||
|
||||
Reference in New Issue
Block a user