diff --git a/tests/ut/models/test_qwen2_5_vl_without_padding.py b/tests/ut/models/test_qwen2_5_vl_without_padding.py index 0ae1afa..d6c9954 100644 --- a/tests/ut/models/test_qwen2_5_vl_without_padding.py +++ b/tests/ut/models/test_qwen2_5_vl_without_padding.py @@ -231,6 +231,8 @@ class TestAscendQwen2_5_VisionTransformer_Without_Padding(PytestBase): vision_config.in_channels = 3 vision_config.hidden_act = "gelu" vision_config.depth = 0 + vision_config.hidden_size = 1280 + vision_config.num_heads = 16 mocker.patch("torch.nn.Module.__setattr__") mocker.patch("torch.nn.Module.__getattr__") @@ -239,6 +241,10 @@ class TestAscendQwen2_5_VisionTransformer_Without_Padding(PytestBase): "vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionTransformer.__init__", return_value=None, ) + mocker_vision_rotary_embedding = mocker.patch( + "vllm_ascend.models.qwen2_5_vl.AscendQwen2_5_VisionRotaryEmbedding.__init__", + return_value=None, + ) mocker.patch( "vllm_ascend.models.qwen2_5_vl_without_padding.AscendQwen2_5_VisionBlock_Without_Padding.__init__", return_value=None, @@ -264,7 +270,7 @@ class TestAscendQwen2_5_VisionTransformer_Without_Padding(PytestBase): args, kwargs = mocker_vit.call_args assert args == (vision_config, norm_eps, None, "") assert not kwargs - + mocker_vision_rotary_embedding.assert_called_once() return vision_transformer def test_init_vision_transformer(self, mocker: MockerFixture): diff --git a/vllm_ascend/models/qwen2_5_vl_without_padding.py b/vllm_ascend/models/qwen2_5_vl_without_padding.py index 47ddd44..85c8bad 100644 --- a/vllm_ascend/models/qwen2_5_vl_without_padding.py +++ b/vllm_ascend/models/qwen2_5_vl_without_padding.py @@ -41,6 +41,8 @@ from vllm.model_executor.models.qwen2_5_vl import ( from vllm.model_executor.models.utils import maybe_prefix from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm_ascend.models.qwen2_5_vl import AscendQwen2_5_VisionRotaryEmbedding + class AscendQwen2_5_VisionAttention_Without_Padding(Qwen2_5_VisionAttention): @@ -160,6 +162,9 @@ class AscendQwen2_5_VisionTransformer_Without_Padding(Qwen2_5_VisionTransformer super().__init__(vision_config, norm_eps, quant_config, prefix) norm_layer = partial(RMSNorm, eps=norm_eps) self.interleaved = interleaved + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = AscendQwen2_5_VisionRotaryEmbedding(head_dim // + 2) self.patch_embed = AscendQwen2_5_VisionPatchEmbed_Without_Padding( patch_size=vision_config.patch_size, temporal_patch_size=vision_config.temporal_patch_size,