From 732664451309f34828a1f387f20cec2cbf757f14 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Sat, 17 May 2025 05:13:32 +0800 Subject: [PATCH] [CI] Fix qwen2.5 vl CI failure (#888) The [vllm commit](https://github.com/vllm-project/vllm/commit/67da5720d4ed2aa1f615ec812031f4f3753b3f62) changed the input and rotary position embedding for qwen 2.5 vl which break CI. This PR fix the CI failure for qwen2.5 vl in quick Signed-off-by: wangxiyuan --- vllm_ascend/models/qwen2_5_vl.py | 113 +++++++++++++++++++++++++++++-- 1 file changed, 109 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/models/qwen2_5_vl.py b/vllm_ascend/models/qwen2_5_vl.py index c06111e..b78bcdf 100644 --- a/vllm_ascend/models/qwen2_5_vl.py +++ b/vllm_ascend/models/qwen2_5_vl.py @@ -36,9 +36,9 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.qwen2_5_vl import ( Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed, - Qwen2_5_VisionTransformer, Qwen2_5_VLDummyInputsBuilder, - Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLMultiModalProcessor, - Qwen2_5_VLProcessingInfo) + Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VisionTransformer, + Qwen2_5_VLDummyInputsBuilder, Qwen2_5_VLForConditionalGeneration, + Qwen2_5_VLMultiModalProcessor, Qwen2_5_VLProcessingInfo) from vllm.model_executor.models.utils import maybe_prefix from vllm.multimodal import MULTIMODAL_REGISTRY @@ -152,6 +152,15 @@ class AscendQwen2_5_VisionPatchEmbed(Qwen2_5_VisionPatchEmbed): return x +class AscendQwen2_5_VisionRotaryEmbedding(Qwen2_5_VisionRotaryEmbedding): + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__(dim, theta) + inv_freq = 1.0 / (theta + **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.inv_freq = inv_freq + + class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer): def __init__( @@ -166,6 +175,9 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer): norm_layer = partial(RMSNorm, eps=norm_eps) self.interleaved = interleaved self.enable_pad = False + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = AscendQwen2_5_VisionRotaryEmbedding(head_dim // + 2) self.patch_embed = AscendQwen2_5_VisionPatchEmbed( patch_size=vision_config.patch_size, temporal_patch_size=vision_config.temporal_patch_size, @@ -298,6 +310,66 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer): loaded_params.add(name) return loaded_params + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = (self.window_size // + self.spatial_merge_size // self.patch_size) + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h = grid_h // self.spatial_merge_size + llm_grid_w = grid_w // self.spatial_merge_size + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100) + index_padded = index_padded.reshape(grid_t, num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, num_windows_h * num_windows_w, vit_merger_window_size, + vit_merger_window_size) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum( + 0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + return window_index, cu_window_seqlens + def forward( self, x: torch.Tensor, @@ -366,4 +438,37 @@ class AscendQwen2_5_VLForConditionalGeneration( norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=self._maybe_ignore_quant_config(quant_config), prefix=maybe_prefix(prefix, "visual"), - ) \ No newline at end of file + ) + + def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]: + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + else: + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + + # Split concatenated embeddings for each image item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + return image_embeds.split(sizes.tolist()) + + def _process_video_input(self, video_input) -> tuple[torch.Tensor, ...]: + + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + + if video_input["type"] == "video_embeds": + video_embeds = video_input["video_embeds"].type(self.visual.dtype) + else: + pixel_values_videos = video_input["pixel_values_videos"].type( + self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + + # Split concatenated embeddings for each video item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + return video_embeds.split(sizes.tolist())