[Misc] Add data preprocess functions to qwen2.5_vl_without_padding (#2148)

### What this PR does / why we need it?
Cherry pick #1705 from v0.9.1-dev
Compared qwen2_5_vl.py, qwen2_5_vl_without_padding.py missing some
funtions. The purpose of this PR is to supplement these.

add:
- rot_pos_emb(self, grid_thw: torch.Tensor)
- get_window_index(self, grid_thw)
- _process_image_input(self, image_input)
- _process_video_input(self, video_input)

Co-authored-by: zheliuyu
[15750543867@163.com](mailto:15750543867@163.com)
Co-authored-by: wangli
[wangli858794774@gmail.com](mailto:wangli858794774@gmail.com)

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.10.0
- vLLM main:
207b750e19

Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
Li Wang
2025-08-01 08:54:02 +08:00
committed by GitHub
parent e3b3ffb875
commit 968e6791d3

View File

@@ -207,6 +207,66 @@ class AscendQwen2_5_VisionTransformer_Without_Padding(Qwen2_5_VisionTransformer
self.hidden_size_per_attention_head)
return cos_new, sin_new
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
).permute(0, 2, 1, 3).flatten()
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
).permute(0, 2, 1, 3).flatten()
pos_ids.append(
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def get_window_index(self, grid_thw):
window_index: list = []
cu_window_seqlens: list = [0]
window_index_id = 0
vit_merger_window_size = (self.window_size //
self.spatial_merge_size // self.patch_size)
for grid_t, grid_h, grid_w in grid_thw:
llm_grid_h = grid_h // self.spatial_merge_size
llm_grid_w = grid_w // self.spatial_merge_size
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
grid_t, llm_grid_h, llm_grid_w)
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100)
index_padded = index_padded.reshape(grid_t, num_windows_h,
vit_merger_window_size,
num_windows_w,
vit_merger_window_size)
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
grid_t, num_windows_h * num_windows_w, vit_merger_window_size,
vit_merger_window_size)
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
index_padded = index_padded.reshape(-1)
index_new = index_padded[index_padded != -100]
window_index.append(index_new + window_index_id)
cu_seqlens_tmp = seqlens.cumsum(
0) * self.spatial_merge_unit + cu_window_seqlens[-1]
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
window_index = torch.cat(window_index, dim=0)
return window_index, cu_window_seqlens
def forward(
self,
x: torch.Tensor,
@@ -258,6 +318,39 @@ class AscendQwen2_5_VisionTransformer_Without_Padding(Qwen2_5_VisionTransformer
x = x[reverse_indices, :]
return x
def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]:
grid_thw = image_input["image_grid_thw"]
assert grid_thw.ndim == 2
if image_input["type"] == "image_embeds":
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
else:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
# Split concatenated embeddings for each image item.
merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size
return image_embeds.split(sizes.tolist())
def _process_video_input(self, video_input) -> tuple[torch.Tensor, ...]:
grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2
if video_input["type"] == "video_embeds":
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
else:
pixel_values_videos = video_input["pixel_values_videos"].type(
self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
# Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size
return video_embeds.split(sizes.tolist())
@MULTIMODAL_REGISTRY.register_processor(
Qwen2_5_VLMultiModalProcessor,