model(vlm): pixtral (#5084)
This commit is contained in:
@@ -1,14 +1,20 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from transformers.models.auto.processing_auto import (
|
||||
PROCESSOR_MAPPING_NAMES as HF_MAPPING_NAMES,
|
||||
)
|
||||
|
||||
import sglang.srt.managers.multimodal_processor as sgl_mm_processor_utils
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||
from sglang.srt.models.llava import (
|
||||
LlavaForConditionalGeneration,
|
||||
LlavaLlamaForCausalLM,
|
||||
LlavaMistralForCausalLM,
|
||||
LlavaQwenForCausalLM,
|
||||
@@ -133,6 +139,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
img_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
)
|
||||
|
||||
res = await asyncio.gather(*res)
|
||||
for pixel_v, image_h, image_s in res:
|
||||
pixel_values.append(pixel_v)
|
||||
@@ -165,3 +172,42 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
)
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class LlavaMultimodalProcessor(BaseMultimodalProcessor):
|
||||
"""
|
||||
This is a wrapper class used to identify the multimodal processor for Llava architecture models.
|
||||
"""
|
||||
|
||||
models = [LlavaForConditionalGeneration]
|
||||
|
||||
def _get_sgl_processor_cls(self, model_type: str):
|
||||
if hf_name := HF_MAPPING_NAMES.get(model_type):
|
||||
sgl_mm_processor_set = sgl_mm_processor_utils.PROCESSOR_MAPPING.values()
|
||||
sgl_processor_cls = list(
|
||||
filter(lambda p: p.__name__ == hf_name, sgl_mm_processor_set)
|
||||
)
|
||||
if sgl_processor_cls:
|
||||
return sgl_processor_cls[0]
|
||||
raise ValueError(
|
||||
f"Cannot find corresponding multimodal processor registered in sglang for model type `{model_type}`"
|
||||
)
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
assert hasattr(hf_config, "vision_config")
|
||||
assert hasattr(hf_config, "text_config")
|
||||
self.vision_config = hf_config.vision_config
|
||||
self.text_config = hf_config.text_config
|
||||
self.hf_config = hf_config
|
||||
|
||||
if vision_type := getattr(self.vision_config, "model_type"):
|
||||
self.inner = self._get_sgl_processor_cls(vision_type)(
|
||||
hf_config, server_args, _processor
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Required `vision_config.model_type` is not found in hf_config: `{hf_config}`"
|
||||
)
|
||||
|
||||
async def process_mm_data_async(self, *args, **kwargs):
|
||||
return await self.inner.process_mm_data_async(*args, **kwargs)
|
||||
|
||||
127
python/sglang/srt/managers/multimodal_processors/pixtral.py
Normal file
127
python/sglang/srt/managers/multimodal_processors/pixtral.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import asyncio
|
||||
import math
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.models.pixtral.image_processing_pixtral import (
|
||||
_num_image_tokens as _get_pixtral_hf_num_image_tokens,
|
||||
)
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Modality,
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
)
|
||||
from sglang.srt.models.pixtral import PixtralVisionModel
|
||||
|
||||
|
||||
class PixtralProcessor(BaseMultimodalProcessor):
|
||||
models = [PixtralVisionModel]
|
||||
|
||||
PAD_TOKEN = "<pad>"
|
||||
IMG_BREAK_TOKEN_ID = 12
|
||||
IMG_END_TOKEN_ID = 13
|
||||
|
||||
def get_patch_grid_size(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> tuple[int, int]:
|
||||
max_width = max_height = self.image_size
|
||||
patch_width = patch_height = self.patch_size
|
||||
|
||||
ratio = max(image_width / max_width, image_height / max_height)
|
||||
|
||||
if ratio > 1:
|
||||
image_width = int(math.floor(image_width / ratio))
|
||||
image_height = int(math.floor(image_height / ratio))
|
||||
|
||||
nrows, ncols = _get_pixtral_hf_num_image_tokens(
|
||||
(image_height, image_width),
|
||||
(patch_height, patch_width),
|
||||
)
|
||||
|
||||
return ncols, nrows
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.image_token_id = getattr(
|
||||
hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID
|
||||
)
|
||||
# Instantiate the patcher logic helper using the class defined above
|
||||
|
||||
self.vision_config = hf_config.vision_config
|
||||
self.image_size = self.vision_config.image_size
|
||||
self.patch_size = self.vision_config.patch_size
|
||||
self.multimodal_tokens = MultimodalSpecialTokens(
|
||||
image_token=_processor.image_token
|
||||
)
|
||||
_processor.tokenizer.add_special_tokens(
|
||||
{
|
||||
"pad_token": getattr(hf_config, "pad_token", self.PAD_TOKEN),
|
||||
}
|
||||
)
|
||||
|
||||
async def _resize(self, image):
|
||||
num_w_tokens, num_h_tokens = self.get_patch_grid_size(
|
||||
image_width=image.size[0],
|
||||
image_height=image.size[1],
|
||||
)
|
||||
new_size = (num_w_tokens * self.patch_size, num_h_tokens * self.patch_size)
|
||||
return image.resize(new_size)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
mm_data = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
multimodal_tokens=self.multimodal_tokens,
|
||||
max_req_input_len=kwargs.get("max_req_input_len", 4096),
|
||||
image_data=image_data,
|
||||
return_text=True,
|
||||
)
|
||||
|
||||
if mm_data.images:
|
||||
resize_tasks = [self._resize(image) for image in mm_data.images]
|
||||
mm_data.images = await asyncio.gather(*resize_tasks)
|
||||
|
||||
processor_output = self.process_mm_data(
|
||||
input_text=mm_data.input_text,
|
||||
images=mm_data.images,
|
||||
)
|
||||
|
||||
if "pixel_values" in processor_output:
|
||||
mm_items = [
|
||||
MultimodalDataItem(
|
||||
pixel_values=processor_output["pixel_values"],
|
||||
image_sizes=processor_output["image_sizes"],
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
]
|
||||
|
||||
input_ids = processor_output["input_ids"].view(-1).tolist()
|
||||
processor_output.update(
|
||||
input_ids=input_ids,
|
||||
mm_items=mm_items,
|
||||
# there's no im_start_id for pixtral, only im_token and im_end_token
|
||||
im_end_id=self.IMG_END_TOKEN_ID,
|
||||
im_token_id=self.image_token_id,
|
||||
)
|
||||
return processor_output
|
||||
Reference in New Issue
Block a user