Process image in parallel (#1539)

This commit is contained in:
Liangsheng Yin
2024-09-29 18:52:43 -07:00
committed by GitHub
parent f86c1e611f
commit 55b974f96f
2 changed files with 204 additions and 147 deletions

View File

@@ -16,17 +16,13 @@ limitations under the License.
"""TokenizerManager is a process that tokenizes the text."""
import asyncio
import concurrent.futures
import dataclasses
import json
import logging
import multiprocessing as mp
import os
from typing import Dict, List, Optional, Tuple, Union
import fastapi
import numpy as np
import transformers
import uvloop
import zmq
import zmq.asyncio
@@ -38,6 +34,10 @@ from sglang.srt.hf_transformers_utils import (
get_processor,
get_tokenizer,
)
from sglang.srt.managers.image_processor import (
get_dummy_image_processor,
get_image_processor,
)
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
@@ -53,11 +53,9 @@ from sglang.srt.managers.io_struct import (
UpdateWeightReqInput,
UpdateWeightReqOutput,
)
from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import is_generation_model, is_multimodal_model, load_image
from sglang.utils import get_exception_traceback
from sglang.srt.utils import is_generation_model, is_multimodal_model
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -105,6 +103,8 @@ class TokenizerManager:
self.context_len = server_args.context_length or get_context_length(
self.hf_config
)
# Create image processor placeholder
self.image_processor = get_dummy_image_processor()
# Create tokenizer
if server_args.skip_tokenizer_init:
@@ -119,13 +119,9 @@ class TokenizerManager:
self.tokenizer = self.processor.tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# We want to parallelize the image pre-processing so we
# create an executor for it
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
initargs=(server_args,),
max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
# We want to parallelize the image pre-processing so we create an executor for it
self.image_processor = get_image_processor(
self.hf_config, server_args, self.processor.image_processor
)
else:
self.tokenizer = get_tokenizer(
@@ -194,8 +190,8 @@ class TokenizerManager:
)
if self.is_generation:
image_inputs = await self._get_image_inputs(
obj, obj.image_data if not_use_index else obj.image_data[index]
image_inputs = await self.image_processor.process_images_async(
obj.image_data if not_use_index else obj.image_data[index], obj
)
return_logprob = (
obj.return_logprob if not_use_index else obj.return_logprob[index]
@@ -247,7 +243,9 @@ class TokenizerManager:
sampling_params = SamplingParams(**obj.sampling_params[0])
sampling_params.max_new_tokens = 0
image_inputs = await self._get_image_inputs(obj, obj.image_data[0])
image_inputs = await self.image_processor.process_images_async(
obj.image_data[0], obj
)
return_logprob = obj.return_logprob[0]
logprob_start_len = obj.logprob_start_len[0]
top_logprobs_num = obj.top_logprobs_num[0]
@@ -362,8 +360,8 @@ class TokenizerManager:
sampling_params = self._get_sampling_params(obj.sampling_params[index])
if self.is_generation:
image_inputs = await self._get_image_inputs(
obj, obj.image_data[index]
image_inputs = await self.image_processor.process_images_async(
obj.image_data[index], obj
)
tokenized_obj = TokenizedGenerateReqInput(
@@ -686,131 +684,3 @@ class TokenizerManager:
token_top_logprobs, decode_to_text
)
return top_logprobs
async def _get_image_inputs(self, obj, image_data: List[Union[str, bytes]]):
if not image_data:
return None
# TODO: move this into a processor for each vision architecture
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
grid_pinpoints = (
self.hf_config.image_grid_pinpoints
if hasattr(self.hf_config, "image_grid_pinpoints")
and "anyres" in aspect_ratio
else None
)
if isinstance(image_data, list) and len(image_data) > 0:
# Multiple images
if len(image_data) > 1:
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values, image_hashes, image_sizes = [], [], []
for img_data in image_data:
pixel_v, image_h, image_s = await self._process_single_image(
img_data, aspect_ratio, grid_pinpoints
)
pixel_values.append(pixel_v)
image_hashes.append(image_h)
image_sizes.append(image_s)
if isinstance(pixel_values[0], np.ndarray):
pixel_values = np.stack(pixel_values, axis=0)
else:
# A single image
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
image_hashes = [image_hash]
image_sizes = [image_size]
elif isinstance(image_data, str):
# A single image
pixel_values, image_hash, image_size = await self._process_single_image(
image_data, aspect_ratio, grid_pinpoints
)
image_hashes = [image_hash]
image_sizes = [image_size]
else:
raise ValueError(f"Invalid image data: {image_data}")
return {
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"image_sizes": image_sizes,
"modalities": obj.modalities,
}
async def _process_single_image(
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
_process_single_image_task,
image_data,
aspect_ratio,
grid_pinpoints,
)
else:
return _process_single_image_task(
image_data, aspect_ratio, grid_pinpoints, self.processor
)
global global_processor
def init_global_processor(server_args: ServerArgs):
"""Init the global processor for multi modal models."""
global global_processor
transformers.logging.set_verbosity_error()
global_processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
def _process_single_image_task(
image_data: Union[str, bytes],
image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[str] = None,
processor=None,
):
try:
processor = processor or global_processor
image, image_size = load_image(image_data)
if image_size is not None:
# It is a video with multiple images
image_hash = hash(image_data)
pixel_values = processor.image_processor(image)["pixel_values"]
for _ in range(len(pixel_values)):
pixel_values[_] = pixel_values[_].astype(np.float16)
pixel_values = np.stack(pixel_values, axis=0)
return pixel_values, image_hash, image_size
else:
# It is an image
image_hash = hash(image_data)
if image_aspect_ratio == "pad":
image = expand2square(
image,
tuple(int(x * 255) for x in processor.image_processor.image_mean),
)
pixel_values = processor.image_processor(image.convert("RGB"))[
"pixel_values"
][0]
elif image_aspect_ratio == "anyres" or (
image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio
):
pixel_values = process_anyres_image(
image, processor.image_processor, image_grid_pinpoints
)
else:
pixel_values = processor.image_processor(image)["pixel_values"][0]
if isinstance(pixel_values, np.ndarray):
pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash, image.size
except Exception:
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())