refactor: multimodal data (#4754)

This commit is contained in:
Mick
2025-04-01 00:57:51 +08:00
committed by GitHub
parent c7457191a0
commit 5cb552b1d4
36 changed files with 989 additions and 1138 deletions

View File

@@ -8,18 +8,10 @@ from typing import Optional
import numpy as np
import PIL
import transformers
from decord import VideoReader, cpu
from PIL import Image
from sglang.srt.utils import load_audio, load_image, logger
global global_processor
def get_global_processor():
global global_processor
return global_processor
from sglang.srt.utils import encode_video, load_audio, load_image, logger
@dataclasses.dataclass
@@ -27,9 +19,6 @@ class BaseMultiModalProcessorOutput:
# input_text, with each frame of video/image represented with a image_token
input_text: str
mm_data_hashes: Optional[list[int]]
# images
image_sizes: Optional[list[int]]
# frames loaded from image and video, in given order
images: Optional[list[PIL.Image]] = None
@@ -37,7 +26,7 @@ class BaseMultiModalProcessorOutput:
audios: Optional[list[np.ndarray]] = None
def normalize(self):
for field_name in ["data_hashes", "image_sizes", "images", "audios"]:
for field_name in ["image_sizes", "images", "audios"]:
field = getattr(self, field_name, None)
if field is not None and isinstance(field, list) and len(field) == 0:
setattr(self, field_name, None)
@@ -67,28 +56,35 @@ class BaseMultimodalProcessor(ABC):
# FIXME: not accurate, model and image specific
self.NUM_TOKEN_PER_FRAME = 330
# Initialize global processor first
init_global_processor(self, server_args)
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
self.io_executor = concurrent.futures.ThreadPoolExecutor(
max_workers=int(os.environ.get("SGLANG_IO_WORKERS", 4))
)
self.cpu_executor = concurrent.futures.ProcessPoolExecutor(
mp_context=mp.get_context("fork"),
initargs=(
self,
server_args,
),
max_workers=int(os.environ.get("SGLANG_CPU_COUNT", os.cpu_count())),
max_workers=int(os.environ.get("SGLANG_CPU_WORKERS", os.cpu_count())),
)
def _build_processor(self, server_args):
"""Init the global processor for multi modal models."""
from sglang.srt.hf_transformers_utils import get_processor
def process_mm_data(
self, input_text, images=None, videos=None, audios=None, **kwargs
):
"""
process multimodal data with transformers AutoProcessor
"""
if images is not None:
kwargs["images"] = images
if videos is not None:
kwargs["videos"] = videos
if audios is not None:
kwargs["audios"] = audios
return get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
processor = self._processor
result = processor.__call__(
text=[input_text],
padding=True,
return_tensors="pt",
**kwargs,
)
return result
@abstractmethod
async def process_mm_data_async(
@@ -115,33 +111,9 @@ class BaseMultimodalProcessor(ABC):
return estimated_frames_list
@staticmethod
def encode_video(video_path, frame_count_limit=None):
if not os.path.exists(video_path):
logger.error(f"Video {video_path} does not exist")
return []
if frame_count_limit == 0:
return []
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_indices = [i for i in range(0, len(vr), sample_fps)]
if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
frame_indices = uniform_sample(frame_indices, frame_count_limit)
frames = vr.get_batch(frame_indices).asnumpy()
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
return frames
def load_mm_data(
self,
input_ids: list[int],
prompt: str,
multimodal_tokens: MultimodalSpecialTokens,
max_req_input_len: int,
image_data: Optional[list] = None,
@@ -167,11 +139,13 @@ class BaseMultimodalProcessor(ABC):
else:
multimodal_tokens.image_token = multimodal_tokens.image_token
if isinstance(input_ids, list) and return_text:
assert len(input_ids) and isinstance(input_ids[0], int)
input_text = self._processor.tokenizer.decode(input_ids)
assert isinstance(prompt, str)
if isinstance(prompt, list) and return_text:
assert len(prompt) and isinstance(prompt[0], int)
prompt = self._processor.tokenizer.decode(prompt)
else:
input_text = input_ids
prompt = prompt
if return_text:
import re
@@ -181,7 +155,7 @@ class BaseMultimodalProcessor(ABC):
+ ")"
)
# split text into list of normal text and special tokens
text_parts = re.split(pattern, input_text)
text_parts = re.split(pattern, prompt)
# TODO(mick): load from server_args, env, or sampling_params
MAX_NUM_FRAMES = 30
@@ -217,7 +191,7 @@ class BaseMultimodalProcessor(ABC):
):
# video
path = image_file[len("video:") :]
frames = BaseMultimodalProcessor.encode_video(
frames = encode_video(
path, frame_count_limit=frames_to_process
)
else:
@@ -254,19 +228,9 @@ class BaseMultimodalProcessor(ABC):
raise RuntimeError(f"An exception occurred while loading images: {e}")
out = BaseMultiModalProcessorOutput(
mm_data_hashes=hashes,
image_sizes=image_sizes,
images=images,
audios=audios,
input_text=new_text,
)
out.normalize()
return out
def init_global_processor(sglang_processor: BaseMultimodalProcessor, server_args):
"""
Init the global processor for multimodal models."""
global global_processor
transformers.logging.set_verbosity_error()
global_processor = sglang_processor._build_processor(server_args=server_args)

View File

@@ -1,10 +1,9 @@
import asyncio
from typing import List, Union
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
get_global_processor,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.clip import CLIPModel
from sglang.srt.utils import load_image
@@ -15,29 +14,6 @@ class ClipImageProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(images, input_text):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return get_global_processor()(
images=images, text=input_text, return_tensors="pt"
)
async def _process_single_image(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
ClipImageProcessor._process_single_image_task,
images,
input_text,
)
else:
image_inputs = self._processor(
images=images, text=[input_text], return_tensors="pt"
)
return image_inputs
async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
):
@@ -56,8 +32,13 @@ class ClipImageProcessor(BaseMultimodalProcessor):
else:
images = load_image(image_data[0])[0]
image_inputs = await self._process_single_image(images, input_text)
image_inputs = self.process_mm_data(input_text=input_text, images=images)
image_inputs["data_hashes"] = [hash(str(image_data))]
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
image_inputs["mm_items"] = [
MultimodalDataItem(
pixel_values=image_inputs["pixel_values"], modality=Modality.IMAGE
)
]
return image_inputs

View File

@@ -16,15 +16,14 @@
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import asyncio
import torch
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
get_global_processor,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM
@@ -35,51 +34,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<image>"
@staticmethod
def _process_images_task(image, input_text, max_req_input_len):
processor = get_global_processor()
res = processor.__call__(
conversations=input_text, images=image, max_req_input_len=max_req_input_len
)
image_token_id = processor.image_token_id
res["im_token_id"] = image_token_id
return res
async def _process_images(self, image_data, input_text, max_req_input_len):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
DeepseekVL2ImageProcessor._process_images_task,
image_data,
input_text,
max_req_input_len,
)
else:
image_inputs = self._process_images_task(
image_data, input_text, max_req_input_len
)
return image_inputs
async def _process_images(self, image_data, input_text, max_req_input_len):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
DeepseekVL2ImageProcessor._process_images_task,
image_data,
input_text,
max_req_input_len,
)
else:
image_inputs = self._process_images_task(
image_data, input_text, max_req_input_len
)
return image_inputs
async def process_mm_data_async(
self, image_data, input_ids, request_obj, max_req_input_len, *args, **kwargs
):
@@ -89,8 +43,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
if not isinstance(image_data, list):
image_data = [image_data]
images, image_sizes = [], []
image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data(
input_ids,
@@ -98,8 +50,11 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len,
)
res = await self._process_images(
base_output.images, base_output.input_text, max_req_input_len
res = self.process_mm_data(
input_text=base_output.input_text,
images=base_output.images,
max_req_input_len=max_req_input_len,
conversations=base_output.input_text,
)
images_seq_mask = res["images_seq_mask"]
images_spatial_crop = res["images_spatial_crop"]
@@ -107,13 +62,17 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
batched_images_spatial_crop.append(images_spatial_crop)
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
items = []
item = MultimodalDataItem(
pixel_values=res["images"],
modality=Modality.IMAGE,
image_emb_mask=images_seq_mask,
image_spatial_crop=batched_images_spatial_crop,
)
items += [item]
return {
"mm_items": items,
"input_ids": res["input_ids"].tolist(),
"pixel_values": res["images"],
"im_token_id": res["im_token_id"],
"data_hashes": base_output.mm_data_hashes,
"image_sizes": image_sizes,
"images_emb_mask": images_seq_mask,
"image_spatial_crop": batched_images_spatial_crop,
"modalities": request_obj.modalities or ["image"],
"im_token_id": self._processor.image_token_id,
}

View File

@@ -7,8 +7,8 @@ from sglang.srt.managers.multimodal_processor import (
)
from sglang.srt.managers.multimodal_processors.base_processor import (
MultimodalSpecialTokens,
get_global_processor,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
@@ -25,28 +25,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
self.IM_START_TOKEN_ID = hf_config.boi_token_index
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
async def _process_single_image(self, images, input_text) -> dict:
if isinstance(images, list) and len(images) == 0:
images = None
processor = get_global_processor()
result = processor.__call__(
text=[input_text],
images=images,
padding=True,
return_tensors="pt",
# if RGBA, this needs to be set
# images_kwargs={
# "input_data_format": ChannelDimension.FIRST
# }
)
pixel_values = getattr(result, "pixel_values", None)
return {
"input_ids": result.input_ids,
"pixel_values": pixel_values,
}
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
@@ -63,21 +41,28 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data(
input_ids=input_ids,
prompt=input_ids,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len,
discard_alpha_channel=True,
)
ret = await self._process_single_image(
ret = self.process_mm_data(
input_text=base_output.input_text, images=base_output.images
)
items = []
for i, image in enumerate(base_output.images):
item = MultimodalDataItem(
pixel_values=ret["pixel_values"][i],
modality=Modality.IMAGE,
)
items += [item]
return {
"mm_items": items,
"input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"],
"data_hashes": base_output.mm_data_hashes,
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
}

View File

@@ -1,11 +1,10 @@
import asyncio
from typing import List, Union
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
get_global_processor,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.deepseek_janus_pro import MultiModalityCausalLM
@@ -15,37 +14,6 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_images_task(images, input_text):
processor = get_global_processor()
result = processor.__call__(
prompt=input_text, images=images, return_tensors="pt"
)
return {
"input_ids": result["input_ids"],
"pixel_values": result["pixel_values"],
"images_emb_mask": result["images_emb_mask"],
"im_start_id": processor.image_start_id,
"im_end_id": processor.image_end_id,
"im_token_id": processor.image_id,
}
async def _process_images(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
JanusProImageProcessor._process_images_task,
images,
input_text,
)
else:
image_inputs = self._processor(
images=images, text=input_text, return_tensors="pt"
)
return image_inputs
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
@@ -60,25 +28,31 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
if not isinstance(image_data, list):
image_data = [image_data]
processor = self._processor
base_out = self.load_mm_data(
input_ids=input_ids,
prompt=input_ids,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(
image_token="<image_placeholder>"
),
multimodal_tokens=MultimodalSpecialTokens(image_token=processor.image_tag),
max_req_input_len=max_req_input_len,
)
images = base_out.images
res = await self._process_images(images=images, input_text=base_out.input_text)
# print(res)
# print(base_out)
# print("", res["images_emb_mask"].shape)
res = self.process_mm_data(
input_text=base_out.input_text,
prompt=base_out.input_text,
images=images,
)
return {
"mm_items": [
MultimodalDataItem(
pixel_values=res["pixel_values"],
image_emb_mask=res["images_emb_mask"],
modality=Modality.IMAGE,
)
],
"input_ids": res["input_ids"].flatten().tolist(),
"pixel_values": res["pixel_values"],
"images_emb_mask": res["images_emb_mask"],
"data_hashes": base_out.mm_data_hashes,
"im_start_id": res["im_start_id"],
"im_end_id": res["im_end_id"],
"im_token_id": res["im_token_id"],
"im_start_id": processor.image_start_id,
"im_end_id": processor.image_end_id,
"im_token_id": processor.image_id,
}

View File

@@ -5,8 +5,8 @@ import numpy as np
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
get_global_processor,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.models.llava import LlavaMistralForCausalLM, LlavaQwenForCausalLM
from sglang.srt.models.llavavid import LlavaVidForCausalLM
@@ -25,11 +25,10 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
image_data: Union[str, bytes],
image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[str] = None,
image_processor=None,
processor=None,
):
processor = get_global_processor()
image_processor = image_processor or processor.image_processor
image_processor = processor.image_processor
try:
image, image_size = load_image(image_data)
@@ -72,18 +71,22 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
async def _process_single_image(
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
):
if self.executor is not None:
if self.cpu_executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
self.cpu_executor,
LlavaImageProcessor._process_single_image_task,
image_data,
aspect_ratio,
grid_pinpoints,
self._processor,
)
else:
return self._process_single_image_task(
image_data, aspect_ratio, grid_pinpoints
image_data,
aspect_ratio,
grid_pinpoints,
self._processor.image_processor,
)
async def process_mm_data_async(
@@ -134,14 +137,22 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
data_hashes = [image_hash]
image_sizes = [image_size]
else:
raise ValueError(f"Invalid image data: {image_data}")
modality = Modality.IMAGE
if isinstance(request_obj.modalities, list):
if request_obj.modalities[0] == "multi-images":
modality = Modality.MULTI_IMAGES
elif request_obj.modalities[0] == "video":
modality = Modality.VIDEO
return {
"pixel_values": pixel_values,
"data_hashes": data_hashes,
"image_sizes": image_sizes,
"modalities": request_obj.modalities or ["image"],
"mm_items": [
MultimodalDataItem(
pixel_values=pixel_values,
image_sizes=image_sizes,
modality=modality,
)
],
}

View File

@@ -1,13 +1,13 @@
import asyncio
from typing import List, Union
import torch
from transformers import BaseImageProcessorFast
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
get_global_processor,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.minicpmo import MiniCPMO
from sglang.srt.models.minicpmv import MiniCPMV
@@ -21,19 +21,23 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
self.image_token = "(<image>./</image>)"
self.audio_token = "(<audio>./</audio>)"
@staticmethod
def _process_data_task(input_text, images=None, audios=None):
def process_data_task(self, input_text, images=None, audios=None):
if isinstance(images, list) and len(images) == 0:
images = None
if isinstance(audios, list) and len(audios) == 0:
audios = None
result = get_global_processor().__call__(
processor = self._processor
args = {}
if isinstance(processor, BaseImageProcessorFast):
args["device"] = "cuda"
result = self._processor.__call__(
text=input_text,
images=images,
audios=audios,
return_tensors="pt",
chunk_input=True,
**args,
)
return {
"input_ids": result.input_ids,
@@ -44,23 +48,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
"audio_bounds": getattr(result, "audio_bounds", None),
}
async def _process_data(self, images, input_text, audios=None):
if self.executor is not None:
loop = asyncio.get_event_loop()
multimodal_data_inputs = await loop.run_in_executor(
self.executor,
MiniCPMMultimodalProcessor._process_data_task,
input_text,
images,
audios,
)
else:
multimodal_data_inputs = self._processor(
images=images, text=input_text, audios=audios, return_tensors="pt"
)
return multimodal_data_inputs
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
@@ -77,7 +64,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
audio_data = [audio_data]
base_output = self.load_mm_data(
input_ids=input_ids,
prompt=input_ids,
max_req_input_len=max_req_input_len,
audio_data=audio_data,
image_data=image_data,
@@ -88,9 +75,9 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
if base_output is None:
return None
res = await self._process_data(
images=base_output.images,
res = self.process_mm_data(
input_text=base_output.input_text,
images=base_output.images,
audios=base_output.audios,
)
@@ -142,23 +129,33 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
tgt_sizes_flat += [tgt_n]
pixel_values = pixel_values_flat
if len(tgt_sizes_flat) == 0:
tgt_sizes = None
else:
tgt_sizes = torch.stack(tgt_sizes_flat)
if not isinstance(res["audio_features"], list):
res["audio_features"] = [res["audio_features"]]
items = []
if len(pixel_values) != 0:
item = MultimodalDataItem(
pixel_values=pixel_values,
tgt_size=tgt_sizes_flat,
modality=Modality.IMAGE,
)
items += [item]
if (
"audio_features" in res
and res["audio_features"] is not None
and len(res["audio_features"]) != 0
):
item = MultimodalDataItem(
audio_features=[res["audio_features"]],
audio_feature_lens=res["audio_feature_lens"],
modality=Modality.AUDIO,
)
items += [item]
return {
"mm_items": items,
"input_ids": res["input_ids"].flatten().tolist(),
"pixel_values": pixel_values,
"tgt_sizes": tgt_sizes,
"data_hashes": base_output.mm_data_hashes,
"modalities": request_obj.modalities or ["image"],
"audio_start_id": audio_start_id,
"audio_end_id": audio_end_id,
"audio_features": res["audio_features"],
"audio_bounds": res["audio_bounds"],
"audio_feature_lens": res["audio_feature_lens"],
"im_token_id": im_token_id,
"im_start_id": tokenizer.im_start_id,
"im_end_id": tokenizer.im_end_id,

View File

@@ -1,10 +1,9 @@
import asyncio
from typing import List, Union
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
get_global_processor,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.mllama import MllamaForConditionalGeneration
from sglang.srt.utils import load_image
@@ -15,25 +14,6 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(images, input_text):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return get_global_processor()(images, input_text, return_tensors="pt")
async def _process_single_image(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
MllamaImageProcessor._process_single_image_task,
images,
input_text,
)
else:
image_inputs = self._processor(images, input_text, return_tensors="pt")
return image_inputs
async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
):
@@ -52,8 +32,15 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
else:
images = load_image(image_data[0])[0]
image_inputs = await self._process_single_image(images, input_text)
image_inputs["data_hashes"] = [hash(str(image_data))]
image_inputs = self.process_mm_data(input_text=input_text, images=images)
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
image_inputs["mm_items"] = [
MultimodalDataItem(
pixel_values=image_inputs["pixel_values"],
aspect_ratio_id=image_inputs["aspect_ratio_ids"],
aspect_ratio_mask=image_inputs["aspect_ratio_mask"],
modality=Modality.IMAGE,
)
]
return image_inputs

View File

@@ -1,6 +1,5 @@
import asyncio
import math
import time
from typing import List, Union
import torch
@@ -11,8 +10,8 @@ from sglang.srt.managers.multimodal_processor import (
)
from sglang.srt.managers.multimodal_processors.base_processor import (
MultimodalSpecialTokens,
get_global_processor,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
@@ -34,45 +33,15 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self.MAX_PIXELS = 16384 * 28 * 28
self.MAX_RATIO = 200
@staticmethod
def _process_images_task(images, input_text, _hf_config):
if isinstance(images, list) and len(images) == 0:
images = None
result = get_global_processor().__call__(
text=[input_text], images=images, padding=True, return_tensors="pt"
)
return {
"input_ids": result.input_ids,
"pixel_values": getattr(result, "pixel_values", None),
"image_grid_thw": getattr(result, "image_grid_thw", None),
"second_per_grid_ts": getattr(result, "second_per_grid_ts", None),
"video_grid_thws": getattr(result, "video_grid_thws", None),
}
async def _process_single_image(self, images, input_text) -> dict:
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
Qwen2_5VLImageProcessor._process_images_task,
images,
input_text,
self.hf_config,
)
else:
return self._process_images_task(images, input_text, self.hf_config)
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
prompt,
request_obj,
max_req_input_len,
*args,
**kwargs,
):
start = time.time()
if not image_data:
return None
if isinstance(image_data, str):
@@ -80,7 +49,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data(
input_ids=input_ids,
prompt=prompt,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len,
@@ -144,24 +113,32 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
images = [resize_image(image) for image in base_output.images]
async def resize_image_async(image):
return resize_image(image)
ret = await self._process_single_image(
images=images, input_text=base_output.input_text
resize_tasks = [resize_image_async(image) for image in base_output.images]
resized_images = await asyncio.gather(*resize_tasks)
ret = self.process_mm_data(
input_text=base_output.input_text,
images=resized_images,
)
image_grid_thws = torch.concat([ret["image_grid_thw"]])
video_grid_thws = None
return {
"input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"],
"data_hashes": base_output.mm_data_hashes,
"modalities": request_obj.modalities or ["image"],
"image_grid_thws": image_grid_thws,
"video_grid_thws": video_grid_thws,
"mm_items": [
MultimodalDataItem(
pixel_values=ret["pixel_values"],
image_grid_thws=image_grid_thws,
# TODO
video_grid_thws=None,
second_per_grid_ts=ret.get("second_per_grid_ts", None),
modality=Modality.IMAGE,
)
],
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
"im_token_id": self.image_token_id,
"video_token_id": self.video_token_id,
"second_per_grid_ts": ret["second_per_grid_ts"],
}