Supported precomputed feature for Kimi VL (#6599)

This commit is contained in:
Lifu Huang
2025-05-26 01:24:13 -07:00
committed by GitHub
parent 501efc3d36
commit 0d503090aa
5 changed files with 93 additions and 47 deletions

View File

@@ -7,6 +7,7 @@ import requests
import torch
from PIL import Image
from transformers import (
AutoModel,
AutoProcessor,
Gemma3ForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
@@ -51,6 +52,7 @@ class VLMInputTestBase:
mem_fraction_static=0.8,
enable_multimodal=True,
disable_cuda_graph=True,
trust_remote_code=True,
)
def tearDown(self):
@@ -183,5 +185,32 @@ class TestGemmaUnderstandsImage(VLMInputTestBase, unittest.IsolatedAsyncioTestCa
)
class TestKimiVLImageUnderstandsImage(
VLMInputTestBase, unittest.IsolatedAsyncioTestCase
):
model_path = "moonshotai/Kimi-VL-A3B-Instruct"
chat_template = "kimi-vl"
@classmethod
def _init_visual(cls):
model = AutoModel.from_pretrained(cls.model_path, trust_remote_code=True)
cls.vision_tower = model.vision_tower.eval().to(cls.device)
cls.mm_projector = model.multi_modal_projector.eval().to(cls.device)
cls.visual = lambda tokenizer_output: cls.mm_projector(
cls.vision_tower(
pixel_values=tokenizer_output["pixel_values"],
grid_hws=tokenizer_output["image_grid_hws"],
)
)
def _pixel_values_image_data(self, processor_output):
return dict(
modality="IMAGE",
image_grid_thws=processor_output["image_grid_hws"],
pixel_values=processor_output["pixel_values"],
)
if __name__ == "__main__":
unittest.main()