[Minor] Improve the function organization in TokenizerManager & improve loggers (#1208)
This commit is contained in:
@@ -21,7 +21,7 @@ import dataclasses
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import transformers
|
||||
@@ -80,6 +80,7 @@ class TokenizerManager:
|
||||
):
|
||||
self.server_args = server_args
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.asyncio.Context(2)
|
||||
self.recv_from_detokenizer = context.socket(zmq.PULL)
|
||||
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
||||
@@ -87,6 +88,7 @@ class TokenizerManager:
|
||||
self.send_to_router = context.socket(zmq.PUSH)
|
||||
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
|
||||
|
||||
# Read model args
|
||||
self.model_path = server_args.model_path
|
||||
self.served_model_name = server_args.served_model_name
|
||||
self.hf_config = get_config(
|
||||
@@ -104,6 +106,7 @@ class TokenizerManager:
|
||||
else:
|
||||
self.context_len = get_context_length(self.hf_config)
|
||||
|
||||
# Create tokenizer
|
||||
if server_args.skip_tokenizer_init:
|
||||
self.tokenizer = self.processor = None
|
||||
else:
|
||||
@@ -127,6 +130,7 @@ class TokenizerManager:
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
|
||||
# Store states
|
||||
self.to_create_loop = True
|
||||
self.rid_to_state: Dict[str, ReqState] = {}
|
||||
|
||||
@@ -134,63 +138,6 @@ class TokenizerManager:
|
||||
self.model_update_lock = asyncio.Lock()
|
||||
self.model_update_result = None
|
||||
|
||||
async def get_pixel_values(self, image_data, aspect_ratio=None):
|
||||
aspect_ratio = (
|
||||
getattr(self.hf_config, "image_aspect_ratio", None)
|
||||
if aspect_ratio is None
|
||||
else aspect_ratio
|
||||
)
|
||||
grid_pinpoints = (
|
||||
self.hf_config.image_grid_pinpoints
|
||||
if hasattr(self.hf_config, "image_grid_pinpoints")
|
||||
and "anyres" in aspect_ratio
|
||||
else None
|
||||
)
|
||||
|
||||
if isinstance(image_data, list) and len(image_data) > 0:
|
||||
pixel_values, image_hash, image_size = [], [], []
|
||||
if len(image_data) > 1:
|
||||
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
|
||||
for img_data in image_data:
|
||||
pixel_v, image_h, image_s = await self._process_single_image(
|
||||
img_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
pixel_values.append(pixel_v)
|
||||
image_hash.append(image_h)
|
||||
image_size.append(image_s)
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
else:
|
||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||
image_data[0], aspect_ratio, grid_pinpoints
|
||||
)
|
||||
image_hash = [image_hash]
|
||||
image_size = [image_size]
|
||||
elif isinstance(image_data, str):
|
||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||
image_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
image_hash = [image_hash]
|
||||
image_size = [image_size]
|
||||
else:
|
||||
pixel_values, image_hash, image_size = None, None, None
|
||||
|
||||
return pixel_values, image_hash, image_size
|
||||
|
||||
async def _process_single_image(self, image_data, aspect_ratio, grid_pinpoints):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.executor,
|
||||
get_pixel_values,
|
||||
image_data,
|
||||
aspect_ratio,
|
||||
grid_pinpoints,
|
||||
)
|
||||
else:
|
||||
return get_pixel_values(
|
||||
image_data, aspect_ratio, grid_pinpoints, self.processor
|
||||
)
|
||||
|
||||
async def generate_request(
|
||||
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None
|
||||
):
|
||||
@@ -198,7 +145,7 @@ class TokenizerManager:
|
||||
self.create_handle_loop()
|
||||
|
||||
while self.model_update_lock.locked():
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
obj.post_init()
|
||||
is_single = obj.is_single
|
||||
@@ -214,8 +161,8 @@ class TokenizerManager:
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
request,
|
||||
index=None,
|
||||
is_cache_for_prefill=False,
|
||||
index: Optional[int] = None,
|
||||
is_cache_for_prefill: Optional[bool] = False,
|
||||
):
|
||||
if not is_cache_for_prefill: # The normal case with a single prompt
|
||||
not_use_index = index is None
|
||||
@@ -235,7 +182,7 @@ class TokenizerManager:
|
||||
)
|
||||
|
||||
if self.is_generation:
|
||||
pixel_values, image_hash, image_size = await self.get_pixel_values(
|
||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||
obj.image_data
|
||||
)
|
||||
return_logprob = (
|
||||
@@ -345,7 +292,7 @@ class TokenizerManager:
|
||||
parallel_sample_num = obj.parallel_sample_num
|
||||
|
||||
if parallel_sample_num != 1:
|
||||
# Send prefill requests to cache the common input
|
||||
# Send prefill requests to cache the common prefix
|
||||
parallel_sample_num += 1
|
||||
input_id_result = [] if obj.input_ids is None else None
|
||||
for i in range(batch_size):
|
||||
@@ -436,7 +383,6 @@ class TokenizerManager:
|
||||
)
|
||||
|
||||
# Then process the responses based on streaming option
|
||||
|
||||
is_stream = hasattr(obj, "stream") and obj.stream
|
||||
|
||||
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
|
||||
@@ -482,9 +428,9 @@ class TokenizerManager:
|
||||
|
||||
async def _get_pixel_values(self, image_data):
|
||||
if isinstance(image_data, list) and len(image_data) > 0:
|
||||
return await self.get_pixel_values(image_data[0])
|
||||
return await self._get_pixel_values_internal(image_data[0])
|
||||
elif isinstance(image_data, str):
|
||||
return await self.get_pixel_values(image_data)
|
||||
return await self._get_pixel_values_internal(image_data)
|
||||
else:
|
||||
return None, None, None
|
||||
|
||||
@@ -563,6 +509,13 @@ class TokenizerManager:
|
||||
req = FlushCacheReq()
|
||||
self.send_to_router.send_pyobj(req)
|
||||
|
||||
def abort_request(self, rid: str):
|
||||
if rid not in self.rid_to_state:
|
||||
return
|
||||
del self.rid_to_state[rid]
|
||||
req = AbortReq(rid)
|
||||
self.send_to_router.send_pyobj(req)
|
||||
|
||||
async def update_weights(self, obj: UpdateWeightReqInput, request):
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
@@ -587,13 +540,6 @@ class TokenizerManager:
|
||||
else:
|
||||
return False, "Another update is in progress. Please try again later."
|
||||
|
||||
def abort_request(self, rid: str):
|
||||
if rid not in self.rid_to_state:
|
||||
return
|
||||
del self.rid_to_state[rid]
|
||||
req = AbortReq(rid)
|
||||
self.send_to_router.send_pyobj(req)
|
||||
|
||||
def create_abort_task(self, obj: GenerateReqInput):
|
||||
# Abort the request if the client is disconnected.
|
||||
async def abort_request():
|
||||
@@ -617,6 +563,8 @@ class TokenizerManager:
|
||||
loop.create_task(self.handle_loop())
|
||||
|
||||
async def handle_loop(self):
|
||||
"""The event loop that handles requests"""
|
||||
|
||||
while True:
|
||||
recv_obj: Union[
|
||||
BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
|
||||
@@ -713,11 +661,69 @@ class TokenizerManager:
|
||||
)
|
||||
return top_logprobs
|
||||
|
||||
async def _get_pixel_values_internal(self, image_data, aspect_ratio=None):
|
||||
aspect_ratio = (
|
||||
getattr(self.hf_config, "image_aspect_ratio", None)
|
||||
if aspect_ratio is None
|
||||
else aspect_ratio
|
||||
)
|
||||
grid_pinpoints = (
|
||||
self.hf_config.image_grid_pinpoints
|
||||
if hasattr(self.hf_config, "image_grid_pinpoints")
|
||||
and "anyres" in aspect_ratio
|
||||
else None
|
||||
)
|
||||
|
||||
if isinstance(image_data, list) and len(image_data) > 0:
|
||||
pixel_values, image_hash, image_size = [], [], []
|
||||
if len(image_data) > 1:
|
||||
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
|
||||
for img_data in image_data:
|
||||
pixel_v, image_h, image_s = await self._process_single_image(
|
||||
img_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
pixel_values.append(pixel_v)
|
||||
image_hash.append(image_h)
|
||||
image_size.append(image_s)
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
else:
|
||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||
image_data[0], aspect_ratio, grid_pinpoints
|
||||
)
|
||||
image_hash = [image_hash]
|
||||
image_size = [image_size]
|
||||
elif isinstance(image_data, str):
|
||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||
image_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
image_hash = [image_hash]
|
||||
image_size = [image_size]
|
||||
else:
|
||||
pixel_values, image_hash, image_size = None, None, None
|
||||
|
||||
return pixel_values, image_hash, image_size
|
||||
|
||||
async def _process_single_image(self, image_data, aspect_ratio, grid_pinpoints):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.executor,
|
||||
_process_single_image_task,
|
||||
image_data,
|
||||
aspect_ratio,
|
||||
grid_pinpoints,
|
||||
)
|
||||
else:
|
||||
return _process_single_image_task(
|
||||
image_data, aspect_ratio, grid_pinpoints, self.processor
|
||||
)
|
||||
|
||||
|
||||
global global_processor
|
||||
|
||||
|
||||
def init_global_processor(server_args: ServerArgs):
|
||||
"""Init the global processor for multi modal models."""
|
||||
global global_processor
|
||||
transformers.logging.set_verbosity_error()
|
||||
global_processor = get_processor(
|
||||
@@ -727,7 +733,7 @@ def init_global_processor(server_args: ServerArgs):
|
||||
)
|
||||
|
||||
|
||||
def get_pixel_values(
|
||||
def _process_single_image_task(
|
||||
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
|
||||
):
|
||||
try:
|
||||
@@ -759,4 +765,4 @@ def get_pixel_values(
|
||||
pixel_values = pixel_values.astype(np.float16)
|
||||
return pixel_values, image_hash, image.size
|
||||
except Exception:
|
||||
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
|
||||
Reference in New Issue
Block a user