[Fix] Fix llava on multi images (#1247)

This commit is contained in:
Lianmin Zheng
2024-08-28 06:33:05 -07:00
committed by GitHub
parent b1a540ec42
commit bf53bf5142
22 changed files with 272 additions and 488 deletions

View File

@@ -23,6 +23,7 @@ import multiprocessing as mp
import os
from typing import Dict, List, Optional, Tuple, Union
import fastapi
import numpy as np
import transformers
import uvloop
@@ -96,21 +97,18 @@ class TokenizerManager:
trust_remote_code=server_args.trust_remote_code,
model_overide_args=model_overide_args,
)
self.is_generation = is_generation_model(
self.hf_config.architectures, self.server_args.is_embedding
)
if server_args.context_length is not None:
self.context_len = server_args.context_length
else:
self.context_len = get_context_length(self.hf_config)
self.context_len = server_args.context_length or get_context_length(
self.hf_config
)
# Create tokenizer
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
else:
if is_multimodal_model(self.model_path):
if is_multimodal_model(self.hf_config.architectures):
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
@@ -118,6 +116,9 @@ class TokenizerManager:
)
self.tokenizer = self.processor.tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# We want to parallelize the image pre-processing so we
# create an executor for it
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
@@ -134,12 +135,14 @@ class TokenizerManager:
self.to_create_loop = True
self.rid_to_state: Dict[str, ReqState] = {}
# for update model weights
# For update model weights
self.model_update_lock = asyncio.Lock()
self.model_update_result = None
async def generate_request(
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
):
if self.to_create_loop:
self.create_handle_loop()
@@ -160,7 +163,7 @@ class TokenizerManager:
async def _handle_single_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request,
request: Optional[fastapi.Request] = None,
index: Optional[int] = None,
is_cache_for_prefill: Optional[bool] = False,
):
@@ -182,8 +185,8 @@ class TokenizerManager:
)
if self.is_generation:
pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data
pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
obj.image_data if not_use_index else obj.image_data[index]
)
return_logprob = (
obj.return_logprob if not_use_index else obj.return_logprob[index]
@@ -195,7 +198,6 @@ class TokenizerManager:
)
if return_logprob and logprob_start_len == -1:
logprob_start_len = len(input_ids) - 1
top_logprobs_num = (
obj.top_logprobs_num
if not_use_index
@@ -238,13 +240,14 @@ class TokenizerManager:
sampling_params = SamplingParams(**obj.sampling_params[0])
sampling_params.max_new_tokens = 0
pixel_values, image_hash, image_size = await self._get_pixel_values(
pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
obj.image_data[0]
)
return_logprob = obj.return_logprob[0]
logprob_start_len = obj.logprob_start_len[0]
top_logprobs_num = obj.top_logprobs_num[0]
# Send to the controller
if self.is_generation:
if return_logprob and logprob_start_len == -1:
logprob_start_len = len(input_ids) - 1
@@ -253,8 +256,8 @@ class TokenizerManager:
input_text,
input_ids,
pixel_values,
image_hash,
image_size,
image_hashes,
image_sizes,
sampling_params,
return_logprob,
logprob_start_len,
@@ -268,24 +271,24 @@ class TokenizerManager:
input_ids,
sampling_params,
)
self.send_to_router.send_pyobj(tokenized_obj)
# Recv results
event = asyncio.Event()
state = ReqState([], False, event)
self.rid_to_state[rid] = state
if not is_cache_for_prefill:
async for response in self._wait_for_response(
event, state, obj, rid, request
):
async for response in self._wait_for_response(state, obj, rid, request):
yield response
else:
assert self.is_generation
await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
await self._wait_for_cache_prefill_response(state, obj, rid, request)
yield input_ids
async def _handle_batch_request(
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
):
batch_size = obj.batch_size
if self.is_generation:
@@ -340,8 +343,8 @@ class TokenizerManager:
if self.is_generation:
if obj.return_logprob[index] and obj.logprob_start_len[index] == -1:
obj.logprob_start_len[index] = len(input_ids) - 1
pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data[index]
pixel_values, image_hashes, image_sizes = (
await self._get_pixel_values(obj.image_data[index])
)
tokenized_obj = TokenizedGenerateReqInput(
@@ -349,8 +352,8 @@ class TokenizerManager:
input_text,
input_ids,
pixel_values,
image_hash,
image_size,
image_hashes,
image_sizes,
sampling_params,
obj.return_logprob[index],
obj.logprob_start_len[index],
@@ -372,7 +375,6 @@ class TokenizerManager:
generators.append(
self._wait_for_response(
event,
state,
obj,
rid,
@@ -388,6 +390,7 @@ class TokenizerManager:
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
output_list = [None] * len(tasks)
# Recv results
while tasks:
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
@@ -426,25 +429,18 @@ class TokenizerManager:
sampling_params.verify()
return sampling_params
async def _get_pixel_values(self, image_data):
if image_data is None:
return None, None, None
else:
return await self._get_pixel_values_internal(image_data)
async def _wait_for_response(
self,
event: asyncio.Event,
state: ReqState,
obj: Union[GenerateReqInput, EmbeddingReqInput],
rid: str,
request,
index: int = None,
request: Optional[fastapi.Request] = None,
index: Optional[int] = None,
response_index: int = 0,
):
while True:
try:
await asyncio.wait_for(event.wait(), timeout=4)
await asyncio.wait_for(state.event.wait(), timeout=4)
except asyncio.TimeoutError:
if request is not None and await request.is_disconnected():
for rid in [obj.rid] if obj.is_single else obj.rid:
@@ -478,16 +474,15 @@ class TokenizerManager:
yield out
break
event.clear()
state.event.clear()
yield out
async def _wait_for_cache_prefill_response(
self,
event: asyncio.Event,
state: ReqState,
obj: GenerateReqInput,
rid: str,
request,
request: Optional[fastapi.Request] = None,
):
while True:
try:
@@ -514,7 +509,9 @@ class TokenizerManager:
req = AbortReq(rid)
self.send_to_router.send_pyobj(req)
async def update_weights(self, obj: UpdateWeightReqInput, request):
async def update_weights(
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
):
if self.to_create_loop:
self.create_handle_loop()
@@ -659,12 +656,11 @@ class TokenizerManager:
)
return top_logprobs
async def _get_pixel_values_internal(self, image_data, aspect_ratio=None):
aspect_ratio = (
getattr(self.hf_config, "image_aspect_ratio", None)
if aspect_ratio is None
else aspect_ratio
)
async def _get_pixel_values(self, image_data: List[Union[str, bytes]]):
if not image_data:
return None, None, None
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
grid_pinpoints = (
self.hf_config.image_grid_pinpoints
if hasattr(self.hf_config, "image_grid_pinpoints")
@@ -673,35 +669,42 @@ class TokenizerManager:
)
if isinstance(image_data, list) and len(image_data) > 0:
pixel_values, image_hash, image_size = [], [], []
# Multiple images
if len(image_data) > 1:
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values, image_hashes, image_sizes = [], [], []
for img_data in image_data:
pixel_v, image_h, image_s = await self._process_single_image(
img_data, aspect_ratio, grid_pinpoints
)
pixel_values.append(pixel_v)
image_hash.append(image_h)
image_size.append(image_s)
pixel_values = np.stack(pixel_values, axis=0)
image_hashes.append(image_h)
image_sizes.append(image_s)
if isinstance(pixel_values[0], np.ndarray):
pixel_values = np.stack(pixel_values, axis=0)
else:
# A single image
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
image_hash = [image_hash]
image_size = [image_size]
image_hashes = [image_hash]
image_sizes = [image_size]
elif isinstance(image_data, str):
# A single image
pixel_values, image_hash, image_size = await self._process_single_image(
image_data, aspect_ratio, grid_pinpoints
)
image_hash = [image_hash]
image_size = [image_size]
image_hashes = [image_hash]
image_sizes = [image_size]
else:
pixel_values, image_hash, image_size = None, None, None
raise ValueError(f"Invalid image data: {image_data}")
return pixel_values, image_hash, image_size
return pixel_values, image_hashes, image_sizes
async def _process_single_image(self, image_data, aspect_ratio, grid_pinpoints):
async def _process_single_image(
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
@@ -732,12 +735,16 @@ def init_global_processor(server_args: ServerArgs):
def _process_single_image_task(
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
image_data: Union[str, bytes],
image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[str] = None,
processor=None,
):
try:
processor = processor or global_processor
image, image_size = load_image(image_data)
if image_size is not None:
# It is a video with multiple images
image_hash = hash(image_data)
pixel_values = processor.image_processor(image)["pixel_values"]
for _ in range(len(pixel_values)):
@@ -745,6 +752,7 @@ def _process_single_image_task(
pixel_values = np.stack(pixel_values, axis=0)
return pixel_values, image_hash, image_size
else:
# It is an image
image_hash = hash(image_data)
if image_aspect_ratio == "pad":
image = expand2square(
@@ -754,13 +762,18 @@ def _process_single_image_task(
pixel_values = processor.image_processor(image.convert("RGB"))[
"pixel_values"
][0]
elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
elif image_aspect_ratio == "anyres" or (
image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio
):
pixel_values = process_anyres_image(
image, processor.image_processor, image_grid_pinpoints
)
else:
pixel_values = processor.image_processor(image)["pixel_values"][0]
pixel_values = pixel_values.astype(np.float16)
if isinstance(pixel_values, np.ndarray):
pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash, image.size
except Exception:
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())