Process image in parallel (#1539)
This commit is contained in:
187
python/sglang/srt/managers/image_processor.py
Normal file
187
python/sglang/srt/managers/image_processor.py
Normal file
@@ -0,0 +1,187 @@
|
||||
# TODO: also move pad_input_ids into this module
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import transformers
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_processor
|
||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import load_image
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
global global_processor
|
||||
|
||||
|
||||
def init_global_processor(server_args: ServerArgs):
|
||||
"""Init the global processor for multi modal models."""
|
||||
global global_processor
|
||||
transformers.logging.set_verbosity_error()
|
||||
global_processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
class BaseImageProcessor(ABC):
|
||||
@abstractmethod
|
||||
async def process_images_async(self, image_data, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class DummyImageProcessor(BaseImageProcessor):
|
||||
async def process_images_async(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
class LlavaImageProcessor(BaseImageProcessor):
|
||||
def __init__(self, hf_config, server_args, _image_processor):
|
||||
self.hf_config = hf_config
|
||||
self._image_processor = _image_processor
|
||||
self.executor = concurrent.futures.ProcessPoolExecutor(
|
||||
initializer=init_global_processor,
|
||||
mp_context=mp.get_context("fork"),
|
||||
initargs=(server_args,),
|
||||
max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_single_image_task(
|
||||
image_data: Union[str, bytes],
|
||||
image_aspect_ratio: Optional[str] = None,
|
||||
image_grid_pinpoints: Optional[str] = None,
|
||||
image_processor=None,
|
||||
):
|
||||
image_processor = image_processor or global_processor.image_processor
|
||||
|
||||
try:
|
||||
image, image_size = load_image(image_data)
|
||||
if image_size is not None:
|
||||
# It is a video with multiple images
|
||||
image_hash = hash(image_data)
|
||||
pixel_values = image_processor(image)["pixel_values"]
|
||||
for _ in range(len(pixel_values)):
|
||||
pixel_values[_] = pixel_values[_].astype(np.float16)
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
return pixel_values, image_hash, image_size
|
||||
else:
|
||||
# It is an image
|
||||
image_hash = hash(image_data)
|
||||
if image_aspect_ratio == "pad":
|
||||
image = expand2square(
|
||||
image,
|
||||
tuple(int(x * 255) for x in image_processor.image_mean),
|
||||
)
|
||||
pixel_values = image_processor(image.convert("RGB"))[
|
||||
"pixel_values"
|
||||
][0]
|
||||
elif image_aspect_ratio == "anyres" or (
|
||||
image_aspect_ratio is not None
|
||||
and "anyres_max" in image_aspect_ratio
|
||||
):
|
||||
pixel_values = process_anyres_image(
|
||||
image, image_processor, image_grid_pinpoints
|
||||
)
|
||||
else:
|
||||
pixel_values = image_processor(image)["pixel_values"][0]
|
||||
|
||||
if isinstance(pixel_values, np.ndarray):
|
||||
pixel_values = pixel_values.astype(np.float16)
|
||||
|
||||
return pixel_values, image_hash, image.size
|
||||
except Exception:
|
||||
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
|
||||
async def _process_single_image(
|
||||
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
|
||||
):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.executor,
|
||||
LlavaImageProcessor._process_single_image_task,
|
||||
image_data,
|
||||
aspect_ratio,
|
||||
grid_pinpoints,
|
||||
)
|
||||
else:
|
||||
return self._process_single_image_task(
|
||||
image_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
|
||||
async def process_images_async(
|
||||
self, image_data: List[Union[str, bytes]], request_obj
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
||||
grid_pinpoints = (
|
||||
self.hf_config.image_grid_pinpoints
|
||||
if hasattr(self.hf_config, "image_grid_pinpoints")
|
||||
and "anyres" in aspect_ratio
|
||||
else None
|
||||
)
|
||||
|
||||
if isinstance(image_data, list) and len(image_data) > 0:
|
||||
# Multiple images
|
||||
if len(image_data) > 1:
|
||||
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
|
||||
pixel_values, image_hashes, image_sizes = [], [], []
|
||||
res = []
|
||||
for img_data in image_data:
|
||||
res.append(
|
||||
self._process_single_image(
|
||||
img_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
)
|
||||
res = await asyncio.gather(*res)
|
||||
for pixel_v, image_h, image_s in res:
|
||||
pixel_values.append(pixel_v)
|
||||
image_hashes.append(image_h)
|
||||
image_sizes.append(image_s)
|
||||
|
||||
if isinstance(pixel_values[0], np.ndarray):
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
else:
|
||||
# A single image
|
||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||
image_data[0], aspect_ratio, grid_pinpoints
|
||||
)
|
||||
image_hashes = [image_hash]
|
||||
image_sizes = [image_size]
|
||||
elif isinstance(image_data, str):
|
||||
# A single image
|
||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||
image_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
image_hashes = [image_hash]
|
||||
image_sizes = [image_size]
|
||||
else:
|
||||
raise ValueError(f"Invalid image data: {image_data}")
|
||||
|
||||
return {
|
||||
"pixel_values": pixel_values,
|
||||
"image_hashes": image_hashes,
|
||||
"image_sizes": image_sizes,
|
||||
"modalities": request_obj.modalities,
|
||||
}
|
||||
|
||||
|
||||
def get_image_processor(
|
||||
hf_config, server_args: ServerArgs, _image_processor
|
||||
) -> BaseImageProcessor:
|
||||
return LlavaImageProcessor(hf_config, server_args, _image_processor)
|
||||
|
||||
|
||||
def get_dummy_image_processor():
|
||||
return DummyImageProcessor()
|
||||
@@ -16,17 +16,13 @@ limitations under the License.
|
||||
"""TokenizerManager is a process that tokenizes the text."""
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import fastapi
|
||||
import numpy as np
|
||||
import transformers
|
||||
import uvloop
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
@@ -38,6 +34,10 @@ from sglang.srt.hf_transformers_utils import (
|
||||
get_processor,
|
||||
get_tokenizer,
|
||||
)
|
||||
from sglang.srt.managers.image_processor import (
|
||||
get_dummy_image_processor,
|
||||
get_image_processor,
|
||||
)
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchEmbeddingOut,
|
||||
@@ -53,11 +53,9 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightReqInput,
|
||||
UpdateWeightReqOutput,
|
||||
)
|
||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import is_generation_model, is_multimodal_model, load_image
|
||||
from sglang.utils import get_exception_traceback
|
||||
from sglang.srt.utils import is_generation_model, is_multimodal_model
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
@@ -105,6 +103,8 @@ class TokenizerManager:
|
||||
self.context_len = server_args.context_length or get_context_length(
|
||||
self.hf_config
|
||||
)
|
||||
# Create image processor placeholder
|
||||
self.image_processor = get_dummy_image_processor()
|
||||
|
||||
# Create tokenizer
|
||||
if server_args.skip_tokenizer_init:
|
||||
@@ -119,13 +119,9 @@ class TokenizerManager:
|
||||
self.tokenizer = self.processor.tokenizer
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# We want to parallelize the image pre-processing so we
|
||||
# create an executor for it
|
||||
self.executor = concurrent.futures.ProcessPoolExecutor(
|
||||
initializer=init_global_processor,
|
||||
mp_context=mp.get_context("fork"),
|
||||
initargs=(server_args,),
|
||||
max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
|
||||
# We want to parallelize the image pre-processing so we create an executor for it
|
||||
self.image_processor = get_image_processor(
|
||||
self.hf_config, server_args, self.processor.image_processor
|
||||
)
|
||||
else:
|
||||
self.tokenizer = get_tokenizer(
|
||||
@@ -194,8 +190,8 @@ class TokenizerManager:
|
||||
)
|
||||
|
||||
if self.is_generation:
|
||||
image_inputs = await self._get_image_inputs(
|
||||
obj, obj.image_data if not_use_index else obj.image_data[index]
|
||||
image_inputs = await self.image_processor.process_images_async(
|
||||
obj.image_data if not_use_index else obj.image_data[index], obj
|
||||
)
|
||||
return_logprob = (
|
||||
obj.return_logprob if not_use_index else obj.return_logprob[index]
|
||||
@@ -247,7 +243,9 @@ class TokenizerManager:
|
||||
|
||||
sampling_params = SamplingParams(**obj.sampling_params[0])
|
||||
sampling_params.max_new_tokens = 0
|
||||
image_inputs = await self._get_image_inputs(obj, obj.image_data[0])
|
||||
image_inputs = await self.image_processor.process_images_async(
|
||||
obj.image_data[0], obj
|
||||
)
|
||||
return_logprob = obj.return_logprob[0]
|
||||
logprob_start_len = obj.logprob_start_len[0]
|
||||
top_logprobs_num = obj.top_logprobs_num[0]
|
||||
@@ -362,8 +360,8 @@ class TokenizerManager:
|
||||
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
||||
|
||||
if self.is_generation:
|
||||
image_inputs = await self._get_image_inputs(
|
||||
obj, obj.image_data[index]
|
||||
image_inputs = await self.image_processor.process_images_async(
|
||||
obj.image_data[index], obj
|
||||
)
|
||||
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
@@ -686,131 +684,3 @@ class TokenizerManager:
|
||||
token_top_logprobs, decode_to_text
|
||||
)
|
||||
return top_logprobs
|
||||
|
||||
async def _get_image_inputs(self, obj, image_data: List[Union[str, bytes]]):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
# TODO: move this into a processor for each vision architecture
|
||||
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
||||
grid_pinpoints = (
|
||||
self.hf_config.image_grid_pinpoints
|
||||
if hasattr(self.hf_config, "image_grid_pinpoints")
|
||||
and "anyres" in aspect_ratio
|
||||
else None
|
||||
)
|
||||
|
||||
if isinstance(image_data, list) and len(image_data) > 0:
|
||||
# Multiple images
|
||||
if len(image_data) > 1:
|
||||
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
|
||||
pixel_values, image_hashes, image_sizes = [], [], []
|
||||
for img_data in image_data:
|
||||
pixel_v, image_h, image_s = await self._process_single_image(
|
||||
img_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
pixel_values.append(pixel_v)
|
||||
image_hashes.append(image_h)
|
||||
image_sizes.append(image_s)
|
||||
|
||||
if isinstance(pixel_values[0], np.ndarray):
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
else:
|
||||
# A single image
|
||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||
image_data[0], aspect_ratio, grid_pinpoints
|
||||
)
|
||||
image_hashes = [image_hash]
|
||||
image_sizes = [image_size]
|
||||
elif isinstance(image_data, str):
|
||||
# A single image
|
||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||
image_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
image_hashes = [image_hash]
|
||||
image_sizes = [image_size]
|
||||
else:
|
||||
raise ValueError(f"Invalid image data: {image_data}")
|
||||
|
||||
return {
|
||||
"pixel_values": pixel_values,
|
||||
"image_hashes": image_hashes,
|
||||
"image_sizes": image_sizes,
|
||||
"modalities": obj.modalities,
|
||||
}
|
||||
|
||||
async def _process_single_image(
|
||||
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
|
||||
):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.executor,
|
||||
_process_single_image_task,
|
||||
image_data,
|
||||
aspect_ratio,
|
||||
grid_pinpoints,
|
||||
)
|
||||
else:
|
||||
return _process_single_image_task(
|
||||
image_data, aspect_ratio, grid_pinpoints, self.processor
|
||||
)
|
||||
|
||||
|
||||
global global_processor
|
||||
|
||||
|
||||
def init_global_processor(server_args: ServerArgs):
|
||||
"""Init the global processor for multi modal models."""
|
||||
global global_processor
|
||||
transformers.logging.set_verbosity_error()
|
||||
global_processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
def _process_single_image_task(
|
||||
image_data: Union[str, bytes],
|
||||
image_aspect_ratio: Optional[str] = None,
|
||||
image_grid_pinpoints: Optional[str] = None,
|
||||
processor=None,
|
||||
):
|
||||
try:
|
||||
processor = processor or global_processor
|
||||
image, image_size = load_image(image_data)
|
||||
if image_size is not None:
|
||||
# It is a video with multiple images
|
||||
image_hash = hash(image_data)
|
||||
pixel_values = processor.image_processor(image)["pixel_values"]
|
||||
for _ in range(len(pixel_values)):
|
||||
pixel_values[_] = pixel_values[_].astype(np.float16)
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
return pixel_values, image_hash, image_size
|
||||
else:
|
||||
# It is an image
|
||||
image_hash = hash(image_data)
|
||||
if image_aspect_ratio == "pad":
|
||||
image = expand2square(
|
||||
image,
|
||||
tuple(int(x * 255) for x in processor.image_processor.image_mean),
|
||||
)
|
||||
pixel_values = processor.image_processor(image.convert("RGB"))[
|
||||
"pixel_values"
|
||||
][0]
|
||||
elif image_aspect_ratio == "anyres" or (
|
||||
image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio
|
||||
):
|
||||
pixel_values = process_anyres_image(
|
||||
image, processor.image_processor, image_grid_pinpoints
|
||||
)
|
||||
else:
|
||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||
|
||||
if isinstance(pixel_values, np.ndarray):
|
||||
pixel_values = pixel_values.astype(np.float16)
|
||||
|
||||
return pixel_values, image_hash, image.size
|
||||
except Exception:
|
||||
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
|
||||
Reference in New Issue
Block a user