From 5380cd7ea39dedd2657ea18f2ea43d09aa57df6d Mon Sep 17 00:00:00 2001 From: Kiv Chen <34561254+KivenChen@users.noreply.github.com> Date: Tue, 13 May 2025 00:16:10 -0700 Subject: [PATCH] model(vlm): pixtral (#5084) --- .../vision_language_models.md | 1 + examples/runtime/README.md | 5 +- .../llama3_llava_server.py} | 2 +- .../llava_onevision_server.py} | 2 +- examples/runtime/multimodal/pixtral_server.py | 127 +++++ .../qwen_llava_server.py} | 2 +- python/sglang/lang/chat_template.py | 23 +- python/sglang/srt/configs/model_config.py | 1 + python/sglang/srt/conversation.py | 22 +- .../managers/multimodal_processors/llava.py | 46 ++ .../managers/multimodal_processors/pixtral.py | 127 +++++ python/sglang/srt/models/llava.py | 253 +++++++++- python/sglang/srt/models/pixtral.py | 467 ++++++++++++++++++ python/sglang/test/runners.py | 9 +- test/srt/models/test_generation_models.py | 55 ++- test/srt/test_vision_openai_server.py | 22 + 16 files changed, 1125 insertions(+), 39 deletions(-) rename examples/runtime/{llava_onevision/http_llama3_llava_test.py => multimodal/llama3_llava_server.py} (99%) rename examples/runtime/{llava_onevision/http_llava_onevision_test.py => multimodal/llava_onevision_server.py} (99%) create mode 100644 examples/runtime/multimodal/pixtral_server.py rename examples/runtime/{llava_onevision/http_qwen_llava_test.py => multimodal/qwen_llava_server.py} (99%) create mode 100644 python/sglang/srt/managers/multimodal_processors/pixtral.py create mode 100644 python/sglang/srt/models/pixtral.py diff --git a/docs/supported_models/vision_language_models.md b/docs/supported_models/vision_language_models.md index a9f4a8197..d2d9f1d80 100644 --- a/docs/supported_models/vision_language_models.md +++ b/docs/supported_models/vision_language_models.md @@ -20,6 +20,7 @@ python3 -m sglang.launch_server \ | **Janus-Pro** (1B, 7B) | `deepseek-ai/Janus-Pro-7B` | `janus-pro` | DeepSeek’s open-source multimodal model capable of both image understanding and generation. Janus-Pro employs a decoupled architecture for separate visual encoding paths, enhancing performance in both tasks. | | **MiniCPM-V / MiniCPM-o** | `openbmb/MiniCPM-V-2_6` | `minicpmv` | MiniCPM-V (2.6, ~8B) supports image inputs, and MiniCPM-o adds audio/video; these multimodal LLMs are optimized for end-side deployment on mobile/edge devices. | | **Llama 3.2 Vision** (11B) | `meta-llama/Llama-3.2-11B-Vision-Instruct` | `llama_3_vision` | Vision-enabled variant of Llama 3 (11B) that accepts image inputs for visual question answering and other multimodal tasks. | +| **Pixtral** (12B, 124B) | `mistral-community/pixtral-12b` | `mistral` | Pixtral is a vision-language model from Mistral AI that can process both text and images. | | **LLaVA** (v1.5 & v1.6) | *e.g.* `liuhaotian/llava-v1.5-13b` | `vicuna_v1.1` | Open vision-chat models that add an image encoder to LLaMA/Vicuna (e.g. LLaMA2 13B) for following multimodal instruction prompts. | | **LLaVA-NeXT** (8B, 72B) | `lmms-lab/llava-next-72b` | `chatml-llava` | Improved LLaVA models (with an 8B Llama3 version and a 72B version) offering enhanced visual instruction-following and accuracy on multimodal benchmarks. | | **LLaVA-OneVision** | `lmms-lab/llava-onevision-qwen2-7b-ov` | `chatml-llava` | Enhanced LLaVA variant integrating Qwen as the backbone; supports multiple images (and even video frames) as inputs via an OpenAI Vision API-compatible format. | diff --git a/examples/runtime/README.md b/examples/runtime/README.md index 941863727..18414452f 100644 --- a/examples/runtime/README.md +++ b/examples/runtime/README.md @@ -33,9 +33,10 @@ The `hidden_states` folder contains examples on how to extract hidden states usi * `hidden_states_engine.py`: An example how to extract hidden states using the Engine API. * `hidden_states_server.py`: An example how to extract hidden states using the Server API. -## LLaVA-NeXT +## Multimodal + +SGLang supports multimodal inputs for various model architectures. The `multimodal` folder contains examples showing how to use urls, files or encoded data to make requests to multimodal models. Examples include querying the [Llava-OneVision](multimodal/llava_onevision_server.py) model (image, multi-image, video), Llava-backed [Qwen-Llava](multimodal/qwen_llava_server.py) and [Llama3-Llava](multimodal/llama3_llava_server.py) models (image, multi-image), and Mistral AI's [Pixtral](multimodal/pixtral_server.py) (image, multi-image). -SGLang support LLaVA-OneVision with single-image, multi-image and video are supported. The folder `llava_onevision` shows how to do this. ## Token In, Token Out diff --git a/examples/runtime/llava_onevision/http_llama3_llava_test.py b/examples/runtime/multimodal/llama3_llava_server.py similarity index 99% rename from examples/runtime/llava_onevision/http_llama3_llava_test.py rename to examples/runtime/multimodal/llama3_llava_server.py index ed0e61631..a8409af71 100644 --- a/examples/runtime/llava_onevision/http_llama3_llava_test.py +++ b/examples/runtime/multimodal/llama3_llava_server.py @@ -6,7 +6,7 @@ Usage: # Endpoint Service CLI: python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000 -python3 http_llama3_llava_test.py +python3 llama3_llava_server.py Output: "Friends posing for a fun photo with a life-sized teddy bear, creating a playful and memorable moment." diff --git a/examples/runtime/llava_onevision/http_llava_onevision_test.py b/examples/runtime/multimodal/llava_onevision_server.py similarity index 99% rename from examples/runtime/llava_onevision/http_llava_onevision_test.py rename to examples/runtime/multimodal/llava_onevision_server.py index 5c895007f..94a0fee94 100644 --- a/examples/runtime/llava_onevision/http_llava_onevision_test.py +++ b/examples/runtime/multimodal/llava_onevision_server.py @@ -3,7 +3,7 @@ Usage: python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 -python3 http_llava_onevision_test.py +python3 llava_onevision_server.py """ import base64 diff --git a/examples/runtime/multimodal/pixtral_server.py b/examples/runtime/multimodal/pixtral_server.py new file mode 100644 index 000000000..d907de14d --- /dev/null +++ b/examples/runtime/multimodal/pixtral_server.py @@ -0,0 +1,127 @@ +""" +Usage: +# Run a Pixtral model with SGLang: +# HuggingFace: +python -m sglang.launch_server --model-path mistral-community/pixtral-12b --port=30000 +# ModelScope: +python -m sglang.launch_server --model-path AI-ModelScope/pixtral-12b --port=30000 + +# Then test it with: +python pixtral_server.py + +This script tests Pixtral model with both single and multiple images. +""" + +import argparse +import asyncio +import json + +import aiohttp +import requests + +IMAGE_TOKEN_SEP = "\n[IMG]" +ROUTE = "/generate" + + +async def send_request(url, data, delay=0): + await asyncio.sleep(delay) + async with aiohttp.ClientSession() as session: + async with session.post(url, json=data) as resp: + output = await resp.json() + return output + + +async def test_concurrent(args): + url = f"{args.host}:{args.port}{ROUTE}" + + # Single image test + if args.single_image: + prompt = f"[INST]Describe this image in detail.{IMAGE_TOKEN_SEP}[/INST]" + image_url = "https://picsum.photos/id/237/400/300" + modality = ["image"] + # Multiple images test + else: + image_urls = [ + "https://picsum.photos/id/237/400/300", + "https://picsum.photos/id/27/500/500", + ] + prompt = f"[INST]How many photos are there? Describe each in a very short sentence.{IMAGE_TOKEN_SEP * len(image_urls)}[/INST]" + image_url = image_urls + modality = ["multi-images"] + + response = await send_request( + url, + { + "text": prompt, + "image_data": image_url, + "sampling_params": { + "max_new_tokens": 100, + "temperature": 0.7, + "top_p": 0.9, + }, + "modalities": modality, + }, + ) + + print(f"Response: {response}") + if "text" in response: + print("\nOutput text:", response["text"]) + + +def test_streaming(args): + url = f"{args.host}:{args.port}/generate" + + # Single image test + if args.single_image: + prompt = f"[INST]Describe this image in detail.{IMAGE_TOKEN_SEP}[/INST]" + image_data = "https://picsum.photos/id/237/400/300" + modality = ["image"] + # Multiple images test + else: + image_urls = [ + "https://picsum.photos/id/237/400/300", + "https://picsum.photos/id/27/500/500", + ] + prompt = f"[INST]How many photos are there? Describe each in a very short sentence.{IMAGE_TOKEN_SEP * len(image_urls)}[/INST]" + image_data = image_urls + modality = ["multi-images"] + + pload = { + "text": prompt, + "image_data": image_data, + "sampling_params": {"max_new_tokens": 100, "temperature": 0.7, "top_p": 0.9}, + "modalities": modality, + "stream": True, + } + + response = requests.post(url, json=pload, stream=True) + + print("Streaming response:") + prev = 0 + for chunk in response.iter_lines(decode_unicode=False): + chunk = chunk.decode("utf-8") + if chunk and chunk.startswith("data:"): + if chunk == "data: [DONE]": + break + data = json.loads(chunk[5:].strip("\n")) + output = data["text"].strip() + print(output[prev:], end="", flush=True) + prev = len(output) + print("\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=30000) + parser.add_argument( + "--single-image", + action="store_true", + help="Test with single image instead of multiple images", + ) + parser.add_argument("--no-stream", action="store_true", help="Don't test streaming") + args = parser.parse_args() + + asyncio.run(test_concurrent(args)) + if not args.no_stream: + test_streaming(args) diff --git a/examples/runtime/llava_onevision/http_qwen_llava_test.py b/examples/runtime/multimodal/qwen_llava_server.py similarity index 99% rename from examples/runtime/llava_onevision/http_qwen_llava_test.py rename to examples/runtime/multimodal/qwen_llava_server.py index 94ec4e079..d8b3226e7 100644 --- a/examples/runtime/llava_onevision/http_qwen_llava_test.py +++ b/examples/runtime/multimodal/qwen_llava_server.py @@ -6,7 +6,7 @@ Usage: # Endpoint Service CLI: python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --port=30000 --tp-size=8 -python3 http_qwen_llava_test.py +python3 qwen_llava_server.py Output: "Two children pose with a large teddy bear, one holding a smaller stuffed bear, in a room with an American flag and potted plants." diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index 8609f3e58..f309d053d 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -194,6 +194,21 @@ register_chat_template( ) ) +# Reference: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/chat_template.json +register_chat_template( + ChatTemplate( + name="mistral", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ("[SYSTEM_PROMPT] ", " [/SYSTEM_PROMPT]"), + "user": ("[INST] ", " [/INST]"), + "assistant": ("", " "), + }, + stop_str=("",), + image_token="[IMG]", + ) +) + register_chat_template( ChatTemplate( name="llama-3-instruct", @@ -509,13 +524,19 @@ def match_vicuna(model_path: str): @register_chat_template_matching_function def match_llama2_chat(model_path: str): if re.search( - r"llama-2.*chat|(mistral|mixtral).*instruct|codellama.*instruct", + r"llama-2.*chat|codellama.*instruct", model_path, re.IGNORECASE, ): return "llama-2-chat" +@register_chat_template_matching_function +def match_mistral(model_path: str): + if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE): + return "mistral" + + @register_chat_template_matching_function def match_llama3_instruct(model_path: str): if re.search(r"llama-3.*instruct", model_path, re.IGNORECASE): diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 85a4f3153..623d38a47 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -545,6 +545,7 @@ multimodal_model_archs = [ "Llama4ForConditionalGeneration", "LlavaMistralForCausalLM", "LlavaQwenForCausalLM", + "LlavaForConditionalGeneration", "LlavaVidForCausalLM", "MiniCPMO", "MiniCPMV", diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index e931bc64a..91aa80fd8 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -634,6 +634,20 @@ register_conv_template( ) ) +# reference: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/chat_template.json +register_conv_template( + Conversation( + name="mistral", + system_template="[SYSTEM_PROMPT]\n{system_message}\n[/SYSTEM_PROMPT]\n\n", + roles=("[INST]", "[/INST]"), + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + stop_str=["[INST]", "[/INST]", "[SYSTEM_PROMPT]", "[/SYSTEM_PROMPT]"], + image_token="[IMG]", + ) +) + # reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json register_conv_template( Conversation( @@ -880,13 +894,19 @@ def match_vicuna(model_path: str): @register_conv_template_matching_function def match_llama2_chat(model_path: str): if re.search( - r"llama-2.*chat|(mistral|mixtral).*instruct|codellama.*instruct", + r"llama-2.*chat|codellama.*instruct", model_path, re.IGNORECASE, ): return "llama-2" +@register_conv_template_matching_function +def match_mistral(model_path: str): + if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE): + return "mistral" + + @register_conv_template_matching_function def match_deepseek_vl(model_path: str): if re.search(r"deepseek.*vl2", model_path, re.IGNORECASE): diff --git a/python/sglang/srt/managers/multimodal_processors/llava.py b/python/sglang/srt/managers/multimodal_processors/llava.py index a1b0dc1f7..c3190c697 100644 --- a/python/sglang/srt/managers/multimodal_processors/llava.py +++ b/python/sglang/srt/managers/multimodal_processors/llava.py @@ -1,14 +1,20 @@ import asyncio +import importlib from typing import List, Optional, Union import numpy as np +from transformers.models.auto.processing_auto import ( + PROCESSOR_MAPPING_NAMES as HF_MAPPING_NAMES, +) +import sglang.srt.managers.multimodal_processor as sgl_mm_processor_utils from sglang.srt.managers.multimodal_processors.base_processor import ( BaseMultimodalProcessor, ) from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.models.llava import ( + LlavaForConditionalGeneration, LlavaLlamaForCausalLM, LlavaMistralForCausalLM, LlavaQwenForCausalLM, @@ -133,6 +139,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor): img_data, aspect_ratio, grid_pinpoints ) ) + res = await asyncio.gather(*res) for pixel_v, image_h, image_s in res: pixel_values.append(pixel_v) @@ -165,3 +172,42 @@ class LlavaImageProcessor(BaseMultimodalProcessor): ) ], } + + +class LlavaMultimodalProcessor(BaseMultimodalProcessor): + """ + This is a wrapper class used to identify the multimodal processor for Llava architecture models. + """ + + models = [LlavaForConditionalGeneration] + + def _get_sgl_processor_cls(self, model_type: str): + if hf_name := HF_MAPPING_NAMES.get(model_type): + sgl_mm_processor_set = sgl_mm_processor_utils.PROCESSOR_MAPPING.values() + sgl_processor_cls = list( + filter(lambda p: p.__name__ == hf_name, sgl_mm_processor_set) + ) + if sgl_processor_cls: + return sgl_processor_cls[0] + raise ValueError( + f"Cannot find corresponding multimodal processor registered in sglang for model type `{model_type}`" + ) + + def __init__(self, hf_config, server_args, _processor): + assert hasattr(hf_config, "vision_config") + assert hasattr(hf_config, "text_config") + self.vision_config = hf_config.vision_config + self.text_config = hf_config.text_config + self.hf_config = hf_config + + if vision_type := getattr(self.vision_config, "model_type"): + self.inner = self._get_sgl_processor_cls(vision_type)( + hf_config, server_args, _processor + ) + else: + raise ValueError( + f"Required `vision_config.model_type` is not found in hf_config: `{hf_config}`" + ) + + async def process_mm_data_async(self, *args, **kwargs): + return await self.inner.process_mm_data_async(*args, **kwargs) diff --git a/python/sglang/srt/managers/multimodal_processors/pixtral.py b/python/sglang/srt/managers/multimodal_processors/pixtral.py new file mode 100644 index 000000000..07a772cdf --- /dev/null +++ b/python/sglang/srt/managers/multimodal_processors/pixtral.py @@ -0,0 +1,127 @@ +import asyncio +import math +from typing import List, Optional, Union + +import numpy as np +from transformers import PretrainedConfig +from transformers.models.pixtral.image_processing_pixtral import ( + _num_image_tokens as _get_pixtral_hf_num_image_tokens, +) + +from sglang.srt.managers.multimodal_processors.base_processor import ( + BaseMultimodalProcessor, + MultimodalSpecialTokens, +) +from sglang.srt.managers.schedule_batch import ( + Modality, + MultimodalDataItem, + MultimodalInputs, +) +from sglang.srt.models.pixtral import PixtralVisionModel + + +class PixtralProcessor(BaseMultimodalProcessor): + models = [PixtralVisionModel] + + PAD_TOKEN = "" + IMG_BREAK_TOKEN_ID = 12 + IMG_END_TOKEN_ID = 13 + + def get_patch_grid_size( + self, + *, + image_width: int, + image_height: int, + ) -> tuple[int, int]: + max_width = max_height = self.image_size + patch_width = patch_height = self.patch_size + + ratio = max(image_width / max_width, image_height / max_height) + + if ratio > 1: + image_width = int(math.floor(image_width / ratio)) + image_height = int(math.floor(image_height / ratio)) + + nrows, ncols = _get_pixtral_hf_num_image_tokens( + (image_height, image_width), + (patch_height, patch_width), + ) + + return ncols, nrows + + def __init__(self, hf_config, server_args, _processor): + super().__init__(hf_config, server_args, _processor) + self.image_token_id = getattr( + hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID + ) + # Instantiate the patcher logic helper using the class defined above + + self.vision_config = hf_config.vision_config + self.image_size = self.vision_config.image_size + self.patch_size = self.vision_config.patch_size + self.multimodal_tokens = MultimodalSpecialTokens( + image_token=_processor.image_token + ) + _processor.tokenizer.add_special_tokens( + { + "pad_token": getattr(hf_config, "pad_token", self.PAD_TOKEN), + } + ) + + async def _resize(self, image): + num_w_tokens, num_h_tokens = self.get_patch_grid_size( + image_width=image.size[0], + image_height=image.size[1], + ) + new_size = (num_w_tokens * self.patch_size, num_h_tokens * self.patch_size) + return image.resize(new_size) + + async def process_mm_data_async( + self, + image_data: List[Union[str, bytes]], + input_text, + request_obj, + *args, + **kwargs, + ): + if not image_data: + return None + + if isinstance(image_data, str): + image_data = [image_data] + + mm_data = self.load_mm_data( + prompt=input_text, + multimodal_tokens=self.multimodal_tokens, + max_req_input_len=kwargs.get("max_req_input_len", 4096), + image_data=image_data, + return_text=True, + ) + + if mm_data.images: + resize_tasks = [self._resize(image) for image in mm_data.images] + mm_data.images = await asyncio.gather(*resize_tasks) + + processor_output = self.process_mm_data( + input_text=mm_data.input_text, + images=mm_data.images, + ) + + if "pixel_values" in processor_output: + mm_items = [ + MultimodalDataItem( + pixel_values=processor_output["pixel_values"], + image_sizes=processor_output["image_sizes"], + modality=Modality.IMAGE, + ) + ] + + input_ids = processor_output["input_ids"].view(-1).tolist() + processor_output.update( + input_ids=input_ids, + mm_items=mm_items, + # there's no im_start_id for pixtral, only im_token and im_end_token + im_end_id=self.IMG_END_TOKEN_ID, + im_token_id=self.image_token_id, + ) + return processor_output diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index f7beda233..5077211d4 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -15,7 +15,8 @@ import math import re -from typing import Iterable, List, Optional, Tuple +from functools import lru_cache +from typing import Dict, Iterable, List, Optional, Tuple, Type, Union import numpy as np import torch @@ -28,10 +29,18 @@ from transformers import ( Qwen2Config, SiglipVisionModel, ) +from transformers.models.auto.modeling_auto import AutoModel, AutoModelForCausalLM from transformers.models.llava.modeling_llava import LlavaMultiModalProjector +# leave till last and symbol only in case circular import +import sglang.srt.models as sgl_models from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.managers.schedule_batch import Modality, MultimodalInputs +from sglang.srt.managers.mm_utils import general_mm_embed_routine +from sglang.srt.managers.schedule_batch import ( + Modality, + MultimodalDataItem, + MultimodalInputs, +) from sglang.srt.mm_utils import ( get_anyres_image_grid_shape, unpad_image, @@ -42,7 +51,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM -from sglang.srt.utils import add_prefix, flatten_nested_list +from sglang.srt.utils import add_prefix, flatten_nested_list, logger class LlavaBaseForCausalLM(nn.Module): @@ -114,7 +123,16 @@ class LlavaBaseForCausalLM(nn.Module): image_inputs.image_offsets = offset_list return input_ids - def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor: + def encode_images( + self, pixel_values: Union[torch.Tensor, List[torch.Tensor]] + ) -> torch.Tensor: + """ + encode images by vision tower and multimodal projector + Args: + pixel_values: torch.Tensor or List[torch.Tensor]: each tensor for an input image + Returns: + torch.Tensor: encoded image features from the input image; if multiple, flattened by seq_len axis + """ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated. @@ -583,4 +601,229 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM): ) -EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM] +class LlavaForConditionalGeneration(LlavaBaseForCausalLM): + """ + An adaptor class to enable support for multiple mmlm such as mistral-community/pixtral-12b + It follows the structure of (vision_tower, multi_modal_projector, language_model) + + Once a model config is loaded, text_config and vision_config will be extracted, and + LlavaForConditionalGeneration will load the language_model and vision_tower models + according to config. + """ + + MULTIMODAL_PROJECTOR_TYPE = LlavaMultiModalProjector + + def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): + if hasattr(self.vision_tower, "pad_input_ids"): + return self.vision_tower.pad_input_ids(input_ids, image_inputs) + else: + return super().pad_input_ids(input_ids, image_inputs) + + def _get_sgl_model_cls(self, config, auto_model_type: Type[AutoModel] = AutoModel): + """ + Get the SGLang model implementation class according to config. + + Args: + config: The config object of the model. + auto_model_type: The type of the auto model. + + Returns: + The SGLang model implementation class. + """ + config_cls_name = config.__class__.__name__ + arch_name_mapping = self._config_cls_name_to_arch_name_mapping(auto_model_type) + if arch := arch_name_mapping.get(config_cls_name): + if isinstance(arch, tuple): + arch = arch[0] + logger.warning( + f"Multiple {auto_model_type.__name__} models found for submodule config `{config_cls_name}`, defaulting to [0]: {arch.__name__}" + ) + try: + return sgl_models.registry.ModelRegistry.resolve_model_cls(arch)[0] + except Exception as e: + raise ValueError( + f"{auto_model_type.__name__} found a corresponding model `{arch}` for config class `{config_cls_name}`, but failed to load it from SGLang ModelRegistry. \n{e}" + ) + else: + raise ValueError( + f"{auto_model_type.__name__} cannot find a corresponding model for config class `{config_cls_name}`" + ) + + @lru_cache + def _config_cls_name_to_arch_name_mapping( + self, auto_model_type: Type[AutoModel] + ) -> Dict[str, str]: + mapping = {} + for config_cls, archs in auto_model_type._model_mapping.items(): + if isinstance(archs, tuple): + mapping[config_cls.__name__] = tuple(arch.__name__ for arch in archs) + else: + mapping[config_cls.__name__] = archs.__name__ + return mapping + + def __init__( + self, + config: LlavaConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + assert hasattr(config, "text_config") + assert hasattr(config, "vision_config") + self.config = config + self.text_config = config.text_config + self.vision_config = config.vision_config + + if not hasattr(self.config, "vocab_size"): + self.config.vocab_size = self.config.text_config.vocab_size + if not hasattr(self.config, "image_aspect_ratio"): + self.config.image_aspect_ratio = "anyres" + if not hasattr(self.config, "image_grid_pinpoints"): + # from transformers.models.llava_onevision.configuration_llava_onevision import LlavaOnevisionConfig + # self.config.image_grid_pinpoints = LlavaOnevisionConfig().image_grid_pinpoints + self.config.image_grid_pinpoints = [ + [96, 96], + [224, 224], + [384, 384], + [512, 512], + [768, 768], + [1024, 1024], + ] + if not hasattr(self.config, "mm_patch_merge_type"): + self.config.mm_patch_merge_type = "flat" + if not hasattr(self.config, "image_token_index"): + self.config.image_token_index = 10 + if not hasattr(self.config, "projector_hidden_act"): + self.config.projector_hidden_act = "gelu" + + self.vision_feature_layer = getattr(config, "vision_feature_layer", -1) + self.vision_feature_select_strategy = getattr( + config, "vision_feature_select_strategy", "full" + ) + self.image_size = self.config.vision_config.image_size + self.patch_size = self.config.vision_config.patch_size + + self.mm_patch_merge_type = config.mm_patch_merge_type + self.image_aspect_ratio = config.image_aspect_ratio + self.image_grid_pinpoints = config.image_grid_pinpoints + + self.image_feature_len = int((self.image_size // self.patch_size) ** 2) + + self.multi_modal_projector = self.MULTIMODAL_PROJECTOR_TYPE(config) + + language_model_cls = self._get_sgl_model_cls( + config.text_config, AutoModelForCausalLM + ) + vision_model_cls = self._get_sgl_model_cls(config.vision_config, AutoModel) + self.language_model = language_model_cls( + config.text_config, + quant_config=quant_config, + prefix=add_prefix("language_model", prefix), + ) + self.vision_tower = vision_model_cls( + config.vision_config, + quant_config=quant_config, + prefix=add_prefix("vision_tower", prefix), + ) + + if "unpad" in getattr(config, "mm_patch_merge_type", ""): + self.language_model.model.image_newline = nn.Parameter( + torch.empty(config.text_config.hidden_size, dtype=torch.float16) + ) + + def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + """Extract features from image inputs. + + Args: + items: List of MultimodalDataItem objects containing image data + Note that an item can be either "image" or "multi-images" + + Returns: + torch.Tensor: features from image inputs, concatenated + """ + features = [] + for item in items: + # in each item, we assume pixel_values is always batched + pixel_values, image_sizes = item.pixel_values, item.image_sizes + image_outputs = self.vision_tower( + pixel_values, image_sizes, output_hidden_states=True + ) + selected_image_feature = image_outputs.hidden_states[ + self.vision_feature_layer + ] + + if self.vision_feature_select_strategy in ["default", "patch"]: + selected_image_feature = selected_image_feature[:, 1:] + elif self.vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise ValueError( + f"Unexpected select feature: {self.vision_feature_select_strategy}" + ) + features.append( + self.multi_modal_projector(selected_image_feature.squeeze(0)) + ) + ret = torch.cat(features, dim=0) + return ret + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + get_embedding: bool = False, + ): + hidden_states = general_mm_embed_routine( + input_ids=input_ids, + forward_batch=forward_batch, + get_embedding=get_embedding, + language_model=self.language_model, + image_data_embedding_func=self.get_image_feature, + placeholder_tokens=None, # using mm_item.pad_value + positions=positions, + ) + + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """Load weights for LlavaForConditionalGeneration. + + Unlike the base class implementation, this one doesn't need to handle + weight name remapping as the weights are already properly structured with + 'language_model' and 'vision_tower' prefixes in the safetensors files. + """ + if ( + self.vision_feature_select_strategy == "patch" + or self.vision_feature_select_strategy == "full" + ): + pass + elif self.vision_feature_select_strategy == "cls_patch": + self.image_feature_len += 1 + else: + raise ValueError( + f"Unexpected select feature: {self.vision_feature_select_strategy}" + ) + + # Create dictionaries for direct parameter loading + params_dict = dict(self.named_parameters()) + + # Load weights directly without remapping + for name, loaded_weight in weights: + for part in ("language_model", "vision_tower"): + if name.startswith(part): + name = name[len(part + ".") :] + getattr(self, part).load_weights([(name, loaded_weight)]) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = [ + LlavaLlamaForCausalLM, + LlavaQwenForCausalLM, + LlavaMistralForCausalLM, + LlavaForConditionalGeneration, +] diff --git a/python/sglang/srt/models/pixtral.py b/python/sglang/srt/models/pixtral.py new file mode 100644 index 000000000..6a5ebab9f --- /dev/null +++ b/python/sglang/srt/models/pixtral.py @@ -0,0 +1,467 @@ +# Copyright 2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" +Using mistral-community/pixtral-12b as reference. +""" + +import logging +import math +from typing import Iterable, List, Optional, Set, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PixtralVisionConfig, PretrainedConfig +from transformers.models.pixtral.modeling_pixtral import PixtralRotaryEmbedding +from transformers.models.pixtral.modeling_pixtral import ( + generate_block_attention_mask as _get_pixtral_attention_mask, +) +from transformers.models.pixtral.modeling_pixtral import position_ids_in_meshgrid + +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.attention.vision import VisionAttention +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import MergedColumnParallelLinear, RowParallelLinear +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens +from sglang.srt.managers.schedule_batch import MultimodalInputs +from sglang.srt.model_loader.weight_utils import default_weight_loader + + +class PixtralHFMLP(nn.Module): + """MLP for PixtralHFVisionModel using SGLang components.""" + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + prefix: str = "", + ) -> None: + super().__init__() + + assert config.intermediate_size is not None + + # Use MergedColumnParallelLinear for gate_up_proj to handle combined weights + self.gate_up_proj = MergedColumnParallelLinear( + input_size=config.hidden_size, + output_sizes=[config.intermediate_size, config.intermediate_size], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + + self.down_proj = RowParallelLinear( + input_size=config.intermediate_size, + output_size=config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up_output, _ = self.gate_up_proj(x) + + # Apply SiLU activation and multiply + gate_up = self.act_fn(gate_up_output) + + # Project back to hidden size + out, _ = self.down_proj(gate_up) + return out + + +class PixtralHFTransformerBlock(nn.Module): + """Transformer block for PixtralHFVisionModel using SGLang components.""" + + def __init__( + self, + config: PretrainedConfig, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + *, + prefix: str = "", + ) -> None: + super().__init__() + + self.layer_id = layer_id + self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5) + + # Use SGLang's VisionAttention instead of vLLM's PixtralHFAttention + self.attention = VisionAttention( + embed_dim=config.hidden_size, + num_heads=config.num_attention_heads, + projection_size=config.hidden_size, + use_qkv_parallel=True, + quant_config=quant_config, + dropout=0.0, + use_context_forward=False, + softmax_in_single_precision=False, + flatten_batch=False, + prefix=f"{prefix}.attention", + ) + + self.feed_forward = PixtralHFMLP( + config, quant_config=quant_config, prefix=f"{prefix}.feed_forward" + ) + + self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]], + ) -> torch.Tensor: + # Ensure hidden_states has the batch dimension [batch, seq_len, hidden_dim] + batch_size, seq_len, hidden_dim = hidden_states.shape + + # Apply attention norm - normalize along the last dimension + attn_normalized = self.attention_norm(hidden_states.view(-1, hidden_dim)).view( + batch_size, seq_len, hidden_dim + ) + + # Pass through attention layer + attention_output = self.attention( + attn_normalized, + attention_mask=attention_mask, + cu_seqlens=None, + position_embeddings=position_embeddings, + ) + + # Apply first residual connection + hidden_states = hidden_states + attention_output + + # Apply feed-forward norm - normalize along the last dimension + ffn_normalized = self.ffn_norm(hidden_states.view(-1, hidden_dim)).view( + batch_size, seq_len, hidden_dim + ) + + # Pass through feed-forward layer + # First reshape to 2D for the feed-forward network, then reshape back + ffn_output = self.feed_forward(ffn_normalized) + + # Apply second residual connection + output = hidden_states + ffn_output + + return output + + +class PixtralHFTransformer(nn.Module): + """Transformer for PixtralHFVisionModel using SGLang components.""" + + def __init__( + self, + config: PixtralVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + prefix: str = "", + ) -> None: + super().__init__() + + num_hidden_layers = config.num_hidden_layers + if num_hidden_layers_override is not None: + num_hidden_layers = num_hidden_layers_override + + self.layers = nn.ModuleList( + [ + PixtralHFTransformerBlock( + config=config, + layer_id=layer_idx, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + ) + for layer_idx in range(num_hidden_layers) + ] + ) + + def forward( + self, + x: torch.Tensor, + attention_mask: Optional[torch.Tensor], + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]], + return_all_hidden_states: bool = False, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Forward pass through transformer layers. + + Args: + x: Input tensor + attention_mask: Optional attention mask + position_embeddings: Optional position embeddings for rotary attention + return_all_hidden_states: Whether to return all hidden states + + Returns: + Either the final hidden state, or a list of all hidden states if + return_all_hidden_states is True + """ + # For HF model compatibility, always start with the input + hidden_states = x + all_hidden_states = [hidden_states] if return_all_hidden_states else None + + for i, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, attention_mask, position_embeddings) + if return_all_hidden_states: + all_hidden_states.append(hidden_states) + + if return_all_hidden_states: + return all_hidden_states + return hidden_states + + +def resolve_visual_encoder_outputs( + outputs: Union[torch.Tensor, List[torch.Tensor]], + feature_sample_layers: Optional[List[int]], + post_norm: Optional[nn.Module], + num_hidden_layers: int, +) -> torch.Tensor: + """Resolve outputs from visual encoder based on feature_sample_layers.""" + if feature_sample_layers is None: + # Just use the last layer's output + if isinstance(outputs, list): + outputs = outputs[-1] + if post_norm is not None: + outputs = post_norm(outputs) + return outputs + + # Handle the case where we want to use specific layers + if not isinstance(outputs, list): + raise ValueError( + "Expected outputs to be a list when feature_sample_layers is provided" + ) + + # Validate layer indices + for layer_idx in feature_sample_layers: + if layer_idx < 0 or layer_idx > num_hidden_layers: + raise ValueError( + f"Feature sample layer index {layer_idx} is out of range " + f"[0, {num_hidden_layers}]" + ) + + # Collect outputs from specified layers + selected_outputs = [outputs[layer_idx] for layer_idx in feature_sample_layers] + + # Combine the outputs + combined_outputs = torch.cat(selected_outputs, dim=-1) + + if post_norm is not None: + combined_outputs = post_norm(combined_outputs) + + return combined_outputs + + +class PixtralHFVisionModel(nn.Module): + """Hugging Face Pixtral Vision Model implemented using SGLang components.""" + + DEFAULT_IMAGE_TOKEN_ID = 10 + + def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): + return self.input_padder.pad_input_tokens(input_ids, image_inputs) + + def __init__( + self, + config: PixtralVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + image_token_id: int = DEFAULT_IMAGE_TOKEN_ID, + num_hidden_layers_override: Optional[int] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_conv = nn.Conv2d( + in_channels=config.num_channels, + out_channels=config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size, + bias=False, + ) + + self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5) + + self.transformer = PixtralHFTransformer( + config, + quant_config, + num_hidden_layers_override=num_hidden_layers_override, + prefix=f"{prefix}.transformer", + ) + + # Check that num_hidden_layers is valid + num_hidden_layers = config.num_hidden_layers + if len(self.transformer.layers) > config.num_hidden_layers: + raise ValueError( + f"The original encoder only has {num_hidden_layers} " + f"layers, but you requested {len(self.transformer.layers)} " + "layers." + ) + + # Initialize patch position embedding + self.image_token_id = image_token_id + self.patch_positional_embedding = PixtralRotaryEmbedding(config) + self.input_padder = MultiModalityDataPaddingPatternMultimodalTokens( + [self.image_token_id] + ) + + @property + def dtype(self): + return next(self.parameters()).dtype + + @property + def device(self): + return next(self.parameters()).device + + def forward( + self, + pixel_values: torch.Tensor, + image_sizes: list[tuple[int, int]], + output_hidden_states: bool = False, + feature_sample_layers: Optional[list[int]] = None, + ) -> Union[torch.Tensor, tuple]: + """ + Args: + pixel_values: [batch_size, C, H, W], padded if multiple images + image_sizes: list of (H, W) for each image in the batch + output_hidden_states: Whether to return all hidden states. + feature_sample_layers: Layer indices whose features should be + concatenated and used as the visual encoder output. If none + are provided, the last layer is used. + + Returns: + A tuple containing: + - hidden_states: Final model outputs (or selected layers if feature_sample_layers given) + - hidden_states tuple (optional): All hidden states if output_hidden_states=True + """ + # batch patch images + embeds_orig = self.patch_conv( + pixel_values.to(device=self.device, dtype=self.dtype) + ) + # crop the embeddings + embeds_2d = [ + embed[..., : h // self.patch_size, : w // self.patch_size] + for embed, (h, w) in zip(embeds_orig, image_sizes) + ] + + # flatten to sequence + embeds_1d = torch.cat([p.flatten(1).T for p in embeds_2d], dim=0) + embeds_featurized = self.ln_pre(embeds_1d).unsqueeze(0) + + # positional embeddings + position_ids = position_ids_in_meshgrid( + embeds_2d, + max_width=self.image_size // self.patch_size, + ).to(self.device) + + # The original PixtralRotaryEmbedding expects 2D input but returns a tuple of tensors (cos, sin) + # These tensors are used by apply_rotary_pos_emb in the transformer blocks + position_embedding = self.patch_positional_embedding( + embeds_featurized, position_ids + ) + attention_mask = _get_pixtral_attention_mask( + [p.shape[-2] * p.shape[-1] for p in embeds_2d], embeds_featurized + ) + + return_all_hidden_states = ( + output_hidden_states or feature_sample_layers is not None + ) + + transformer_outputs = self.transformer( + embeds_featurized, # add batch dimension + attention_mask, + position_embedding, + return_all_hidden_states=return_all_hidden_states, + ) + + # Store all hidden states if requested + all_hidden_states = None + if isinstance(transformer_outputs, list): + all_hidden_states = transformer_outputs + # Use the last layer by default if feature_sample_layers is not specified + if feature_sample_layers is None: + out = transformer_outputs[-1] + else: + # Resolve outputs based on feature sample layers + out = resolve_visual_encoder_outputs( + transformer_outputs, + feature_sample_layers, + None, + self.config.num_hidden_layers, + ) + else: + out = transformer_outputs + + # Format return to be compatible with HuggingFace vision models + if output_hidden_states: + return type( + "VisualOutput", + (), + { + "last_hidden_state": out, + "hidden_states": all_hidden_states, + }, + ) + else: + return out + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: + """Load weights from a HuggingFace checkpoint with proper parameter mapping.""" + params_dict = dict(self.named_parameters()) + + # for (param, weight, shard_id): load weight into param as param's shard_id part + stacked_params_mapping = [ + (".attention.qkv_proj", ".attention.q_proj", "q"), + (".attention.qkv_proj", ".attention.k_proj", "k"), + (".attention.qkv_proj", ".attention.v_proj", "v"), + (".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0), + (".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1), + ] + + # Process each weight + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name in name: + # Replace the weight name part with the combined parameter name + transformed_name = name.replace(weight_name, param_name) + if transformed_name in params_dict: + param = params_dict[transformed_name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight, shard_id) + break + else: + if ".attention.o_proj" in name: + alt_name = name.replace(".attention.o_proj", ".attention.proj") + if alt_name in params_dict: + name = alt_name + if name in params_dict: + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + + +class PixtralVisionModel(PixtralHFVisionModel): + pass + + +# Register the model classes for external access +EntryClass = [PixtralVisionModel] diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 6b4927d96..eece963c7 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -19,7 +19,9 @@ from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F +import transformers from transformers import ( + AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForVision2Seq, @@ -211,7 +213,12 @@ class HFRunner: # Load the model and tokenizer if self.model_type == "generation": - self.base_model = AutoModelForCausalLM.from_pretrained( + config = AutoConfig.from_pretrained(model_path) + if model_archs := getattr(config, "architectures"): + model_cls = getattr(transformers, model_archs[0]) + else: + model_cls = AutoModelForCausalLM + self.base_model = model_cls.from_pretrained( model_path, torch_dtype=torch_dtype, trust_remote_code=self.trust_remote_code, diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index 237af60cd..54ba7b8f2 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -14,14 +14,15 @@ """ Usage: -To test a specific model: -1. Add it to ALL_OTHER_MODELS -2. Run `ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels.test_others` +To test a specific model locally: +1. Add it to ALL_MODELS, for example, `ModelCase("Qwen/Qwen2-1.5B")` +2. Run `ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels` """ import dataclasses import multiprocessing as mp import os +import random import unittest from typing import List @@ -53,8 +54,9 @@ CI_MODELS = [ ModelCase("google/gemma-2-2b"), ] -# All other models that do not run on the CI -ALL_OTHER_MODELS = [ +# the complete set of models to test sglang's generation model +ALL_MODELS = [ + *CI_MODELS, ModelCase("Qwen/Qwen2-1.5B"), ModelCase("Qwen/Qwen2.5-14B-Instruct"), ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True), @@ -63,7 +65,7 @@ ALL_OTHER_MODELS = [ "THUDM/glm-4-9b-chat", tp_size=2, trust_remote_code=True, skip_long_prompt=True ), ModelCase("openai-community/gpt2"), - ModelCase("microsoft/Phi-3-small-8k-instruct"), + ModelCase("microsoft/Phi-3-small-8k-instruct", trust_remote_code=True), ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True), ModelCase("ibm-granite/granite-3.0-2b-instruct", skip_long_prompt=True), ] @@ -117,9 +119,30 @@ class TestGenerationModels(CustomTestCase): debug_text=f"model_path={model_path} prompts={prompts}", ) + @unittest.skipIf(not is_in_ci(), "Local test should run all models") def test_ci_models(self): for model_case in CI_MODELS: for torch_dtype in TORCH_DTYPES: + prompts = DEFAULT_PROMPTS + + # Skip long prompts for models that do not have a long context + if model_case.skip_long_prompt: + prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000] + + # Assert the logits and output strs are close + self.assert_close_logits_and_output_strs( + prompts, model_case, torch_dtype + ) + + @unittest.skipIf(is_in_ci(), "CI only runs selected models for simplicity") + def test_all_models(self): + for model_case in ALL_MODELS: + for torch_dtype in TORCH_DTYPES: + if ( + "ONLY_RUN" in os.environ + and os.environ["ONLY_RUN"] != model_case.model_path + ): + continue # Skip long prompts for models that do not have a long context prompts = DEFAULT_PROMPTS @@ -131,26 +154,6 @@ class TestGenerationModels(CustomTestCase): prompts, model_case, torch_dtype ) - def test_others(self): - if is_in_ci(): - return - - for model_case in ALL_OTHER_MODELS: - # Only run a specified model - if ( - "ONLY_RUN" in os.environ - and os.environ["ONLY_RUN"] != model_case.model_path - ): - continue - - # Skip long prompts for models that do not have a long context - prompts = DEFAULT_PROMPTS - if model_case.skip_long_prompt: - prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000] - - # Assert the logits and output strs are close - self.assert_close_logits_and_output_strs(prompts, model_case, torch.float16) - if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index 9bd074829..be55f4bcd 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -642,6 +642,28 @@ class TestMinicpmoServer(TestOpenAIVisionServer): self._test_audio_ambient_completion() +class TestPixtralServer(TestOpenAIVisionServer): + @classmethod + def setUpClass(cls): + cls.model = "mistral-community/pixtral-12b" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--mem-fraction-static", + "0.73", + ], + ) + cls.base_url += "/v1" + + def test_video_chat_completion(self): + pass + + class TestDeepseekVL2Server(TestOpenAIVisionServer): @classmethod def setUpClass(cls):