From 022012aae83d2ae4a0f7133c55245d42e8613901 Mon Sep 17 00:00:00 2001 From: Lifu Huang Date: Sat, 24 May 2025 21:43:38 -0700 Subject: [PATCH] Support Phi-4 Multi-Modal (text + vision only) (#6494) --- .gitignore | 3 + python/pyproject.toml | 3 +- python/sglang/srt/configs/model_config.py | 1 + python/sglang/srt/conversation.py | 20 + .../managers/multimodal_processors/phi4mm.py | 87 ++++ python/sglang/srt/models/minicpmv.py | 27 +- python/sglang/srt/models/phi4mmvllm.py | 489 ++++++++++++++++++ test/srt/test_vision_openai_server_b.py | 26 + 8 files changed, 650 insertions(+), 6 deletions(-) create mode 100644 python/sglang/srt/managers/multimodal_processors/phi4mm.py create mode 100644 python/sglang/srt/models/phi4mmvllm.py diff --git a/.gitignore b/.gitignore index 6d6b84e3c..4e8cebd10 100644 --- a/.gitignore +++ b/.gitignore @@ -228,5 +228,8 @@ compile_commands.json 1 +# Autoenv +.env.leave + # Rust lib Cargo.lock diff --git a/python/pyproject.toml b/python/pyproject.toml index 724456dfd..9a120837f 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -17,6 +17,7 @@ dependencies = ["aiohttp", "requests", "tqdm", "numpy", "IPython", "setproctitle [project.optional-dependencies] runtime_common = [ + "blobfile==3.0.0", "compressed-tensors", "datasets", "fastapi", @@ -38,12 +39,12 @@ runtime_common = [ "python-multipart", "pyzmq>=25.1.2", "soundfile==0.13.1", + "scipy", "torchao==0.9.0", "transformers==4.51.1", "uvicorn", "uvloop", "xgrammar==0.1.19", - "blobfile==3.0.0" ] srt = [ diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index bc0fe0cb1..ba4c3d7d2 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -552,6 +552,7 @@ multimodal_model_archs = [ "Qwen2_5_VLForConditionalGeneration", "KimiVLForConditionalGeneration", "InternVLChatModel", + "Phi4MMForCausalLM", ] diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 2280cd579..182345c74 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -661,6 +661,20 @@ register_conv_template( ) ) +# TODO (lifuhuang): Refactor BaseMultimodalProcessor to support the default image token "<|image_{index}|>" in the future. +register_conv_template( + Conversation( + name="phi-4-mm", + system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.", + system_template="<|system|>{system_message}<|end|>", + roles=("<|user|>", "<|assistant|>"), + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="<|end|>", + stop_str="<|end|>", + image_token="<|endoftext10|>", + ) +) + register_conv_template( Conversation( name="chatml", @@ -945,3 +959,9 @@ def match_openbmb_minicpm(model_path: str): def match_moonshot_kimivl(model_path: str): if re.search(r"kimi.*vl", model_path, re.IGNORECASE): return "kimi-vl" + + +@register_conv_template_matching_function +def match_phi_4_mm(model_path: str): + if "phi-4-multimodal" in model_path.lower(): + return "phi-4-mm" diff --git a/python/sglang/srt/managers/multimodal_processors/phi4mm.py b/python/sglang/srt/managers/multimodal_processors/phi4mm.py new file mode 100644 index 000000000..a64f00377 --- /dev/null +++ b/python/sglang/srt/managers/multimodal_processors/phi4mm.py @@ -0,0 +1,87 @@ +import logging +from typing import List, Union + +from sglang.srt.managers.multimodal_processors.base_processor import ( + BaseMultimodalProcessor, + MultimodalSpecialTokens, +) +from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem +from sglang.srt.models.phi4mmvllm import Phi4MMForCausalLM + +logger = logging.getLogger(__name__) + +_IMAGE_SPECIAL_TOKEN = "<|endoftext10|>" +_IMAGE_SPECIAL_TOKEN_ID = 200010 + + +class Phi4MMImageProcessor(BaseMultimodalProcessor): + models = [Phi4MMForCausalLM] + + def __init__(self, hf_config, server_args, _processor): + super().__init__(hf_config, server_args, _processor) + self.multimodal_tokens = MultimodalSpecialTokens( + image_token=_IMAGE_SPECIAL_TOKEN, + ) + + async def process_mm_data_async( + self, + image_data: List[Union[str, bytes]], + input_text, + request_obj, + max_req_input_len, + **kwargs, + ): + audio_data = request_obj.audio_data + + if not image_data and not audio_data: + return None + + if not isinstance(image_data, list): + image_data = [image_data] + + if not isinstance(audio_data, list): + audio_data = [audio_data] + + if audio_data: + logger.warning( + "Currently SGLang does not support audio data for Phi4MM. We are working on it. You can file an issue to help us prioritize." + ) + audio_data = [] + + base_output = self.load_mm_data( + prompt=input_text, + max_req_input_len=max_req_input_len, + audio_data=audio_data, + image_data=image_data, + multimodal_tokens=self.multimodal_tokens, + ) + if base_output is None: + return None + + res = self.process_mm_data( + input_text=base_output.input_text, + images=base_output.images, + audios=base_output.audios, + ) + + input_ids = res["input_ids"].flatten() + image_offsets = self.get_mm_items_offset( + input_ids=input_ids, + mm_token_id=_IMAGE_SPECIAL_TOKEN_ID, + ) + + items = [ + MultimodalDataItem( + pixel_values=res["input_image_embeds"], + image_sizes=res["image_sizes"], + image_emb_mask=res["image_attention_mask"], + image_offsets=image_offsets, + modality=Modality.IMAGE, + ) + ] + + return { + "mm_items": items, + "input_ids": input_ids.tolist(), + "im_token_id": _IMAGE_SPECIAL_TOKEN_ID, + } diff --git a/python/sglang/srt/models/minicpmv.py b/python/sglang/srt/models/minicpmv.py index 5793fa819..7ef812f25 100644 --- a/python/sglang/srt/models/minicpmv.py +++ b/python/sglang/srt/models/minicpmv.py @@ -20,6 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only MiniCPM-V model compatible with HuggingFace weights.""" + from functools import partial from typing import ( Any, @@ -386,6 +387,7 @@ class Idefics2VisionTransformer(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + require_post_norm: bool = True, prefix: str = "", ) -> None: super().__init__() @@ -398,20 +400,35 @@ class Idefics2VisionTransformer(nn.Module): quant_config=quant_config, prefix=add_prefix("encoder", prefix), ) - self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.post_layernorm = ( + nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + if require_post_norm + else nn.Identity() + ) def get_input_embeddings(self) -> nn.Embedding: return self.embeddings - def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor: - patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] # shape: (batch_size,) + def compute_cu_seqlens( + self, + tgt_sizes: Optional[torch.Tensor] = None, + atch_attention_mask: Optional[torch.BoolTensor] = None, + ) -> torch.Tensor: + # shape: (batch_size,) + if tgt_sizes is not None: + patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] + else: + patch_len = atch_attention_mask[:, :, 0].sum(dim=1) * atch_attention_mask[ + :, 0, : + ].sum(dim=1) + cu_seqlens = torch.cat( [ torch.tensor([0], device=patch_len.device, dtype=torch.int32), torch.cumsum(patch_len, dim=0, dtype=torch.int32), ], dim=0, - ).to(tgt_sizes.device) + ).to(patch_len.device) return cu_seqlens def forward( @@ -425,7 +442,7 @@ class Idefics2VisionTransformer(nn.Module): patch_attention_mask=patch_attention_mask, tgt_sizes=tgt_sizes, ) - cu_seqlens = self.compute_cu_seqlens(tgt_sizes) + cu_seqlens = self.compute_cu_seqlens(tgt_sizes, patch_attention_mask) encoder_outputs = self.encoder( hidden_states, cu_seqlens=cu_seqlens, diff --git a/python/sglang/srt/models/phi4mmvllm.py b/python/sglang/srt/models/phi4mmvllm.py new file mode 100644 index 000000000..6078d4012 --- /dev/null +++ b/python/sglang/srt/models/phi4mmvllm.py @@ -0,0 +1,489 @@ +import logging +import math +from collections.abc import Iterable +from typing import List, Optional, Tuple + +import numpy as np +import torch +from torch import nn +from transformers import PretrainedConfig, SiglipVisionConfig + +from sglang.srt.layers.quantization import QuantizationConfig +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternMultimodalTokens, + general_mm_embed_routine, +) +from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.llama import LlamaForCausalLM + +# TODO (lifuhuang): Idefics2VisionTransformer is introduced in minicpmv, we should extract it to a shared location as a quick follow-up. +from sglang.srt.models.minicpmv import Idefics2VisionTransformer + +logger = logging.getLogger(__name__) + +SIGLIP_NAME = "siglip-so400m-patch14-448" +VISION_ENCODER_TO_PROCESSING_CONFIG = { + "siglip-so400m-patch14-448": { + "vit_image_size": 448, + "vit_patch_size": 14, + "token_compression_factor": 2, + }, +} + + +def get_navit_vision_model(): + vision_config = { + "hidden_size": 1152, + "image_size": 448, + "intermediate_size": 4304, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 26, # Model is originally 27-layer, we only need the first 26 layers for feature extraction. + "patch_size": 14, + } + model_config = SiglipVisionConfig(**vision_config) + + vision_model = Idefics2VisionTransformer( + config=model_config, require_post_norm=False + ) + + return vision_model + + +class Phi4MMImageEncoder(nn.Module): + """Image embedding.""" + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + model_dir: str = "", + ) -> None: + super().__init__() + + # n_embed or hidden_size + hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size + self.type_feature = "patch" + + self.img_processor = get_navit_vision_model() + + pe_weight = self.img_processor.embeddings.position_embedding.weight + L, D = pe_weight.size() + H = int(math.sqrt(L)) + assert H**2 == L, f"position embedding size {L} is not square" + if H % 2 != 0: + self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1)) + H += 1 + image_dim_out = D + # ((448/14)//2)**2 + self.num_img_tokens = (H // 2) ** 2 + self.base_feat_height_target = H + + self.image_dim_out = image_dim_out + self.img_sizes = None + self.image_attention_mask = None + + # global_gn and sub_gn for hd transform, serves as line separator + self.use_hd_transform = True + self.with_learnable_separator = True + self.hd_transform_order = "sub_glb" + self.freeze_img_processor = False + self.crop_size = 448 + + # image token compression + self.image_token_compression_cls = "avg_pool_2d" + self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2) + self.base_feat_height_reduction = 1 + self.base_feat_height_target = self.base_feat_height_target // 2 + + # with_hd_transform and with_learnable_separator should have same value + assert ( + self.use_hd_transform == self.with_learnable_separator + ), "use_hd_transform and with_learnable_separator should have same value" + assert self.use_hd_transform, "learnable separator is only for hd transform" + # 1024 * 4, merge spatial to channel dimension + self.glb_GN = nn.Parameter( + torch.zeros([1, 1, self.image_dim_out * self.base_feat_height_reduction**2]) + ) + self.sub_GN = nn.Parameter( + torch.zeros( + [1, 1, 1, self.image_dim_out * self.base_feat_height_reduction**2] + ) + ) + + dim_projection = hidden_size + depth = 2 + layers = [ + nn.Linear( + image_dim_out * self.base_feat_height_reduction**2, dim_projection + ) + ] + for _ in range(1, depth): + layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) + self.img_projection = nn.Sequential(*layers) + + self.vocab_size = config.vocab_size + self.img_features = None + + self.use_out_place_operations = False + + def get_img_features( + self, img_embeds: torch.FloatTensor, attention_mask=None + ) -> torch.FloatTensor: + img_feature = self.img_processor( + img_embeds, patch_attention_mask=attention_mask + ) + + patch_feature = img_feature + + use_token_compression = self.image_token_compression is not None + use_padding = getattr(self, "img_processor_padding", None) is not None + if use_token_compression or use_padding: + # reshape to 2D tensor + width = int(math.sqrt(patch_feature.size(1))) + patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) + # convert to NCHW + patch_feature = patch_feature.permute(0, 3, 1, 2) + + if use_padding: + patch_feature = self.img_processor_padding(patch_feature) + if use_token_compression: + patch_feature = self.image_token_compression(patch_feature) + + # convert to NHWC + patch_feature = patch_feature.permute(0, 2, 3, 1) + patch_feature = patch_feature.view( + -1, + patch_feature.size(1) * patch_feature.size(2), + patch_feature.size(-1), + ) + + return patch_feature + + def forward( + self, + pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor, + image_attention_mask: torch.Tensor, + ) -> list[torch.FloatTensor]: + """ + process image and return vision embeddings. + + pixel_values: (num_images, num_crops, c, h, w) + image_sizes: [[h1, w1], [h2, w2]] + image_attention_mask: num_images x num_crops x 32 x 32 + output: (num_images, num_img_tokens, hidden_size) + """ + + # eg + # pixel_values: torch.Size([1, 7, 3, 448, 448]) + # image_sizes: tensor([[ 896, 1344]], device='cuda:0') + # output: torch.Size([1, 1841, 3072]) + + img_projection_params = next(self.img_projection.parameters()) + target_device = img_projection_params.device + target_dtype = img_projection_params.dtype + + img_sizes = image_sizes + num_images, num_crops, c, h, w = pixel_values.shape + bs = num_images + pixel_values = pixel_values.flatten(0, 1) + + img_features = self.get_img_features( + pixel_values, + image_attention_mask.type(torch.BoolTensor).flatten(0, 1).to(target_device), + ) + + base_feat_height_target = self.base_feat_height_target + base_resolution = self.crop_size + base_feat_height_reduction = self.base_feat_height_reduction + + base_feat_height = base_feat_width = int(np.sqrt(img_features.shape[1])) + assert ( + base_feat_height == base_feat_height_target + and base_feat_width == base_feat_height_target + ), f'base_feat_height: {base_feat_height},"\ + f" base_feat_width: {base_feat_width}, "\ + f"expect {base_feat_height_target} features for hd transform' + + # bs x max_num_crops x (24x24) x C + img_features = img_features.view( + bs, -1, base_feat_height * base_feat_width, self.image_dim_out + ) + C = self.image_dim_out + H = base_feat_height + + output_imgs = [] + output_len = [] + # training is tensor, inference is list + if isinstance(img_sizes, torch.Tensor): + img_sizes = img_sizes.view(-1, 2) + for _bs in range(bs): + h, w = img_sizes[_bs] + h = h // base_resolution + w = w // base_resolution + B_ = h * w + + # 1 x (24x24) x 1024 + global_img_feature = img_features[_bs, :1] + + # 1 x 12 x 12 x 4096 + glb_img = ( + global_img_feature.reshape(1, H, H, C) + .reshape( + 1, + H // base_feat_height_reduction, + base_feat_height_reduction, + H // base_feat_height_reduction, + base_feat_height_reduction, + C, + ) + .contiguous() + .permute(0, 1, 3, 2, 4, 5) + .reshape( + 1, + H // base_feat_height_reduction, + H // base_feat_height_reduction, + base_feat_height_reduction * base_feat_height_reduction * C, + ) + .contiguous() + ) + temp_glb_GN = self.sub_GN.repeat(1, H // base_feat_height_reduction, 1, 1) + + # 1 x 156 x 4096 + glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape( + 1, -1, base_feat_height_reduction * base_feat_height_reduction * C + ) + + # (max_num_crops-1) x (12x12) x C + sub_img = img_features[_bs, 1:] + # 16x574x1024 + # get rid of padding sub_img + sub_img = sub_img[:B_] + + # (num_crops, 12, 2, 12, 2, 1024) -> + # (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024) + sub_img = ( + sub_img.reshape(B_, H, H, C) + .reshape( + B_, + H // base_feat_height_reduction, + base_feat_height_reduction, + H // base_feat_height_reduction, + base_feat_height_reduction, + C, + ) + .contiguous() + .permute(0, 1, 3, 2, 4, 5) + .reshape( + B_, -1, base_feat_height_reduction * base_feat_height_reduction * C + ) + .contiguous() + ) + sub_img = ( + sub_img.reshape( + 1, + h, + w, + base_feat_height // base_feat_height_reduction, + base_feat_width // base_feat_height_reduction, + -1, + ) + .permute(0, 1, 3, 2, 4, 5) + .reshape( + 1, + h * base_feat_height // base_feat_height_reduction, + w * base_feat_width // base_feat_height_reduction, + base_feat_height_reduction * base_feat_height_reduction * C, + ) + ) + + if image_attention_mask is not None and len(image_attention_mask) > 0: + reshaped_image_attention_mask = ( + image_attention_mask[_bs, 1 : B_ + 1, 0::2, 0::2] + .reshape( + 1, + h, + w, + base_feat_height // base_feat_height_reduction, + base_feat_width // base_feat_height_reduction, + ) + .permute(0, 1, 3, 2, 4) + .reshape( + 1, + h * base_feat_height // base_feat_height_reduction, + w * base_feat_width // base_feat_height_reduction, + ) + ) + useful_height = int(reshaped_image_attention_mask[0, :, 0].sum().item()) + useful_width = int(reshaped_image_attention_mask[0, 0, :].sum().item()) + sub_img = sub_img[:, :useful_height, :useful_width] + temp_sub_GN = self.sub_GN.repeat(1, useful_height, 1, 1) + temp_len = ( + int(image_attention_mask[_bs, : B_ + 1, 0::2, 0::2].sum().item()) + + (useful_height + 1) + + base_feat_height // base_feat_height_reduction + ) + else: + temp_sub_GN = self.sub_GN.repeat( + 1, h * base_feat_height // base_feat_height_reduction, 1, 1 + ) + temp_len = int( + (h * w + 1) * self.num_img_tokens + + 1 + + (h + 1) * base_feat_height // base_feat_height_reduction + ) + + sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape( + 1, -1, base_feat_height_reduction * base_feat_height_reduction * C + ) + # (1, num_img_tokens, 1024*4) + + # glb + sub + if self.hd_transform_order == "glb_sub": + output_imgs.append(torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) + elif self.hd_transform_order == "sub_glb": + output_imgs.append(torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) + else: + raise NotImplementedError( + f'hd_transform_order = {self.hd_transform_order}, "\ + "not implemented' + ) + + # temp_len = int((h*w+1)*144 + 1 + (h+1)*12) + assert ( + temp_len == output_imgs[-1].shape[1] + ), f'temp_len: {temp_len}, output_imgs[-1].shape[1]: "\ + "{output_imgs[-1].shape[1]}' + + output_len.append(temp_len) + + img_set_tensor = [] + for _output_img in output_imgs: + img_feature_proj = self.img_projection( + _output_img.to(target_device).to(target_dtype) + ) + img_set_tensor.append(img_feature_proj.squeeze(0)) + + return img_set_tensor + + +class Phi4MMForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.language_model = LlamaForCausalLM( + config=config, quant_config=quant_config, prefix=prefix + ) + + self.vision_encoder = Phi4MMImageEncoder( + config, + quant_config, + prefix="model.vision_embed_tokens", + model_dir=config._name_or_path, + ) + + def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + dtype = next(self.vision_encoder.parameters()).dtype + pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type( + dtype + ) + image_attention_mask = torch.cat([item.image_emb_mask for item in items], dim=0) + image_sizes = torch.cat([item.image_sizes for item in items], dim=0) + image_embeds = self.vision_encoder( + pixel_values, image_sizes, image_attention_mask + ) + return torch.cat(image_embeds).type(dtype) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + **kwargs: object, + ) -> torch.Tensor: + hidden_states = general_mm_embed_routine( + input_ids=input_ids, + forward_batch=forward_batch, + language_model=self.language_model, + image_data_embedding_func=self.get_image_feature, + positions=positions, + ) + + return hidden_states + + def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): + # Get all special token IDs + im_token_id: int = mm_inputs.im_token_id + pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id]) + return pattern.pad_input_tokens(input_ids, mm_inputs) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".self_attn.qkv_proj", ".self_attn.q_proj", "q"), + (".self_attn.qkv_proj", ".self_attn.k_proj", "k"), + (".self_attn.qkv_proj", ".self_attn.v_proj", "v"), + ] + prefix_mapping = { + "model.embed_tokens_extend.image_embed.": "vision_encoder.", + "model.": "language_model.model.", + } + + skip_list = [ + "img_processor.encoder.layers.26", + "img_processor.head", + "img_processor.post_layernorm", + "audio", + ] + + def _should_skip(name: str) -> bool: + return any(substr in name for substr in skip_list) + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + # Skip the last layer + if _should_skip(name): + continue + + for old_name, new_name in prefix_mapping.items(): + if name.startswith(old_name): + name = name.replace(old_name, new_name) + break + + # Adapt to VisionAttention + name = name.replace(r"self_attn.out_proj", r"self_attn.proj") + name = name.replace(r"base_layer.", r"") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict.get(name) + if param is None: + if "lora" not in name: + logger.warning("Warning: {name} not found in model parameters") + continue + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = [Phi4MMForCausalLM] diff --git a/test/srt/test_vision_openai_server_b.py b/test/srt/test_vision_openai_server_b.py index 30c9808d3..404a4844b 100644 --- a/test/srt/test_vision_openai_server_b.py +++ b/test/srt/test_vision_openai_server_b.py @@ -196,5 +196,31 @@ class TestKimiVLServer(TestOpenAIVisionServer): pass +class TestPhi4MMServer(TestOpenAIVisionServer): + @classmethod + def setUpClass(cls): + cls.model = "microsoft/Phi-4-multimodal-instruct" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--mem-fraction-static", + "0.75", + ], + ) + cls.base_url += "/v1" + + def test_video_chat_completion(self): + pass + + def test_multi_images_chat_completion(self): + # TODO (lifuhuang): support LoRA to enable Phi4MM multi-image understanding capability. + pass + + if __name__ == "__main__": unittest.main()