[Minor] Improve the function organization in TokenizerManager & improve loggers (#1208)

This commit is contained in:
Lianmin Zheng
2024-08-25 14:46:34 -07:00
committed by GitHub
parent 30b4f771b0
commit 902278008a
12 changed files with 137 additions and 134 deletions

View File

@@ -21,7 +21,7 @@ import dataclasses
import logging
import multiprocessing as mp
import os
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import transformers
@@ -80,6 +80,7 @@ class TokenizerManager:
):
self.server_args = server_args
# Init inter-process communication
context = zmq.asyncio.Context(2)
self.recv_from_detokenizer = context.socket(zmq.PULL)
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
@@ -87,6 +88,7 @@ class TokenizerManager:
self.send_to_router = context.socket(zmq.PUSH)
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
# Read model args
self.model_path = server_args.model_path
self.served_model_name = server_args.served_model_name
self.hf_config = get_config(
@@ -104,6 +106,7 @@ class TokenizerManager:
else:
self.context_len = get_context_length(self.hf_config)
# Create tokenizer
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
else:
@@ -127,6 +130,7 @@ class TokenizerManager:
trust_remote_code=server_args.trust_remote_code,
)
# Store states
self.to_create_loop = True
self.rid_to_state: Dict[str, ReqState] = {}
@@ -134,63 +138,6 @@ class TokenizerManager:
self.model_update_lock = asyncio.Lock()
self.model_update_result = None
async def get_pixel_values(self, image_data, aspect_ratio=None):
aspect_ratio = (
getattr(self.hf_config, "image_aspect_ratio", None)
if aspect_ratio is None
else aspect_ratio
)
grid_pinpoints = (
self.hf_config.image_grid_pinpoints
if hasattr(self.hf_config, "image_grid_pinpoints")
and "anyres" in aspect_ratio
else None
)
if isinstance(image_data, list) and len(image_data) > 0:
pixel_values, image_hash, image_size = [], [], []
if len(image_data) > 1:
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
for img_data in image_data:
pixel_v, image_h, image_s = await self._process_single_image(
img_data, aspect_ratio, grid_pinpoints
)
pixel_values.append(pixel_v)
image_hash.append(image_h)
image_size.append(image_s)
pixel_values = np.stack(pixel_values, axis=0)
else:
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
image_hash = [image_hash]
image_size = [image_size]
elif isinstance(image_data, str):
pixel_values, image_hash, image_size = await self._process_single_image(
image_data, aspect_ratio, grid_pinpoints
)
image_hash = [image_hash]
image_size = [image_size]
else:
pixel_values, image_hash, image_size = None, None, None
return pixel_values, image_hash, image_size
async def _process_single_image(self, image_data, aspect_ratio, grid_pinpoints):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
get_pixel_values,
image_data,
aspect_ratio,
grid_pinpoints,
)
else:
return get_pixel_values(
image_data, aspect_ratio, grid_pinpoints, self.processor
)
async def generate_request(
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None
):
@@ -198,7 +145,7 @@ class TokenizerManager:
self.create_handle_loop()
while self.model_update_lock.locked():
await asyncio.sleep(0)
await asyncio.sleep(0.001)
obj.post_init()
is_single = obj.is_single
@@ -214,8 +161,8 @@ class TokenizerManager:
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request,
index=None,
is_cache_for_prefill=False,
index: Optional[int] = None,
is_cache_for_prefill: Optional[bool] = False,
):
if not is_cache_for_prefill: # The normal case with a single prompt
not_use_index = index is None
@@ -235,7 +182,7 @@ class TokenizerManager:
)
if self.is_generation:
pixel_values, image_hash, image_size = await self.get_pixel_values(
pixel_values, image_hash, image_size = await self._get_pixel_values(
obj.image_data
)
return_logprob = (
@@ -345,7 +292,7 @@ class TokenizerManager:
parallel_sample_num = obj.parallel_sample_num
if parallel_sample_num != 1:
# Send prefill requests to cache the common input
# Send prefill requests to cache the common prefix
parallel_sample_num += 1
input_id_result = [] if obj.input_ids is None else None
for i in range(batch_size):
@@ -436,7 +383,6 @@ class TokenizerManager:
)
# Then process the responses based on streaming option
is_stream = hasattr(obj, "stream") and obj.stream
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
@@ -482,9 +428,9 @@ class TokenizerManager:
async def _get_pixel_values(self, image_data):
if isinstance(image_data, list) and len(image_data) > 0:
return await self.get_pixel_values(image_data[0])
return await self._get_pixel_values_internal(image_data[0])
elif isinstance(image_data, str):
return await self.get_pixel_values(image_data)
return await self._get_pixel_values_internal(image_data)
else:
return None, None, None
@@ -563,6 +509,13 @@ class TokenizerManager:
req = FlushCacheReq()
self.send_to_router.send_pyobj(req)
def abort_request(self, rid: str):
if rid not in self.rid_to_state:
return
del self.rid_to_state[rid]
req = AbortReq(rid)
self.send_to_router.send_pyobj(req)
async def update_weights(self, obj: UpdateWeightReqInput, request):
if self.to_create_loop:
self.create_handle_loop()
@@ -587,13 +540,6 @@ class TokenizerManager:
else:
return False, "Another update is in progress. Please try again later."
def abort_request(self, rid: str):
if rid not in self.rid_to_state:
return
del self.rid_to_state[rid]
req = AbortReq(rid)
self.send_to_router.send_pyobj(req)
def create_abort_task(self, obj: GenerateReqInput):
# Abort the request if the client is disconnected.
async def abort_request():
@@ -617,6 +563,8 @@ class TokenizerManager:
loop.create_task(self.handle_loop())
async def handle_loop(self):
"""The event loop that handles requests"""
while True:
recv_obj: Union[
BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
@@ -713,11 +661,69 @@ class TokenizerManager:
)
return top_logprobs
async def _get_pixel_values_internal(self, image_data, aspect_ratio=None):
aspect_ratio = (
getattr(self.hf_config, "image_aspect_ratio", None)
if aspect_ratio is None
else aspect_ratio
)
grid_pinpoints = (
self.hf_config.image_grid_pinpoints
if hasattr(self.hf_config, "image_grid_pinpoints")
and "anyres" in aspect_ratio
else None
)
if isinstance(image_data, list) and len(image_data) > 0:
pixel_values, image_hash, image_size = [], [], []
if len(image_data) > 1:
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
for img_data in image_data:
pixel_v, image_h, image_s = await self._process_single_image(
img_data, aspect_ratio, grid_pinpoints
)
pixel_values.append(pixel_v)
image_hash.append(image_h)
image_size.append(image_s)
pixel_values = np.stack(pixel_values, axis=0)
else:
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
image_hash = [image_hash]
image_size = [image_size]
elif isinstance(image_data, str):
pixel_values, image_hash, image_size = await self._process_single_image(
image_data, aspect_ratio, grid_pinpoints
)
image_hash = [image_hash]
image_size = [image_size]
else:
pixel_values, image_hash, image_size = None, None, None
return pixel_values, image_hash, image_size
async def _process_single_image(self, image_data, aspect_ratio, grid_pinpoints):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
_process_single_image_task,
image_data,
aspect_ratio,
grid_pinpoints,
)
else:
return _process_single_image_task(
image_data, aspect_ratio, grid_pinpoints, self.processor
)
global global_processor
def init_global_processor(server_args: ServerArgs):
"""Init the global processor for multi modal models."""
global global_processor
transformers.logging.set_verbosity_error()
global_processor = get_processor(
@@ -727,7 +733,7 @@ def init_global_processor(server_args: ServerArgs):
)
def get_pixel_values(
def _process_single_image_task(
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
):
try:
@@ -759,4 +765,4 @@ def get_pixel_values(
pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash, image.size
except Exception:
print("Exception in TokenizerManager:\n" + get_exception_traceback())
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())