[Fix] Fix llava on multi images (#1247)
This commit is contained in:
@@ -23,6 +23,7 @@ import multiprocessing as mp
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import fastapi
|
||||
import numpy as np
|
||||
import transformers
|
||||
import uvloop
|
||||
@@ -96,21 +97,18 @@ class TokenizerManager:
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
model_overide_args=model_overide_args,
|
||||
)
|
||||
|
||||
self.is_generation = is_generation_model(
|
||||
self.hf_config.architectures, self.server_args.is_embedding
|
||||
)
|
||||
|
||||
if server_args.context_length is not None:
|
||||
self.context_len = server_args.context_length
|
||||
else:
|
||||
self.context_len = get_context_length(self.hf_config)
|
||||
self.context_len = server_args.context_length or get_context_length(
|
||||
self.hf_config
|
||||
)
|
||||
|
||||
# Create tokenizer
|
||||
if server_args.skip_tokenizer_init:
|
||||
self.tokenizer = self.processor = None
|
||||
else:
|
||||
if is_multimodal_model(self.model_path):
|
||||
if is_multimodal_model(self.hf_config.architectures):
|
||||
self.processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
@@ -118,6 +116,9 @@ class TokenizerManager:
|
||||
)
|
||||
self.tokenizer = self.processor.tokenizer
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# We want to parallelize the image pre-processing so we
|
||||
# create an executor for it
|
||||
self.executor = concurrent.futures.ProcessPoolExecutor(
|
||||
initializer=init_global_processor,
|
||||
mp_context=mp.get_context("fork"),
|
||||
@@ -134,12 +135,14 @@ class TokenizerManager:
|
||||
self.to_create_loop = True
|
||||
self.rid_to_state: Dict[str, ReqState] = {}
|
||||
|
||||
# for update model weights
|
||||
# For update model weights
|
||||
self.model_update_lock = asyncio.Lock()
|
||||
self.model_update_result = None
|
||||
|
||||
async def generate_request(
|
||||
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
@@ -160,7 +163,7 @@ class TokenizerManager:
|
||||
async def _handle_single_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
request,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
index: Optional[int] = None,
|
||||
is_cache_for_prefill: Optional[bool] = False,
|
||||
):
|
||||
@@ -182,8 +185,8 @@ class TokenizerManager:
|
||||
)
|
||||
|
||||
if self.is_generation:
|
||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||
obj.image_data
|
||||
pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
|
||||
obj.image_data if not_use_index else obj.image_data[index]
|
||||
)
|
||||
return_logprob = (
|
||||
obj.return_logprob if not_use_index else obj.return_logprob[index]
|
||||
@@ -195,7 +198,6 @@ class TokenizerManager:
|
||||
)
|
||||
if return_logprob and logprob_start_len == -1:
|
||||
logprob_start_len = len(input_ids) - 1
|
||||
|
||||
top_logprobs_num = (
|
||||
obj.top_logprobs_num
|
||||
if not_use_index
|
||||
@@ -238,13 +240,14 @@ class TokenizerManager:
|
||||
|
||||
sampling_params = SamplingParams(**obj.sampling_params[0])
|
||||
sampling_params.max_new_tokens = 0
|
||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||
pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
|
||||
obj.image_data[0]
|
||||
)
|
||||
return_logprob = obj.return_logprob[0]
|
||||
logprob_start_len = obj.logprob_start_len[0]
|
||||
top_logprobs_num = obj.top_logprobs_num[0]
|
||||
|
||||
# Send to the controller
|
||||
if self.is_generation:
|
||||
if return_logprob and logprob_start_len == -1:
|
||||
logprob_start_len = len(input_ids) - 1
|
||||
@@ -253,8 +256,8 @@ class TokenizerManager:
|
||||
input_text,
|
||||
input_ids,
|
||||
pixel_values,
|
||||
image_hash,
|
||||
image_size,
|
||||
image_hashes,
|
||||
image_sizes,
|
||||
sampling_params,
|
||||
return_logprob,
|
||||
logprob_start_len,
|
||||
@@ -268,24 +271,24 @@ class TokenizerManager:
|
||||
input_ids,
|
||||
sampling_params,
|
||||
)
|
||||
|
||||
self.send_to_router.send_pyobj(tokenized_obj)
|
||||
|
||||
# Recv results
|
||||
event = asyncio.Event()
|
||||
state = ReqState([], False, event)
|
||||
self.rid_to_state[rid] = state
|
||||
if not is_cache_for_prefill:
|
||||
async for response in self._wait_for_response(
|
||||
event, state, obj, rid, request
|
||||
):
|
||||
async for response in self._wait_for_response(state, obj, rid, request):
|
||||
yield response
|
||||
else:
|
||||
assert self.is_generation
|
||||
await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
|
||||
await self._wait_for_cache_prefill_response(state, obj, rid, request)
|
||||
yield input_ids
|
||||
|
||||
async def _handle_batch_request(
|
||||
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
batch_size = obj.batch_size
|
||||
if self.is_generation:
|
||||
@@ -340,8 +343,8 @@ class TokenizerManager:
|
||||
if self.is_generation:
|
||||
if obj.return_logprob[index] and obj.logprob_start_len[index] == -1:
|
||||
obj.logprob_start_len[index] = len(input_ids) - 1
|
||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||
obj.image_data[index]
|
||||
pixel_values, image_hashes, image_sizes = (
|
||||
await self._get_pixel_values(obj.image_data[index])
|
||||
)
|
||||
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
@@ -349,8 +352,8 @@ class TokenizerManager:
|
||||
input_text,
|
||||
input_ids,
|
||||
pixel_values,
|
||||
image_hash,
|
||||
image_size,
|
||||
image_hashes,
|
||||
image_sizes,
|
||||
sampling_params,
|
||||
obj.return_logprob[index],
|
||||
obj.logprob_start_len[index],
|
||||
@@ -372,7 +375,6 @@ class TokenizerManager:
|
||||
|
||||
generators.append(
|
||||
self._wait_for_response(
|
||||
event,
|
||||
state,
|
||||
obj,
|
||||
rid,
|
||||
@@ -388,6 +390,7 @@ class TokenizerManager:
|
||||
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
|
||||
output_list = [None] * len(tasks)
|
||||
|
||||
# Recv results
|
||||
while tasks:
|
||||
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
@@ -426,25 +429,18 @@ class TokenizerManager:
|
||||
sampling_params.verify()
|
||||
return sampling_params
|
||||
|
||||
async def _get_pixel_values(self, image_data):
|
||||
if image_data is None:
|
||||
return None, None, None
|
||||
else:
|
||||
return await self._get_pixel_values_internal(image_data)
|
||||
|
||||
async def _wait_for_response(
|
||||
self,
|
||||
event: asyncio.Event,
|
||||
state: ReqState,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
rid: str,
|
||||
request,
|
||||
index: int = None,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
index: Optional[int] = None,
|
||||
response_index: int = 0,
|
||||
):
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(event.wait(), timeout=4)
|
||||
await asyncio.wait_for(state.event.wait(), timeout=4)
|
||||
except asyncio.TimeoutError:
|
||||
if request is not None and await request.is_disconnected():
|
||||
for rid in [obj.rid] if obj.is_single else obj.rid:
|
||||
@@ -478,16 +474,15 @@ class TokenizerManager:
|
||||
yield out
|
||||
break
|
||||
|
||||
event.clear()
|
||||
state.event.clear()
|
||||
yield out
|
||||
|
||||
async def _wait_for_cache_prefill_response(
|
||||
self,
|
||||
event: asyncio.Event,
|
||||
state: ReqState,
|
||||
obj: GenerateReqInput,
|
||||
rid: str,
|
||||
request,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
while True:
|
||||
try:
|
||||
@@ -514,7 +509,9 @@ class TokenizerManager:
|
||||
req = AbortReq(rid)
|
||||
self.send_to_router.send_pyobj(req)
|
||||
|
||||
async def update_weights(self, obj: UpdateWeightReqInput, request):
|
||||
async def update_weights(
|
||||
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
|
||||
@@ -659,12 +656,11 @@ class TokenizerManager:
|
||||
)
|
||||
return top_logprobs
|
||||
|
||||
async def _get_pixel_values_internal(self, image_data, aspect_ratio=None):
|
||||
aspect_ratio = (
|
||||
getattr(self.hf_config, "image_aspect_ratio", None)
|
||||
if aspect_ratio is None
|
||||
else aspect_ratio
|
||||
)
|
||||
async def _get_pixel_values(self, image_data: List[Union[str, bytes]]):
|
||||
if not image_data:
|
||||
return None, None, None
|
||||
|
||||
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
||||
grid_pinpoints = (
|
||||
self.hf_config.image_grid_pinpoints
|
||||
if hasattr(self.hf_config, "image_grid_pinpoints")
|
||||
@@ -673,35 +669,42 @@ class TokenizerManager:
|
||||
)
|
||||
|
||||
if isinstance(image_data, list) and len(image_data) > 0:
|
||||
pixel_values, image_hash, image_size = [], [], []
|
||||
# Multiple images
|
||||
if len(image_data) > 1:
|
||||
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
|
||||
pixel_values, image_hashes, image_sizes = [], [], []
|
||||
for img_data in image_data:
|
||||
pixel_v, image_h, image_s = await self._process_single_image(
|
||||
img_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
pixel_values.append(pixel_v)
|
||||
image_hash.append(image_h)
|
||||
image_size.append(image_s)
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
image_hashes.append(image_h)
|
||||
image_sizes.append(image_s)
|
||||
|
||||
if isinstance(pixel_values[0], np.ndarray):
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
else:
|
||||
# A single image
|
||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||
image_data[0], aspect_ratio, grid_pinpoints
|
||||
)
|
||||
image_hash = [image_hash]
|
||||
image_size = [image_size]
|
||||
image_hashes = [image_hash]
|
||||
image_sizes = [image_size]
|
||||
elif isinstance(image_data, str):
|
||||
# A single image
|
||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||
image_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
image_hash = [image_hash]
|
||||
image_size = [image_size]
|
||||
image_hashes = [image_hash]
|
||||
image_sizes = [image_size]
|
||||
else:
|
||||
pixel_values, image_hash, image_size = None, None, None
|
||||
raise ValueError(f"Invalid image data: {image_data}")
|
||||
|
||||
return pixel_values, image_hash, image_size
|
||||
return pixel_values, image_hashes, image_sizes
|
||||
|
||||
async def _process_single_image(self, image_data, aspect_ratio, grid_pinpoints):
|
||||
async def _process_single_image(
|
||||
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
|
||||
):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
@@ -732,12 +735,16 @@ def init_global_processor(server_args: ServerArgs):
|
||||
|
||||
|
||||
def _process_single_image_task(
|
||||
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
|
||||
image_data: Union[str, bytes],
|
||||
image_aspect_ratio: Optional[str] = None,
|
||||
image_grid_pinpoints: Optional[str] = None,
|
||||
processor=None,
|
||||
):
|
||||
try:
|
||||
processor = processor or global_processor
|
||||
image, image_size = load_image(image_data)
|
||||
if image_size is not None:
|
||||
# It is a video with multiple images
|
||||
image_hash = hash(image_data)
|
||||
pixel_values = processor.image_processor(image)["pixel_values"]
|
||||
for _ in range(len(pixel_values)):
|
||||
@@ -745,6 +752,7 @@ def _process_single_image_task(
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
return pixel_values, image_hash, image_size
|
||||
else:
|
||||
# It is an image
|
||||
image_hash = hash(image_data)
|
||||
if image_aspect_ratio == "pad":
|
||||
image = expand2square(
|
||||
@@ -754,13 +762,18 @@ def _process_single_image_task(
|
||||
pixel_values = processor.image_processor(image.convert("RGB"))[
|
||||
"pixel_values"
|
||||
][0]
|
||||
elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
|
||||
elif image_aspect_ratio == "anyres" or (
|
||||
image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio
|
||||
):
|
||||
pixel_values = process_anyres_image(
|
||||
image, processor.image_processor, image_grid_pinpoints
|
||||
)
|
||||
else:
|
||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||
pixel_values = pixel_values.astype(np.float16)
|
||||
|
||||
if isinstance(pixel_values, np.ndarray):
|
||||
pixel_values = pixel_values.astype(np.float16)
|
||||
|
||||
return pixel_values, image_hash, image.size
|
||||
except Exception:
|
||||
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
|
||||
Reference in New Issue
Block a user