model(vlm): pixtral (#5084)
This commit is contained in:
@@ -20,6 +20,7 @@ python3 -m sglang.launch_server \
|
|||||||
| **Janus-Pro** (1B, 7B) | `deepseek-ai/Janus-Pro-7B` | `janus-pro` | DeepSeek’s open-source multimodal model capable of both image understanding and generation. Janus-Pro employs a decoupled architecture for separate visual encoding paths, enhancing performance in both tasks. |
|
| **Janus-Pro** (1B, 7B) | `deepseek-ai/Janus-Pro-7B` | `janus-pro` | DeepSeek’s open-source multimodal model capable of both image understanding and generation. Janus-Pro employs a decoupled architecture for separate visual encoding paths, enhancing performance in both tasks. |
|
||||||
| **MiniCPM-V / MiniCPM-o** | `openbmb/MiniCPM-V-2_6` | `minicpmv` | MiniCPM-V (2.6, ~8B) supports image inputs, and MiniCPM-o adds audio/video; these multimodal LLMs are optimized for end-side deployment on mobile/edge devices. |
|
| **MiniCPM-V / MiniCPM-o** | `openbmb/MiniCPM-V-2_6` | `minicpmv` | MiniCPM-V (2.6, ~8B) supports image inputs, and MiniCPM-o adds audio/video; these multimodal LLMs are optimized for end-side deployment on mobile/edge devices. |
|
||||||
| **Llama 3.2 Vision** (11B) | `meta-llama/Llama-3.2-11B-Vision-Instruct` | `llama_3_vision` | Vision-enabled variant of Llama 3 (11B) that accepts image inputs for visual question answering and other multimodal tasks. |
|
| **Llama 3.2 Vision** (11B) | `meta-llama/Llama-3.2-11B-Vision-Instruct` | `llama_3_vision` | Vision-enabled variant of Llama 3 (11B) that accepts image inputs for visual question answering and other multimodal tasks. |
|
||||||
|
| **Pixtral** (12B, 124B) | `mistral-community/pixtral-12b` | `mistral` | Pixtral is a vision-language model from Mistral AI that can process both text and images. |
|
||||||
| **LLaVA** (v1.5 & v1.6) | *e.g.* `liuhaotian/llava-v1.5-13b` | `vicuna_v1.1` | Open vision-chat models that add an image encoder to LLaMA/Vicuna (e.g. LLaMA2 13B) for following multimodal instruction prompts. |
|
| **LLaVA** (v1.5 & v1.6) | *e.g.* `liuhaotian/llava-v1.5-13b` | `vicuna_v1.1` | Open vision-chat models that add an image encoder to LLaMA/Vicuna (e.g. LLaMA2 13B) for following multimodal instruction prompts. |
|
||||||
| **LLaVA-NeXT** (8B, 72B) | `lmms-lab/llava-next-72b` | `chatml-llava` | Improved LLaVA models (with an 8B Llama3 version and a 72B version) offering enhanced visual instruction-following and accuracy on multimodal benchmarks. |
|
| **LLaVA-NeXT** (8B, 72B) | `lmms-lab/llava-next-72b` | `chatml-llava` | Improved LLaVA models (with an 8B Llama3 version and a 72B version) offering enhanced visual instruction-following and accuracy on multimodal benchmarks. |
|
||||||
| **LLaVA-OneVision** | `lmms-lab/llava-onevision-qwen2-7b-ov` | `chatml-llava` | Enhanced LLaVA variant integrating Qwen as the backbone; supports multiple images (and even video frames) as inputs via an OpenAI Vision API-compatible format. |
|
| **LLaVA-OneVision** | `lmms-lab/llava-onevision-qwen2-7b-ov` | `chatml-llava` | Enhanced LLaVA variant integrating Qwen as the backbone; supports multiple images (and even video frames) as inputs via an OpenAI Vision API-compatible format. |
|
||||||
|
|||||||
@@ -33,9 +33,10 @@ The `hidden_states` folder contains examples on how to extract hidden states usi
|
|||||||
* `hidden_states_engine.py`: An example how to extract hidden states using the Engine API.
|
* `hidden_states_engine.py`: An example how to extract hidden states using the Engine API.
|
||||||
* `hidden_states_server.py`: An example how to extract hidden states using the Server API.
|
* `hidden_states_server.py`: An example how to extract hidden states using the Server API.
|
||||||
|
|
||||||
## LLaVA-NeXT
|
## Multimodal
|
||||||
|
|
||||||
|
SGLang supports multimodal inputs for various model architectures. The `multimodal` folder contains examples showing how to use urls, files or encoded data to make requests to multimodal models. Examples include querying the [Llava-OneVision](multimodal/llava_onevision_server.py) model (image, multi-image, video), Llava-backed [Qwen-Llava](multimodal/qwen_llava_server.py) and [Llama3-Llava](multimodal/llama3_llava_server.py) models (image, multi-image), and Mistral AI's [Pixtral](multimodal/pixtral_server.py) (image, multi-image).
|
||||||
|
|
||||||
SGLang support LLaVA-OneVision with single-image, multi-image and video are supported. The folder `llava_onevision` shows how to do this.
|
|
||||||
|
|
||||||
## Token In, Token Out
|
## Token In, Token Out
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ Usage:
|
|||||||
# Endpoint Service CLI:
|
# Endpoint Service CLI:
|
||||||
python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000
|
python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000
|
||||||
|
|
||||||
python3 http_llama3_llava_test.py
|
python3 llama3_llava_server.py
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
"Friends posing for a fun photo with a life-sized teddy bear, creating a playful and memorable moment."
|
"Friends posing for a fun photo with a life-sized teddy bear, creating a playful and memorable moment."
|
||||||
@@ -3,7 +3,7 @@ Usage:
|
|||||||
|
|
||||||
python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8
|
python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8
|
||||||
|
|
||||||
python3 http_llava_onevision_test.py
|
python3 llava_onevision_server.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
127
examples/runtime/multimodal/pixtral_server.py
Normal file
127
examples/runtime/multimodal/pixtral_server.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
# Run a Pixtral model with SGLang:
|
||||||
|
# HuggingFace:
|
||||||
|
python -m sglang.launch_server --model-path mistral-community/pixtral-12b --port=30000
|
||||||
|
# ModelScope:
|
||||||
|
python -m sglang.launch_server --model-path AI-ModelScope/pixtral-12b --port=30000
|
||||||
|
|
||||||
|
# Then test it with:
|
||||||
|
python pixtral_server.py
|
||||||
|
|
||||||
|
This script tests Pixtral model with both single and multiple images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import requests
|
||||||
|
|
||||||
|
IMAGE_TOKEN_SEP = "\n[IMG]"
|
||||||
|
ROUTE = "/generate"
|
||||||
|
|
||||||
|
|
||||||
|
async def send_request(url, data, delay=0):
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(url, json=data) as resp:
|
||||||
|
output = await resp.json()
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
async def test_concurrent(args):
|
||||||
|
url = f"{args.host}:{args.port}{ROUTE}"
|
||||||
|
|
||||||
|
# Single image test
|
||||||
|
if args.single_image:
|
||||||
|
prompt = f"<s>[INST]Describe this image in detail.{IMAGE_TOKEN_SEP}[/INST]"
|
||||||
|
image_url = "https://picsum.photos/id/237/400/300"
|
||||||
|
modality = ["image"]
|
||||||
|
# Multiple images test
|
||||||
|
else:
|
||||||
|
image_urls = [
|
||||||
|
"https://picsum.photos/id/237/400/300",
|
||||||
|
"https://picsum.photos/id/27/500/500",
|
||||||
|
]
|
||||||
|
prompt = f"<s>[INST]How many photos are there? Describe each in a very short sentence.{IMAGE_TOKEN_SEP * len(image_urls)}[/INST]"
|
||||||
|
image_url = image_urls
|
||||||
|
modality = ["multi-images"]
|
||||||
|
|
||||||
|
response = await send_request(
|
||||||
|
url,
|
||||||
|
{
|
||||||
|
"text": prompt,
|
||||||
|
"image_data": image_url,
|
||||||
|
"sampling_params": {
|
||||||
|
"max_new_tokens": 100,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"top_p": 0.9,
|
||||||
|
},
|
||||||
|
"modalities": modality,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Response: {response}")
|
||||||
|
if "text" in response:
|
||||||
|
print("\nOutput text:", response["text"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming(args):
|
||||||
|
url = f"{args.host}:{args.port}/generate"
|
||||||
|
|
||||||
|
# Single image test
|
||||||
|
if args.single_image:
|
||||||
|
prompt = f"<s>[INST]Describe this image in detail.{IMAGE_TOKEN_SEP}[/INST]"
|
||||||
|
image_data = "https://picsum.photos/id/237/400/300"
|
||||||
|
modality = ["image"]
|
||||||
|
# Multiple images test
|
||||||
|
else:
|
||||||
|
image_urls = [
|
||||||
|
"https://picsum.photos/id/237/400/300",
|
||||||
|
"https://picsum.photos/id/27/500/500",
|
||||||
|
]
|
||||||
|
prompt = f"<s>[INST]How many photos are there? Describe each in a very short sentence.{IMAGE_TOKEN_SEP * len(image_urls)}[/INST]"
|
||||||
|
image_data = image_urls
|
||||||
|
modality = ["multi-images"]
|
||||||
|
|
||||||
|
pload = {
|
||||||
|
"text": prompt,
|
||||||
|
"image_data": image_data,
|
||||||
|
"sampling_params": {"max_new_tokens": 100, "temperature": 0.7, "top_p": 0.9},
|
||||||
|
"modalities": modality,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(url, json=pload, stream=True)
|
||||||
|
|
||||||
|
print("Streaming response:")
|
||||||
|
prev = 0
|
||||||
|
for chunk in response.iter_lines(decode_unicode=False):
|
||||||
|
chunk = chunk.decode("utf-8")
|
||||||
|
if chunk and chunk.startswith("data:"):
|
||||||
|
if chunk == "data: [DONE]":
|
||||||
|
break
|
||||||
|
data = json.loads(chunk[5:].strip("\n"))
|
||||||
|
output = data["text"].strip()
|
||||||
|
print(output[prev:], end="", flush=True)
|
||||||
|
prev = len(output)
|
||||||
|
print("\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||||
|
parser.add_argument("--port", type=int, default=30000)
|
||||||
|
parser.add_argument(
|
||||||
|
"--single-image",
|
||||||
|
action="store_true",
|
||||||
|
help="Test with single image instead of multiple images",
|
||||||
|
)
|
||||||
|
parser.add_argument("--no-stream", action="store_true", help="Don't test streaming")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
asyncio.run(test_concurrent(args))
|
||||||
|
if not args.no_stream:
|
||||||
|
test_streaming(args)
|
||||||
@@ -6,7 +6,7 @@ Usage:
|
|||||||
# Endpoint Service CLI:
|
# Endpoint Service CLI:
|
||||||
python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --port=30000 --tp-size=8
|
python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --port=30000 --tp-size=8
|
||||||
|
|
||||||
python3 http_qwen_llava_test.py
|
python3 qwen_llava_server.py
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
"Two children pose with a large teddy bear, one holding a smaller stuffed bear, in a room with an American flag and potted plants."
|
"Two children pose with a large teddy bear, one holding a smaller stuffed bear, in a room with an American flag and potted plants."
|
||||||
@@ -194,6 +194,21 @@ register_chat_template(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Reference: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/chat_template.json
|
||||||
|
register_chat_template(
|
||||||
|
ChatTemplate(
|
||||||
|
name="mistral",
|
||||||
|
default_system_prompt=None,
|
||||||
|
role_prefix_and_suffix={
|
||||||
|
"system": ("[SYSTEM_PROMPT] ", " [/SYSTEM_PROMPT]"),
|
||||||
|
"user": ("[INST] ", " [/INST]"),
|
||||||
|
"assistant": ("", " </s><s>"),
|
||||||
|
},
|
||||||
|
stop_str=("</s>",),
|
||||||
|
image_token="[IMG]",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
register_chat_template(
|
register_chat_template(
|
||||||
ChatTemplate(
|
ChatTemplate(
|
||||||
name="llama-3-instruct",
|
name="llama-3-instruct",
|
||||||
@@ -509,13 +524,19 @@ def match_vicuna(model_path: str):
|
|||||||
@register_chat_template_matching_function
|
@register_chat_template_matching_function
|
||||||
def match_llama2_chat(model_path: str):
|
def match_llama2_chat(model_path: str):
|
||||||
if re.search(
|
if re.search(
|
||||||
r"llama-2.*chat|(mistral|mixtral).*instruct|codellama.*instruct",
|
r"llama-2.*chat|codellama.*instruct",
|
||||||
model_path,
|
model_path,
|
||||||
re.IGNORECASE,
|
re.IGNORECASE,
|
||||||
):
|
):
|
||||||
return "llama-2-chat"
|
return "llama-2-chat"
|
||||||
|
|
||||||
|
|
||||||
|
@register_chat_template_matching_function
|
||||||
|
def match_mistral(model_path: str):
|
||||||
|
if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE):
|
||||||
|
return "mistral"
|
||||||
|
|
||||||
|
|
||||||
@register_chat_template_matching_function
|
@register_chat_template_matching_function
|
||||||
def match_llama3_instruct(model_path: str):
|
def match_llama3_instruct(model_path: str):
|
||||||
if re.search(r"llama-3.*instruct", model_path, re.IGNORECASE):
|
if re.search(r"llama-3.*instruct", model_path, re.IGNORECASE):
|
||||||
|
|||||||
@@ -545,6 +545,7 @@ multimodal_model_archs = [
|
|||||||
"Llama4ForConditionalGeneration",
|
"Llama4ForConditionalGeneration",
|
||||||
"LlavaMistralForCausalLM",
|
"LlavaMistralForCausalLM",
|
||||||
"LlavaQwenForCausalLM",
|
"LlavaQwenForCausalLM",
|
||||||
|
"LlavaForConditionalGeneration",
|
||||||
"LlavaVidForCausalLM",
|
"LlavaVidForCausalLM",
|
||||||
"MiniCPMO",
|
"MiniCPMO",
|
||||||
"MiniCPMV",
|
"MiniCPMV",
|
||||||
|
|||||||
@@ -634,6 +634,20 @@ register_conv_template(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# reference: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/chat_template.json
|
||||||
|
register_conv_template(
|
||||||
|
Conversation(
|
||||||
|
name="mistral",
|
||||||
|
system_template="[SYSTEM_PROMPT]\n{system_message}\n[/SYSTEM_PROMPT]\n\n",
|
||||||
|
roles=("[INST]", "[/INST]"),
|
||||||
|
sep_style=SeparatorStyle.LLAMA2,
|
||||||
|
sep=" ",
|
||||||
|
sep2=" </s><s>",
|
||||||
|
stop_str=["[INST]", "[/INST]", "[SYSTEM_PROMPT]", "[/SYSTEM_PROMPT]"],
|
||||||
|
image_token="[IMG]",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
|
# reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
|
||||||
register_conv_template(
|
register_conv_template(
|
||||||
Conversation(
|
Conversation(
|
||||||
@@ -880,13 +894,19 @@ def match_vicuna(model_path: str):
|
|||||||
@register_conv_template_matching_function
|
@register_conv_template_matching_function
|
||||||
def match_llama2_chat(model_path: str):
|
def match_llama2_chat(model_path: str):
|
||||||
if re.search(
|
if re.search(
|
||||||
r"llama-2.*chat|(mistral|mixtral).*instruct|codellama.*instruct",
|
r"llama-2.*chat|codellama.*instruct",
|
||||||
model_path,
|
model_path,
|
||||||
re.IGNORECASE,
|
re.IGNORECASE,
|
||||||
):
|
):
|
||||||
return "llama-2"
|
return "llama-2"
|
||||||
|
|
||||||
|
|
||||||
|
@register_conv_template_matching_function
|
||||||
|
def match_mistral(model_path: str):
|
||||||
|
if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE):
|
||||||
|
return "mistral"
|
||||||
|
|
||||||
|
|
||||||
@register_conv_template_matching_function
|
@register_conv_template_matching_function
|
||||||
def match_deepseek_vl(model_path: str):
|
def match_deepseek_vl(model_path: str):
|
||||||
if re.search(r"deepseek.*vl2", model_path, re.IGNORECASE):
|
if re.search(r"deepseek.*vl2", model_path, re.IGNORECASE):
|
||||||
|
|||||||
@@ -1,14 +1,20 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import importlib
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from transformers.models.auto.processing_auto import (
|
||||||
|
PROCESSOR_MAPPING_NAMES as HF_MAPPING_NAMES,
|
||||||
|
)
|
||||||
|
|
||||||
|
import sglang.srt.managers.multimodal_processor as sgl_mm_processor_utils
|
||||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||||
BaseMultimodalProcessor,
|
BaseMultimodalProcessor,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||||
from sglang.srt.models.llava import (
|
from sglang.srt.models.llava import (
|
||||||
|
LlavaForConditionalGeneration,
|
||||||
LlavaLlamaForCausalLM,
|
LlavaLlamaForCausalLM,
|
||||||
LlavaMistralForCausalLM,
|
LlavaMistralForCausalLM,
|
||||||
LlavaQwenForCausalLM,
|
LlavaQwenForCausalLM,
|
||||||
@@ -133,6 +139,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
|||||||
img_data, aspect_ratio, grid_pinpoints
|
img_data, aspect_ratio, grid_pinpoints
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
res = await asyncio.gather(*res)
|
res = await asyncio.gather(*res)
|
||||||
for pixel_v, image_h, image_s in res:
|
for pixel_v, image_h, image_s in res:
|
||||||
pixel_values.append(pixel_v)
|
pixel_values.append(pixel_v)
|
||||||
@@ -165,3 +172,42 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
|||||||
)
|
)
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaMultimodalProcessor(BaseMultimodalProcessor):
|
||||||
|
"""
|
||||||
|
This is a wrapper class used to identify the multimodal processor for Llava architecture models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
models = [LlavaForConditionalGeneration]
|
||||||
|
|
||||||
|
def _get_sgl_processor_cls(self, model_type: str):
|
||||||
|
if hf_name := HF_MAPPING_NAMES.get(model_type):
|
||||||
|
sgl_mm_processor_set = sgl_mm_processor_utils.PROCESSOR_MAPPING.values()
|
||||||
|
sgl_processor_cls = list(
|
||||||
|
filter(lambda p: p.__name__ == hf_name, sgl_mm_processor_set)
|
||||||
|
)
|
||||||
|
if sgl_processor_cls:
|
||||||
|
return sgl_processor_cls[0]
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot find corresponding multimodal processor registered in sglang for model type `{model_type}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hf_config, server_args, _processor):
|
||||||
|
assert hasattr(hf_config, "vision_config")
|
||||||
|
assert hasattr(hf_config, "text_config")
|
||||||
|
self.vision_config = hf_config.vision_config
|
||||||
|
self.text_config = hf_config.text_config
|
||||||
|
self.hf_config = hf_config
|
||||||
|
|
||||||
|
if vision_type := getattr(self.vision_config, "model_type"):
|
||||||
|
self.inner = self._get_sgl_processor_cls(vision_type)(
|
||||||
|
hf_config, server_args, _processor
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Required `vision_config.model_type` is not found in hf_config: `{hf_config}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def process_mm_data_async(self, *args, **kwargs):
|
||||||
|
return await self.inner.process_mm_data_async(*args, **kwargs)
|
||||||
|
|||||||
127
python/sglang/srt/managers/multimodal_processors/pixtral.py
Normal file
127
python/sglang/srt/managers/multimodal_processors/pixtral.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
import asyncio
|
||||||
|
import math
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
from transformers.models.pixtral.image_processing_pixtral import (
|
||||||
|
_num_image_tokens as _get_pixtral_hf_num_image_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||||
|
BaseMultimodalProcessor,
|
||||||
|
MultimodalSpecialTokens,
|
||||||
|
)
|
||||||
|
from sglang.srt.managers.schedule_batch import (
|
||||||
|
Modality,
|
||||||
|
MultimodalDataItem,
|
||||||
|
MultimodalInputs,
|
||||||
|
)
|
||||||
|
from sglang.srt.models.pixtral import PixtralVisionModel
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralProcessor(BaseMultimodalProcessor):
|
||||||
|
models = [PixtralVisionModel]
|
||||||
|
|
||||||
|
PAD_TOKEN = "<pad>"
|
||||||
|
IMG_BREAK_TOKEN_ID = 12
|
||||||
|
IMG_END_TOKEN_ID = 13
|
||||||
|
|
||||||
|
def get_patch_grid_size(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
image_width: int,
|
||||||
|
image_height: int,
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
max_width = max_height = self.image_size
|
||||||
|
patch_width = patch_height = self.patch_size
|
||||||
|
|
||||||
|
ratio = max(image_width / max_width, image_height / max_height)
|
||||||
|
|
||||||
|
if ratio > 1:
|
||||||
|
image_width = int(math.floor(image_width / ratio))
|
||||||
|
image_height = int(math.floor(image_height / ratio))
|
||||||
|
|
||||||
|
nrows, ncols = _get_pixtral_hf_num_image_tokens(
|
||||||
|
(image_height, image_width),
|
||||||
|
(patch_height, patch_width),
|
||||||
|
)
|
||||||
|
|
||||||
|
return ncols, nrows
|
||||||
|
|
||||||
|
def __init__(self, hf_config, server_args, _processor):
|
||||||
|
super().__init__(hf_config, server_args, _processor)
|
||||||
|
self.image_token_id = getattr(
|
||||||
|
hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID
|
||||||
|
)
|
||||||
|
# Instantiate the patcher logic helper using the class defined above
|
||||||
|
|
||||||
|
self.vision_config = hf_config.vision_config
|
||||||
|
self.image_size = self.vision_config.image_size
|
||||||
|
self.patch_size = self.vision_config.patch_size
|
||||||
|
self.multimodal_tokens = MultimodalSpecialTokens(
|
||||||
|
image_token=_processor.image_token
|
||||||
|
)
|
||||||
|
_processor.tokenizer.add_special_tokens(
|
||||||
|
{
|
||||||
|
"pad_token": getattr(hf_config, "pad_token", self.PAD_TOKEN),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _resize(self, image):
|
||||||
|
num_w_tokens, num_h_tokens = self.get_patch_grid_size(
|
||||||
|
image_width=image.size[0],
|
||||||
|
image_height=image.size[1],
|
||||||
|
)
|
||||||
|
new_size = (num_w_tokens * self.patch_size, num_h_tokens * self.patch_size)
|
||||||
|
return image.resize(new_size)
|
||||||
|
|
||||||
|
async def process_mm_data_async(
|
||||||
|
self,
|
||||||
|
image_data: List[Union[str, bytes]],
|
||||||
|
input_text,
|
||||||
|
request_obj,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if not image_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(image_data, str):
|
||||||
|
image_data = [image_data]
|
||||||
|
|
||||||
|
mm_data = self.load_mm_data(
|
||||||
|
prompt=input_text,
|
||||||
|
multimodal_tokens=self.multimodal_tokens,
|
||||||
|
max_req_input_len=kwargs.get("max_req_input_len", 4096),
|
||||||
|
image_data=image_data,
|
||||||
|
return_text=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if mm_data.images:
|
||||||
|
resize_tasks = [self._resize(image) for image in mm_data.images]
|
||||||
|
mm_data.images = await asyncio.gather(*resize_tasks)
|
||||||
|
|
||||||
|
processor_output = self.process_mm_data(
|
||||||
|
input_text=mm_data.input_text,
|
||||||
|
images=mm_data.images,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "pixel_values" in processor_output:
|
||||||
|
mm_items = [
|
||||||
|
MultimodalDataItem(
|
||||||
|
pixel_values=processor_output["pixel_values"],
|
||||||
|
image_sizes=processor_output["image_sizes"],
|
||||||
|
modality=Modality.IMAGE,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
input_ids = processor_output["input_ids"].view(-1).tolist()
|
||||||
|
processor_output.update(
|
||||||
|
input_ids=input_ids,
|
||||||
|
mm_items=mm_items,
|
||||||
|
# there's no im_start_id for pixtral, only im_token and im_end_token
|
||||||
|
im_end_id=self.IMG_END_TOKEN_ID,
|
||||||
|
im_token_id=self.image_token_id,
|
||||||
|
)
|
||||||
|
return processor_output
|
||||||
@@ -15,7 +15,8 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
from typing import Iterable, List, Optional, Tuple
|
from functools import lru_cache
|
||||||
|
from typing import Dict, Iterable, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -28,10 +29,18 @@ from transformers import (
|
|||||||
Qwen2Config,
|
Qwen2Config,
|
||||||
SiglipVisionModel,
|
SiglipVisionModel,
|
||||||
)
|
)
|
||||||
|
from transformers.models.auto.modeling_auto import AutoModel, AutoModelForCausalLM
|
||||||
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
||||||
|
|
||||||
|
# leave till last and symbol only in case circular import
|
||||||
|
import sglang.srt.models as sgl_models
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalInputs
|
from sglang.srt.managers.mm_utils import general_mm_embed_routine
|
||||||
|
from sglang.srt.managers.schedule_batch import (
|
||||||
|
Modality,
|
||||||
|
MultimodalDataItem,
|
||||||
|
MultimodalInputs,
|
||||||
|
)
|
||||||
from sglang.srt.mm_utils import (
|
from sglang.srt.mm_utils import (
|
||||||
get_anyres_image_grid_shape,
|
get_anyres_image_grid_shape,
|
||||||
unpad_image,
|
unpad_image,
|
||||||
@@ -42,7 +51,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|||||||
from sglang.srt.models.llama import LlamaForCausalLM
|
from sglang.srt.models.llama import LlamaForCausalLM
|
||||||
from sglang.srt.models.mistral import MistralForCausalLM
|
from sglang.srt.models.mistral import MistralForCausalLM
|
||||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||||
from sglang.srt.utils import add_prefix, flatten_nested_list
|
from sglang.srt.utils import add_prefix, flatten_nested_list, logger
|
||||||
|
|
||||||
|
|
||||||
class LlavaBaseForCausalLM(nn.Module):
|
class LlavaBaseForCausalLM(nn.Module):
|
||||||
@@ -114,7 +123,16 @@ class LlavaBaseForCausalLM(nn.Module):
|
|||||||
image_inputs.image_offsets = offset_list
|
image_inputs.image_offsets = offset_list
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
def encode_images(
|
||||||
|
self, pixel_values: Union[torch.Tensor, List[torch.Tensor]]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
encode images by vision tower and multimodal projector
|
||||||
|
Args:
|
||||||
|
pixel_values: torch.Tensor or List[torch.Tensor]: each tensor for an input image
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: encoded image features from the input image; if multiple, flattened by seq_len axis
|
||||||
|
"""
|
||||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||||
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
|
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
|
||||||
|
|
||||||
@@ -583,4 +601,229 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
|
class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
|
||||||
|
"""
|
||||||
|
An adaptor class to enable support for multiple mmlm such as mistral-community/pixtral-12b
|
||||||
|
It follows the structure of (vision_tower, multi_modal_projector, language_model)
|
||||||
|
|
||||||
|
Once a model config is loaded, text_config and vision_config will be extracted, and
|
||||||
|
LlavaForConditionalGeneration will load the language_model and vision_tower models
|
||||||
|
according to config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
MULTIMODAL_PROJECTOR_TYPE = LlavaMultiModalProjector
|
||||||
|
|
||||||
|
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
||||||
|
if hasattr(self.vision_tower, "pad_input_ids"):
|
||||||
|
return self.vision_tower.pad_input_ids(input_ids, image_inputs)
|
||||||
|
else:
|
||||||
|
return super().pad_input_ids(input_ids, image_inputs)
|
||||||
|
|
||||||
|
def _get_sgl_model_cls(self, config, auto_model_type: Type[AutoModel] = AutoModel):
|
||||||
|
"""
|
||||||
|
Get the SGLang model implementation class according to config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The config object of the model.
|
||||||
|
auto_model_type: The type of the auto model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The SGLang model implementation class.
|
||||||
|
"""
|
||||||
|
config_cls_name = config.__class__.__name__
|
||||||
|
arch_name_mapping = self._config_cls_name_to_arch_name_mapping(auto_model_type)
|
||||||
|
if arch := arch_name_mapping.get(config_cls_name):
|
||||||
|
if isinstance(arch, tuple):
|
||||||
|
arch = arch[0]
|
||||||
|
logger.warning(
|
||||||
|
f"Multiple {auto_model_type.__name__} models found for submodule config `{config_cls_name}`, defaulting to [0]: {arch.__name__}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
return sgl_models.registry.ModelRegistry.resolve_model_cls(arch)[0]
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"{auto_model_type.__name__} found a corresponding model `{arch}` for config class `{config_cls_name}`, but failed to load it from SGLang ModelRegistry. \n{e}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"{auto_model_type.__name__} cannot find a corresponding model for config class `{config_cls_name}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def _config_cls_name_to_arch_name_mapping(
|
||||||
|
self, auto_model_type: Type[AutoModel]
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
mapping = {}
|
||||||
|
for config_cls, archs in auto_model_type._model_mapping.items():
|
||||||
|
if isinstance(archs, tuple):
|
||||||
|
mapping[config_cls.__name__] = tuple(arch.__name__ for arch in archs)
|
||||||
|
else:
|
||||||
|
mapping[config_cls.__name__] = archs.__name__
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LlavaConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert hasattr(config, "text_config")
|
||||||
|
assert hasattr(config, "vision_config")
|
||||||
|
self.config = config
|
||||||
|
self.text_config = config.text_config
|
||||||
|
self.vision_config = config.vision_config
|
||||||
|
|
||||||
|
if not hasattr(self.config, "vocab_size"):
|
||||||
|
self.config.vocab_size = self.config.text_config.vocab_size
|
||||||
|
if not hasattr(self.config, "image_aspect_ratio"):
|
||||||
|
self.config.image_aspect_ratio = "anyres"
|
||||||
|
if not hasattr(self.config, "image_grid_pinpoints"):
|
||||||
|
# from transformers.models.llava_onevision.configuration_llava_onevision import LlavaOnevisionConfig
|
||||||
|
# self.config.image_grid_pinpoints = LlavaOnevisionConfig().image_grid_pinpoints
|
||||||
|
self.config.image_grid_pinpoints = [
|
||||||
|
[96, 96],
|
||||||
|
[224, 224],
|
||||||
|
[384, 384],
|
||||||
|
[512, 512],
|
||||||
|
[768, 768],
|
||||||
|
[1024, 1024],
|
||||||
|
]
|
||||||
|
if not hasattr(self.config, "mm_patch_merge_type"):
|
||||||
|
self.config.mm_patch_merge_type = "flat"
|
||||||
|
if not hasattr(self.config, "image_token_index"):
|
||||||
|
self.config.image_token_index = 10
|
||||||
|
if not hasattr(self.config, "projector_hidden_act"):
|
||||||
|
self.config.projector_hidden_act = "gelu"
|
||||||
|
|
||||||
|
self.vision_feature_layer = getattr(config, "vision_feature_layer", -1)
|
||||||
|
self.vision_feature_select_strategy = getattr(
|
||||||
|
config, "vision_feature_select_strategy", "full"
|
||||||
|
)
|
||||||
|
self.image_size = self.config.vision_config.image_size
|
||||||
|
self.patch_size = self.config.vision_config.patch_size
|
||||||
|
|
||||||
|
self.mm_patch_merge_type = config.mm_patch_merge_type
|
||||||
|
self.image_aspect_ratio = config.image_aspect_ratio
|
||||||
|
self.image_grid_pinpoints = config.image_grid_pinpoints
|
||||||
|
|
||||||
|
self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
|
||||||
|
|
||||||
|
self.multi_modal_projector = self.MULTIMODAL_PROJECTOR_TYPE(config)
|
||||||
|
|
||||||
|
language_model_cls = self._get_sgl_model_cls(
|
||||||
|
config.text_config, AutoModelForCausalLM
|
||||||
|
)
|
||||||
|
vision_model_cls = self._get_sgl_model_cls(config.vision_config, AutoModel)
|
||||||
|
self.language_model = language_model_cls(
|
||||||
|
config.text_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("language_model", prefix),
|
||||||
|
)
|
||||||
|
self.vision_tower = vision_model_cls(
|
||||||
|
config.vision_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("vision_tower", prefix),
|
||||||
|
)
|
||||||
|
|
||||||
|
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
||||||
|
self.language_model.model.image_newline = nn.Parameter(
|
||||||
|
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
|
"""Extract features from image inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
items: List of MultimodalDataItem objects containing image data
|
||||||
|
Note that an item can be either "image" or "multi-images"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: features from image inputs, concatenated
|
||||||
|
"""
|
||||||
|
features = []
|
||||||
|
for item in items:
|
||||||
|
# in each item, we assume pixel_values is always batched
|
||||||
|
pixel_values, image_sizes = item.pixel_values, item.image_sizes
|
||||||
|
image_outputs = self.vision_tower(
|
||||||
|
pixel_values, image_sizes, output_hidden_states=True
|
||||||
|
)
|
||||||
|
selected_image_feature = image_outputs.hidden_states[
|
||||||
|
self.vision_feature_layer
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.vision_feature_select_strategy in ["default", "patch"]:
|
||||||
|
selected_image_feature = selected_image_feature[:, 1:]
|
||||||
|
elif self.vision_feature_select_strategy == "full":
|
||||||
|
selected_image_feature = selected_image_feature
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected select feature: {self.vision_feature_select_strategy}"
|
||||||
|
)
|
||||||
|
features.append(
|
||||||
|
self.multi_modal_projector(selected_image_feature.squeeze(0))
|
||||||
|
)
|
||||||
|
ret = torch.cat(features, dim=0)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
get_embedding: bool = False,
|
||||||
|
):
|
||||||
|
hidden_states = general_mm_embed_routine(
|
||||||
|
input_ids=input_ids,
|
||||||
|
forward_batch=forward_batch,
|
||||||
|
get_embedding=get_embedding,
|
||||||
|
language_model=self.language_model,
|
||||||
|
image_data_embedding_func=self.get_image_feature,
|
||||||
|
placeholder_tokens=None, # using mm_item.pad_value
|
||||||
|
positions=positions,
|
||||||
|
)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
"""Load weights for LlavaForConditionalGeneration.
|
||||||
|
|
||||||
|
Unlike the base class implementation, this one doesn't need to handle
|
||||||
|
weight name remapping as the weights are already properly structured with
|
||||||
|
'language_model' and 'vision_tower' prefixes in the safetensors files.
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
self.vision_feature_select_strategy == "patch"
|
||||||
|
or self.vision_feature_select_strategy == "full"
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
elif self.vision_feature_select_strategy == "cls_patch":
|
||||||
|
self.image_feature_len += 1
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected select feature: {self.vision_feature_select_strategy}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create dictionaries for direct parameter loading
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
|
||||||
|
# Load weights directly without remapping
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
for part in ("language_model", "vision_tower"):
|
||||||
|
if name.startswith(part):
|
||||||
|
name = name[len(part + ".") :]
|
||||||
|
getattr(self, part).load_weights([(name, loaded_weight)])
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
EntryClass = [
|
||||||
|
LlavaLlamaForCausalLM,
|
||||||
|
LlavaQwenForCausalLM,
|
||||||
|
LlavaMistralForCausalLM,
|
||||||
|
LlavaForConditionalGeneration,
|
||||||
|
]
|
||||||
|
|||||||
467
python/sglang/srt/models/pixtral.py
Normal file
467
python/sglang/srt/models/pixtral.py
Normal file
@@ -0,0 +1,467 @@
|
|||||||
|
# Copyright 2024 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""
|
||||||
|
Using mistral-community/pixtral-12b as reference.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import PixtralVisionConfig, PretrainedConfig
|
||||||
|
from transformers.models.pixtral.modeling_pixtral import PixtralRotaryEmbedding
|
||||||
|
from transformers.models.pixtral.modeling_pixtral import (
|
||||||
|
generate_block_attention_mask as _get_pixtral_attention_mask,
|
||||||
|
)
|
||||||
|
from transformers.models.pixtral.modeling_pixtral import position_ids_in_meshgrid
|
||||||
|
|
||||||
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
|
from sglang.srt.layers.attention.vision import VisionAttention
|
||||||
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
|
from sglang.srt.layers.linear import MergedColumnParallelLinear, RowParallelLinear
|
||||||
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
|
from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
|
||||||
|
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
||||||
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralHFMLP(nn.Module):
|
||||||
|
"""MLP for PixtralHFVisionModel using SGLang components."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
*,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert config.intermediate_size is not None
|
||||||
|
|
||||||
|
# Use MergedColumnParallelLinear for gate_up_proj to handle combined weights
|
||||||
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
|
input_size=config.hidden_size,
|
||||||
|
output_sizes=[config.intermediate_size, config.intermediate_size],
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_proj = RowParallelLinear(
|
||||||
|
input_size=config.intermediate_size,
|
||||||
|
output_size=config.hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
gate_up_output, _ = self.gate_up_proj(x)
|
||||||
|
|
||||||
|
# Apply SiLU activation and multiply
|
||||||
|
gate_up = self.act_fn(gate_up_output)
|
||||||
|
|
||||||
|
# Project back to hidden size
|
||||||
|
out, _ = self.down_proj(gate_up)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralHFTransformerBlock(nn.Module):
|
||||||
|
"""Transformer block for PixtralHFVisionModel using SGLang components."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
layer_id: int,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
*,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||||
|
|
||||||
|
# Use SGLang's VisionAttention instead of vLLM's PixtralHFAttention
|
||||||
|
self.attention = VisionAttention(
|
||||||
|
embed_dim=config.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
projection_size=config.hidden_size,
|
||||||
|
use_qkv_parallel=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
dropout=0.0,
|
||||||
|
use_context_forward=False,
|
||||||
|
softmax_in_single_precision=False,
|
||||||
|
flatten_batch=False,
|
||||||
|
prefix=f"{prefix}.attention",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.feed_forward = PixtralHFMLP(
|
||||||
|
config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor],
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Ensure hidden_states has the batch dimension [batch, seq_len, hidden_dim]
|
||||||
|
batch_size, seq_len, hidden_dim = hidden_states.shape
|
||||||
|
|
||||||
|
# Apply attention norm - normalize along the last dimension
|
||||||
|
attn_normalized = self.attention_norm(hidden_states.view(-1, hidden_dim)).view(
|
||||||
|
batch_size, seq_len, hidden_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pass through attention layer
|
||||||
|
attention_output = self.attention(
|
||||||
|
attn_normalized,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
cu_seqlens=None,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply first residual connection
|
||||||
|
hidden_states = hidden_states + attention_output
|
||||||
|
|
||||||
|
# Apply feed-forward norm - normalize along the last dimension
|
||||||
|
ffn_normalized = self.ffn_norm(hidden_states.view(-1, hidden_dim)).view(
|
||||||
|
batch_size, seq_len, hidden_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pass through feed-forward layer
|
||||||
|
# First reshape to 2D for the feed-forward network, then reshape back
|
||||||
|
ffn_output = self.feed_forward(ffn_normalized)
|
||||||
|
|
||||||
|
# Apply second residual connection
|
||||||
|
output = hidden_states + ffn_output
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralHFTransformer(nn.Module):
|
||||||
|
"""Transformer for PixtralHFVisionModel using SGLang components."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PixtralVisionConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
*,
|
||||||
|
num_hidden_layers_override: Optional[int] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
num_hidden_layers = config.num_hidden_layers
|
||||||
|
if num_hidden_layers_override is not None:
|
||||||
|
num_hidden_layers = num_hidden_layers_override
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
PixtralHFTransformerBlock(
|
||||||
|
config=config,
|
||||||
|
layer_id=layer_idx,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.layers.{layer_idx}",
|
||||||
|
)
|
||||||
|
for layer_idx in range(num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor],
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
return_all_hidden_states: bool = False,
|
||||||
|
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||||
|
"""Forward pass through transformer layers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor
|
||||||
|
attention_mask: Optional attention mask
|
||||||
|
position_embeddings: Optional position embeddings for rotary attention
|
||||||
|
return_all_hidden_states: Whether to return all hidden states
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Either the final hidden state, or a list of all hidden states if
|
||||||
|
return_all_hidden_states is True
|
||||||
|
"""
|
||||||
|
# For HF model compatibility, always start with the input
|
||||||
|
hidden_states = x
|
||||||
|
all_hidden_states = [hidden_states] if return_all_hidden_states else None
|
||||||
|
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
hidden_states = layer(hidden_states, attention_mask, position_embeddings)
|
||||||
|
if return_all_hidden_states:
|
||||||
|
all_hidden_states.append(hidden_states)
|
||||||
|
|
||||||
|
if return_all_hidden_states:
|
||||||
|
return all_hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_visual_encoder_outputs(
|
||||||
|
outputs: Union[torch.Tensor, List[torch.Tensor]],
|
||||||
|
feature_sample_layers: Optional[List[int]],
|
||||||
|
post_norm: Optional[nn.Module],
|
||||||
|
num_hidden_layers: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Resolve outputs from visual encoder based on feature_sample_layers."""
|
||||||
|
if feature_sample_layers is None:
|
||||||
|
# Just use the last layer's output
|
||||||
|
if isinstance(outputs, list):
|
||||||
|
outputs = outputs[-1]
|
||||||
|
if post_norm is not None:
|
||||||
|
outputs = post_norm(outputs)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
# Handle the case where we want to use specific layers
|
||||||
|
if not isinstance(outputs, list):
|
||||||
|
raise ValueError(
|
||||||
|
"Expected outputs to be a list when feature_sample_layers is provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate layer indices
|
||||||
|
for layer_idx in feature_sample_layers:
|
||||||
|
if layer_idx < 0 or layer_idx > num_hidden_layers:
|
||||||
|
raise ValueError(
|
||||||
|
f"Feature sample layer index {layer_idx} is out of range "
|
||||||
|
f"[0, {num_hidden_layers}]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect outputs from specified layers
|
||||||
|
selected_outputs = [outputs[layer_idx] for layer_idx in feature_sample_layers]
|
||||||
|
|
||||||
|
# Combine the outputs
|
||||||
|
combined_outputs = torch.cat(selected_outputs, dim=-1)
|
||||||
|
|
||||||
|
if post_norm is not None:
|
||||||
|
combined_outputs = post_norm(combined_outputs)
|
||||||
|
|
||||||
|
return combined_outputs
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralHFVisionModel(nn.Module):
|
||||||
|
"""Hugging Face Pixtral Vision Model implemented using SGLang components."""
|
||||||
|
|
||||||
|
DEFAULT_IMAGE_TOKEN_ID = 10
|
||||||
|
|
||||||
|
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
||||||
|
return self.input_padder.pad_input_tokens(input_ids, image_inputs)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PixtralVisionConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
*,
|
||||||
|
image_token_id: int = DEFAULT_IMAGE_TOKEN_ID,
|
||||||
|
num_hidden_layers_override: Optional[int] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.image_size = config.image_size
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
|
||||||
|
self.patch_conv = nn.Conv2d(
|
||||||
|
in_channels=config.num_channels,
|
||||||
|
out_channels=config.hidden_size,
|
||||||
|
kernel_size=config.patch_size,
|
||||||
|
stride=config.patch_size,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
|
||||||
|
|
||||||
|
self.transformer = PixtralHFTransformer(
|
||||||
|
config,
|
||||||
|
quant_config,
|
||||||
|
num_hidden_layers_override=num_hidden_layers_override,
|
||||||
|
prefix=f"{prefix}.transformer",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that num_hidden_layers is valid
|
||||||
|
num_hidden_layers = config.num_hidden_layers
|
||||||
|
if len(self.transformer.layers) > config.num_hidden_layers:
|
||||||
|
raise ValueError(
|
||||||
|
f"The original encoder only has {num_hidden_layers} "
|
||||||
|
f"layers, but you requested {len(self.transformer.layers)} "
|
||||||
|
"layers."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize patch position embedding
|
||||||
|
self.image_token_id = image_token_id
|
||||||
|
self.patch_positional_embedding = PixtralRotaryEmbedding(config)
|
||||||
|
self.input_padder = MultiModalityDataPaddingPatternMultimodalTokens(
|
||||||
|
[self.image_token_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return next(self.parameters()).dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
image_sizes: list[tuple[int, int]],
|
||||||
|
output_hidden_states: bool = False,
|
||||||
|
feature_sample_layers: Optional[list[int]] = None,
|
||||||
|
) -> Union[torch.Tensor, tuple]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
pixel_values: [batch_size, C, H, W], padded if multiple images
|
||||||
|
image_sizes: list of (H, W) for each image in the batch
|
||||||
|
output_hidden_states: Whether to return all hidden states.
|
||||||
|
feature_sample_layers: Layer indices whose features should be
|
||||||
|
concatenated and used as the visual encoder output. If none
|
||||||
|
are provided, the last layer is used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing:
|
||||||
|
- hidden_states: Final model outputs (or selected layers if feature_sample_layers given)
|
||||||
|
- hidden_states tuple (optional): All hidden states if output_hidden_states=True
|
||||||
|
"""
|
||||||
|
# batch patch images
|
||||||
|
embeds_orig = self.patch_conv(
|
||||||
|
pixel_values.to(device=self.device, dtype=self.dtype)
|
||||||
|
)
|
||||||
|
# crop the embeddings
|
||||||
|
embeds_2d = [
|
||||||
|
embed[..., : h // self.patch_size, : w // self.patch_size]
|
||||||
|
for embed, (h, w) in zip(embeds_orig, image_sizes)
|
||||||
|
]
|
||||||
|
|
||||||
|
# flatten to sequence
|
||||||
|
embeds_1d = torch.cat([p.flatten(1).T for p in embeds_2d], dim=0)
|
||||||
|
embeds_featurized = self.ln_pre(embeds_1d).unsqueeze(0)
|
||||||
|
|
||||||
|
# positional embeddings
|
||||||
|
position_ids = position_ids_in_meshgrid(
|
||||||
|
embeds_2d,
|
||||||
|
max_width=self.image_size // self.patch_size,
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
# The original PixtralRotaryEmbedding expects 2D input but returns a tuple of tensors (cos, sin)
|
||||||
|
# These tensors are used by apply_rotary_pos_emb in the transformer blocks
|
||||||
|
position_embedding = self.patch_positional_embedding(
|
||||||
|
embeds_featurized, position_ids
|
||||||
|
)
|
||||||
|
attention_mask = _get_pixtral_attention_mask(
|
||||||
|
[p.shape[-2] * p.shape[-1] for p in embeds_2d], embeds_featurized
|
||||||
|
)
|
||||||
|
|
||||||
|
return_all_hidden_states = (
|
||||||
|
output_hidden_states or feature_sample_layers is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
transformer_outputs = self.transformer(
|
||||||
|
embeds_featurized, # add batch dimension
|
||||||
|
attention_mask,
|
||||||
|
position_embedding,
|
||||||
|
return_all_hidden_states=return_all_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store all hidden states if requested
|
||||||
|
all_hidden_states = None
|
||||||
|
if isinstance(transformer_outputs, list):
|
||||||
|
all_hidden_states = transformer_outputs
|
||||||
|
# Use the last layer by default if feature_sample_layers is not specified
|
||||||
|
if feature_sample_layers is None:
|
||||||
|
out = transformer_outputs[-1]
|
||||||
|
else:
|
||||||
|
# Resolve outputs based on feature sample layers
|
||||||
|
out = resolve_visual_encoder_outputs(
|
||||||
|
transformer_outputs,
|
||||||
|
feature_sample_layers,
|
||||||
|
None,
|
||||||
|
self.config.num_hidden_layers,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out = transformer_outputs
|
||||||
|
|
||||||
|
# Format return to be compatible with HuggingFace vision models
|
||||||
|
if output_hidden_states:
|
||||||
|
return type(
|
||||||
|
"VisualOutput",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"last_hidden_state": out,
|
||||||
|
"hidden_states": all_hidden_states,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return out
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
|
||||||
|
"""Load weights from a HuggingFace checkpoint with proper parameter mapping."""
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
|
||||||
|
# for (param, weight, shard_id): load weight into param as param's shard_id part
|
||||||
|
stacked_params_mapping = [
|
||||||
|
(".attention.qkv_proj", ".attention.q_proj", "q"),
|
||||||
|
(".attention.qkv_proj", ".attention.k_proj", "k"),
|
||||||
|
(".attention.qkv_proj", ".attention.v_proj", "v"),
|
||||||
|
(".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
|
||||||
|
(".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Process each weight
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name in name:
|
||||||
|
# Replace the weight name part with the combined parameter name
|
||||||
|
transformed_name = name.replace(weight_name, param_name)
|
||||||
|
if transformed_name in params_dict:
|
||||||
|
param = params_dict[transformed_name]
|
||||||
|
weight_loader = getattr(
|
||||||
|
param, "weight_loader", default_weight_loader
|
||||||
|
)
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if ".attention.o_proj" in name:
|
||||||
|
alt_name = name.replace(".attention.o_proj", ".attention.proj")
|
||||||
|
if alt_name in params_dict:
|
||||||
|
name = alt_name
|
||||||
|
if name in params_dict:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(
|
||||||
|
param, "weight_loader", default_weight_loader
|
||||||
|
)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralVisionModel(PixtralHFVisionModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Register the model classes for external access
|
||||||
|
EntryClass = [PixtralVisionModel]
|
||||||
@@ -19,7 +19,9 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import transformers
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
AutoModel,
|
AutoModel,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModelForVision2Seq,
|
AutoModelForVision2Seq,
|
||||||
@@ -211,7 +213,12 @@ class HFRunner:
|
|||||||
|
|
||||||
# Load the model and tokenizer
|
# Load the model and tokenizer
|
||||||
if self.model_type == "generation":
|
if self.model_type == "generation":
|
||||||
self.base_model = AutoModelForCausalLM.from_pretrained(
|
config = AutoConfig.from_pretrained(model_path)
|
||||||
|
if model_archs := getattr(config, "architectures"):
|
||||||
|
model_cls = getattr(transformers, model_archs[0])
|
||||||
|
else:
|
||||||
|
model_cls = AutoModelForCausalLM
|
||||||
|
self.base_model = model_cls.from_pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
trust_remote_code=self.trust_remote_code,
|
trust_remote_code=self.trust_remote_code,
|
||||||
|
|||||||
@@ -14,14 +14,15 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
To test a specific model:
|
To test a specific model locally:
|
||||||
1. Add it to ALL_OTHER_MODELS
|
1. Add it to ALL_MODELS, for example, `ModelCase("Qwen/Qwen2-1.5B")`
|
||||||
2. Run `ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels.test_others`
|
2. Run `ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
import unittest
|
import unittest
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
@@ -53,8 +54,9 @@ CI_MODELS = [
|
|||||||
ModelCase("google/gemma-2-2b"),
|
ModelCase("google/gemma-2-2b"),
|
||||||
]
|
]
|
||||||
|
|
||||||
# All other models that do not run on the CI
|
# the complete set of models to test sglang's generation model
|
||||||
ALL_OTHER_MODELS = [
|
ALL_MODELS = [
|
||||||
|
*CI_MODELS,
|
||||||
ModelCase("Qwen/Qwen2-1.5B"),
|
ModelCase("Qwen/Qwen2-1.5B"),
|
||||||
ModelCase("Qwen/Qwen2.5-14B-Instruct"),
|
ModelCase("Qwen/Qwen2.5-14B-Instruct"),
|
||||||
ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True),
|
ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True),
|
||||||
@@ -63,7 +65,7 @@ ALL_OTHER_MODELS = [
|
|||||||
"THUDM/glm-4-9b-chat", tp_size=2, trust_remote_code=True, skip_long_prompt=True
|
"THUDM/glm-4-9b-chat", tp_size=2, trust_remote_code=True, skip_long_prompt=True
|
||||||
),
|
),
|
||||||
ModelCase("openai-community/gpt2"),
|
ModelCase("openai-community/gpt2"),
|
||||||
ModelCase("microsoft/Phi-3-small-8k-instruct"),
|
ModelCase("microsoft/Phi-3-small-8k-instruct", trust_remote_code=True),
|
||||||
ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True),
|
ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True),
|
||||||
ModelCase("ibm-granite/granite-3.0-2b-instruct", skip_long_prompt=True),
|
ModelCase("ibm-granite/granite-3.0-2b-instruct", skip_long_prompt=True),
|
||||||
]
|
]
|
||||||
@@ -117,9 +119,30 @@ class TestGenerationModels(CustomTestCase):
|
|||||||
debug_text=f"model_path={model_path} prompts={prompts}",
|
debug_text=f"model_path={model_path} prompts={prompts}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@unittest.skipIf(not is_in_ci(), "Local test should run all models")
|
||||||
def test_ci_models(self):
|
def test_ci_models(self):
|
||||||
for model_case in CI_MODELS:
|
for model_case in CI_MODELS:
|
||||||
for torch_dtype in TORCH_DTYPES:
|
for torch_dtype in TORCH_DTYPES:
|
||||||
|
prompts = DEFAULT_PROMPTS
|
||||||
|
|
||||||
|
# Skip long prompts for models that do not have a long context
|
||||||
|
if model_case.skip_long_prompt:
|
||||||
|
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
|
||||||
|
|
||||||
|
# Assert the logits and output strs are close
|
||||||
|
self.assert_close_logits_and_output_strs(
|
||||||
|
prompts, model_case, torch_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.skipIf(is_in_ci(), "CI only runs selected models for simplicity")
|
||||||
|
def test_all_models(self):
|
||||||
|
for model_case in ALL_MODELS:
|
||||||
|
for torch_dtype in TORCH_DTYPES:
|
||||||
|
if (
|
||||||
|
"ONLY_RUN" in os.environ
|
||||||
|
and os.environ["ONLY_RUN"] != model_case.model_path
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
# Skip long prompts for models that do not have a long context
|
# Skip long prompts for models that do not have a long context
|
||||||
prompts = DEFAULT_PROMPTS
|
prompts = DEFAULT_PROMPTS
|
||||||
@@ -131,26 +154,6 @@ class TestGenerationModels(CustomTestCase):
|
|||||||
prompts, model_case, torch_dtype
|
prompts, model_case, torch_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_others(self):
|
|
||||||
if is_in_ci():
|
|
||||||
return
|
|
||||||
|
|
||||||
for model_case in ALL_OTHER_MODELS:
|
|
||||||
# Only run a specified model
|
|
||||||
if (
|
|
||||||
"ONLY_RUN" in os.environ
|
|
||||||
and os.environ["ONLY_RUN"] != model_case.model_path
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Skip long prompts for models that do not have a long context
|
|
||||||
prompts = DEFAULT_PROMPTS
|
|
||||||
if model_case.skip_long_prompt:
|
|
||||||
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
|
|
||||||
|
|
||||||
# Assert the logits and output strs are close
|
|
||||||
self.assert_close_logits_and_output_strs(prompts, model_case, torch.float16)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -642,6 +642,28 @@ class TestMinicpmoServer(TestOpenAIVisionServer):
|
|||||||
self._test_audio_ambient_completion()
|
self._test_audio_ambient_completion()
|
||||||
|
|
||||||
|
|
||||||
|
class TestPixtralServer(TestOpenAIVisionServer):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = "mistral-community/pixtral-12b"
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.api_key = "sk-123456"
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--mem-fraction-static",
|
||||||
|
"0.73",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
cls.base_url += "/v1"
|
||||||
|
|
||||||
|
def test_video_chat_completion(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TestDeepseekVL2Server(TestOpenAIVisionServer):
|
class TestDeepseekVL2Server(TestOpenAIVisionServer):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
|
|||||||
Reference in New Issue
Block a user