2024-07-28 23:07:12 +10:00
|
|
|
"""
|
|
|
|
|
Copyright 2023-2024 SGLang Team
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License.
|
|
|
|
|
"""
|
|
|
|
|
|
2024-06-08 02:06:52 -07:00
|
|
|
"""TokenizerManager is a process that tokenizes the text."""
|
2024-06-12 21:48:40 -07:00
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
import asyncio
|
|
|
|
|
import concurrent.futures
|
|
|
|
|
import dataclasses
|
2024-09-09 04:14:11 -07:00
|
|
|
import json
|
2024-05-12 06:41:32 -07:00
|
|
|
import logging
|
2024-01-24 10:35:31 +00:00
|
|
|
import multiprocessing as mp
|
2024-01-08 04:37:50 +00:00
|
|
|
import os
|
2024-08-25 14:46:34 -07:00
|
|
|
from typing import Dict, List, Optional, Tuple, Union
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-08-28 06:33:05 -07:00
|
|
|
import fastapi
|
2024-01-08 04:37:50 +00:00
|
|
|
import numpy as np
|
|
|
|
|
import transformers
|
|
|
|
|
import uvloop
|
|
|
|
|
import zmq
|
|
|
|
|
import zmq.asyncio
|
2024-05-20 18:41:21 -07:00
|
|
|
from fastapi import BackgroundTasks
|
2024-04-22 22:38:09 +08:00
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
from sglang.srt.hf_transformers_utils import (
|
|
|
|
|
get_config,
|
|
|
|
|
get_context_length,
|
|
|
|
|
get_processor,
|
|
|
|
|
get_tokenizer,
|
|
|
|
|
)
|
|
|
|
|
from sglang.srt.managers.io_struct import (
|
2024-05-17 05:49:31 -07:00
|
|
|
AbortReq,
|
2024-08-08 16:31:19 -07:00
|
|
|
BatchEmbeddingOut,
|
2024-01-08 04:37:50 +00:00
|
|
|
BatchStrOut,
|
2024-06-12 21:48:40 -07:00
|
|
|
BatchTokenIDOut,
|
2024-08-08 16:31:19 -07:00
|
|
|
EmbeddingReqInput,
|
2024-01-29 17:05:42 -08:00
|
|
|
FlushCacheReq,
|
2024-01-08 04:37:50 +00:00
|
|
|
GenerateReqInput,
|
2024-09-27 23:32:11 -07:00
|
|
|
RewardReqInput,
|
2024-08-08 16:31:19 -07:00
|
|
|
TokenizedEmbeddingReqInput,
|
2024-01-08 04:37:50 +00:00
|
|
|
TokenizedGenerateReqInput,
|
2024-09-27 23:32:11 -07:00
|
|
|
TokenizedRewardReqInput,
|
2024-08-20 13:48:24 -07:00
|
|
|
UpdateWeightReqInput,
|
|
|
|
|
UpdateWeightReqOutput,
|
2024-01-08 04:37:50 +00:00
|
|
|
)
|
2024-01-24 01:51:21 -08:00
|
|
|
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
2024-08-21 16:48:24 -07:00
|
|
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
2024-01-08 04:37:50 +00:00
|
|
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
2024-08-08 16:31:19 -07:00
|
|
|
from sglang.srt.utils import is_generation_model, is_multimodal_model, load_image
|
2024-05-16 18:07:30 -07:00
|
|
|
from sglang.utils import get_exception_traceback
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
|
|
|
|
2024-05-12 06:41:32 -07:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
@dataclasses.dataclass
|
|
|
|
|
class ReqState:
|
2024-08-21 19:24:36 -07:00
|
|
|
"""Store the state a request."""
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
out_list: List
|
|
|
|
|
finished: bool
|
|
|
|
|
event: asyncio.Event
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TokenizerManager:
|
2024-08-21 19:24:36 -07:00
|
|
|
"""TokenizerManager is a process that tokenizes the text."""
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
server_args: ServerArgs,
|
|
|
|
|
port_args: PortArgs,
|
|
|
|
|
):
|
2024-03-11 20:06:52 +08:00
|
|
|
self.server_args = server_args
|
|
|
|
|
|
2024-08-25 14:46:34 -07:00
|
|
|
# Init inter-process communication
|
2024-01-08 04:37:50 +00:00
|
|
|
context = zmq.asyncio.Context(2)
|
|
|
|
|
self.recv_from_detokenizer = context.socket(zmq.PULL)
|
|
|
|
|
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
|
|
|
|
|
2024-09-29 02:36:12 -07:00
|
|
|
self.send_to_scheduler = context.socket(zmq.PUSH)
|
|
|
|
|
self.send_to_scheduler.connect(f"tcp://127.0.0.1:{port_args.scheduler_port}")
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-08-25 14:46:34 -07:00
|
|
|
# Read model args
|
2024-01-08 04:37:50 +00:00
|
|
|
self.model_path = server_args.model_path
|
2024-08-02 08:13:51 +08:00
|
|
|
self.served_model_name = server_args.served_model_name
|
2024-01-08 04:37:50 +00:00
|
|
|
self.hf_config = get_config(
|
2024-05-14 07:57:00 +08:00
|
|
|
self.model_path,
|
|
|
|
|
trust_remote_code=server_args.trust_remote_code,
|
2024-09-09 04:14:11 -07:00
|
|
|
model_override_args=json.loads(server_args.json_model_override_args),
|
2024-01-08 04:37:50 +00:00
|
|
|
)
|
2024-08-26 01:29:12 +08:00
|
|
|
self.is_generation = is_generation_model(
|
|
|
|
|
self.hf_config.architectures, self.server_args.is_embedding
|
|
|
|
|
)
|
2024-08-28 06:33:05 -07:00
|
|
|
self.context_len = server_args.context_length or get_context_length(
|
|
|
|
|
self.hf_config
|
|
|
|
|
)
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-08-25 14:46:34 -07:00
|
|
|
# Create tokenizer
|
2024-08-10 03:14:13 +08:00
|
|
|
if server_args.skip_tokenizer_init:
|
|
|
|
|
self.tokenizer = self.processor = None
|
2024-01-08 04:37:50 +00:00
|
|
|
else:
|
2024-08-28 06:33:05 -07:00
|
|
|
if is_multimodal_model(self.hf_config.architectures):
|
2024-08-10 03:14:13 +08:00
|
|
|
self.processor = get_processor(
|
|
|
|
|
server_args.tokenizer_path,
|
|
|
|
|
tokenizer_mode=server_args.tokenizer_mode,
|
|
|
|
|
trust_remote_code=server_args.trust_remote_code,
|
|
|
|
|
)
|
|
|
|
|
self.tokenizer = self.processor.tokenizer
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
2024-08-28 06:33:05 -07:00
|
|
|
|
|
|
|
|
# We want to parallelize the image pre-processing so we
|
|
|
|
|
# create an executor for it
|
2024-08-10 03:14:13 +08:00
|
|
|
self.executor = concurrent.futures.ProcessPoolExecutor(
|
|
|
|
|
initializer=init_global_processor,
|
|
|
|
|
mp_context=mp.get_context("fork"),
|
|
|
|
|
initargs=(server_args,),
|
2024-09-22 17:20:26 +08:00
|
|
|
max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
|
2024-08-10 03:14:13 +08:00
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
self.tokenizer = get_tokenizer(
|
|
|
|
|
server_args.tokenizer_path,
|
|
|
|
|
tokenizer_mode=server_args.tokenizer_mode,
|
|
|
|
|
trust_remote_code=server_args.trust_remote_code,
|
|
|
|
|
)
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-08-25 14:46:34 -07:00
|
|
|
# Store states
|
2024-01-08 04:37:50 +00:00
|
|
|
self.to_create_loop = True
|
2024-06-12 21:48:40 -07:00
|
|
|
self.rid_to_state: Dict[str, ReqState] = {}
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-08-28 06:33:05 -07:00
|
|
|
# For update model weights
|
2024-08-20 13:48:24 -07:00
|
|
|
self.model_update_lock = asyncio.Lock()
|
|
|
|
|
self.model_update_result = None
|
|
|
|
|
|
2024-08-08 16:31:19 -07:00
|
|
|
async def generate_request(
|
2024-08-28 06:33:05 -07:00
|
|
|
self,
|
2024-09-27 23:32:11 -07:00
|
|
|
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
2024-08-28 06:33:05 -07:00
|
|
|
request: Optional[fastapi.Request] = None,
|
2024-08-08 16:31:19 -07:00
|
|
|
):
|
2024-01-08 04:37:50 +00:00
|
|
|
if self.to_create_loop:
|
2024-05-17 05:49:31 -07:00
|
|
|
self.create_handle_loop()
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-08-20 13:48:24 -07:00
|
|
|
while self.model_update_lock.locked():
|
2024-08-25 14:46:34 -07:00
|
|
|
await asyncio.sleep(0.001)
|
2024-08-20 13:48:24 -07:00
|
|
|
|
2024-05-17 05:49:31 -07:00
|
|
|
obj.post_init()
|
2024-05-12 12:29:00 -10:00
|
|
|
is_single = obj.is_single
|
|
|
|
|
|
2024-07-20 14:10:01 +08:00
|
|
|
if is_single:
|
|
|
|
|
async for response in self._handle_single_request(obj, request):
|
|
|
|
|
yield response
|
|
|
|
|
else:
|
|
|
|
|
async for response in self._handle_batch_request(obj, request):
|
|
|
|
|
yield response
|
2024-01-30 23:12:33 +09:00
|
|
|
|
2024-07-27 05:05:15 -07:00
|
|
|
async def _handle_single_request(
|
2024-08-08 16:31:19 -07:00
|
|
|
self,
|
2024-09-27 23:32:11 -07:00
|
|
|
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
2024-08-28 06:33:05 -07:00
|
|
|
request: Optional[fastapi.Request] = None,
|
2024-08-25 14:46:34 -07:00
|
|
|
index: Optional[int] = None,
|
|
|
|
|
is_cache_for_prefill: Optional[bool] = False,
|
2024-07-27 05:05:15 -07:00
|
|
|
):
|
2024-08-05 07:43:09 +08:00
|
|
|
if not is_cache_for_prefill: # The normal case with a single prompt
|
|
|
|
|
not_use_index = index is None
|
|
|
|
|
|
2024-07-30 04:07:18 +08:00
|
|
|
rid = obj.rid if not_use_index else obj.rid[index]
|
|
|
|
|
input_text = obj.text if not_use_index else obj.text[index]
|
2024-09-27 23:32:11 -07:00
|
|
|
if hasattr(obj, "conv"):
|
|
|
|
|
# reward model
|
|
|
|
|
assert self.tokenizer is not None
|
|
|
|
|
conv = obj.conv if not_use_index else obj.conv[index]
|
|
|
|
|
input_text = self.tokenizer.apply_chat_template(conv, tokenize=False)
|
|
|
|
|
input_ids = self.tokenizer.encode(input_text)
|
|
|
|
|
elif obj.input_ids is None:
|
2024-08-10 03:14:13 +08:00
|
|
|
assert self.tokenizer is not None
|
2024-08-08 16:31:19 -07:00
|
|
|
input_ids = self.tokenizer.encode(input_text)
|
|
|
|
|
else:
|
|
|
|
|
input_ids = obj.input_ids if not_use_index else obj.input_ids[index]
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-07-20 14:10:01 +08:00
|
|
|
self._validate_input_length(input_ids)
|
2024-07-30 04:07:18 +08:00
|
|
|
|
2024-07-20 14:10:01 +08:00
|
|
|
sampling_params = self._get_sampling_params(
|
2024-07-30 04:07:18 +08:00
|
|
|
obj.sampling_params if not_use_index else obj.sampling_params[index]
|
2024-07-20 14:10:01 +08:00
|
|
|
)
|
2024-08-08 16:31:19 -07:00
|
|
|
|
|
|
|
|
if self.is_generation:
|
2024-09-28 23:28:55 -07:00
|
|
|
image_inputs = await self._get_image_inputs(
|
|
|
|
|
obj, obj.image_data if not_use_index else obj.image_data[index]
|
2024-08-08 16:31:19 -07:00
|
|
|
)
|
|
|
|
|
return_logprob = (
|
|
|
|
|
obj.return_logprob if not_use_index else obj.return_logprob[index]
|
|
|
|
|
)
|
|
|
|
|
logprob_start_len = (
|
|
|
|
|
obj.logprob_start_len
|
|
|
|
|
if not_use_index
|
|
|
|
|
else obj.logprob_start_len[index]
|
|
|
|
|
)
|
|
|
|
|
top_logprobs_num = (
|
|
|
|
|
obj.top_logprobs_num
|
|
|
|
|
if not_use_index
|
|
|
|
|
else obj.top_logprobs_num[index]
|
|
|
|
|
)
|
2024-08-05 07:43:09 +08:00
|
|
|
else: # A prefill request to cache the common prompt for parallel sampling
|
2024-08-08 16:31:19 -07:00
|
|
|
assert self.is_generation
|
2024-08-05 07:43:09 +08:00
|
|
|
if obj.text is not None:
|
|
|
|
|
if isinstance(obj.text, list):
|
|
|
|
|
input_text = obj.text[index]
|
|
|
|
|
rid = obj.rid[index]
|
|
|
|
|
else:
|
|
|
|
|
input_text = obj.text
|
|
|
|
|
rid = obj.rid[0]
|
2024-08-10 03:14:13 +08:00
|
|
|
if self.tokenizer is not None:
|
|
|
|
|
input_ids = self.tokenizer.encode(input_text)
|
|
|
|
|
else:
|
|
|
|
|
assert obj.input_ids is not None
|
|
|
|
|
input_ids = obj.input_ids
|
|
|
|
|
if isinstance(obj.input_ids, list) and isinstance(
|
|
|
|
|
obj.input_ids[0], list
|
|
|
|
|
):
|
|
|
|
|
# when obj["input_ids"] is List[List[int]]
|
|
|
|
|
input_ids = obj.input_ids[index]
|
|
|
|
|
rid = obj.rid[index]
|
|
|
|
|
else:
|
|
|
|
|
input_ids = obj.input_ids
|
|
|
|
|
rid = obj.rid[0]
|
2024-07-27 05:05:15 -07:00
|
|
|
else:
|
2024-08-05 07:43:09 +08:00
|
|
|
input_text = None
|
|
|
|
|
if isinstance(obj.input_ids, list) and isinstance(
|
|
|
|
|
obj.input_ids[0], list
|
|
|
|
|
):
|
|
|
|
|
# when obj["input_ids"] is List[List[int]]
|
|
|
|
|
input_ids = obj.input_ids[index]
|
|
|
|
|
rid = obj.rid[index]
|
|
|
|
|
else:
|
|
|
|
|
input_ids = obj.input_ids
|
|
|
|
|
rid = obj.rid[0]
|
|
|
|
|
|
2024-07-27 05:05:15 -07:00
|
|
|
sampling_params = SamplingParams(**obj.sampling_params[0])
|
|
|
|
|
sampling_params.max_new_tokens = 0
|
2024-09-28 23:28:55 -07:00
|
|
|
image_inputs = await self._get_image_inputs(obj, obj.image_data[0])
|
2024-07-27 05:05:15 -07:00
|
|
|
return_logprob = obj.return_logprob[0]
|
|
|
|
|
logprob_start_len = obj.logprob_start_len[0]
|
|
|
|
|
top_logprobs_num = obj.top_logprobs_num[0]
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-08-28 06:33:05 -07:00
|
|
|
# Send to the controller
|
2024-08-08 16:31:19 -07:00
|
|
|
if self.is_generation:
|
|
|
|
|
tokenized_obj = TokenizedGenerateReqInput(
|
|
|
|
|
rid,
|
|
|
|
|
input_text,
|
|
|
|
|
input_ids,
|
2024-09-28 23:28:55 -07:00
|
|
|
image_inputs,
|
2024-08-08 16:31:19 -07:00
|
|
|
sampling_params,
|
|
|
|
|
return_logprob,
|
|
|
|
|
logprob_start_len,
|
|
|
|
|
top_logprobs_num,
|
|
|
|
|
obj.stream,
|
2024-09-12 16:46:14 -07:00
|
|
|
(
|
|
|
|
|
obj.lora_path[index]
|
|
|
|
|
if isinstance(obj.lora_path, list)
|
|
|
|
|
else obj.lora_path
|
|
|
|
|
),
|
2024-08-08 16:31:19 -07:00
|
|
|
)
|
2024-09-27 23:32:11 -07:00
|
|
|
elif isinstance(obj, EmbeddingReqInput):
|
2024-08-08 16:31:19 -07:00
|
|
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
|
|
|
|
rid,
|
|
|
|
|
input_text,
|
|
|
|
|
input_ids,
|
|
|
|
|
sampling_params,
|
|
|
|
|
)
|
2024-09-27 23:32:11 -07:00
|
|
|
else:
|
|
|
|
|
assert isinstance(obj, RewardReqInput)
|
|
|
|
|
tokenized_obj = TokenizedRewardReqInput(
|
|
|
|
|
rid,
|
|
|
|
|
input_text,
|
|
|
|
|
input_ids,
|
|
|
|
|
sampling_params,
|
|
|
|
|
)
|
2024-09-29 02:36:12 -07:00
|
|
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
2024-07-20 14:10:01 +08:00
|
|
|
|
2024-08-28 06:33:05 -07:00
|
|
|
# Recv results
|
2024-07-20 14:10:01 +08:00
|
|
|
event = asyncio.Event()
|
|
|
|
|
state = ReqState([], False, event)
|
|
|
|
|
self.rid_to_state[rid] = state
|
2024-07-27 05:05:15 -07:00
|
|
|
if not is_cache_for_prefill:
|
2024-08-28 06:33:05 -07:00
|
|
|
async for response in self._wait_for_response(state, obj, rid, request):
|
2024-07-20 14:10:01 +08:00
|
|
|
yield response
|
2024-07-27 05:05:15 -07:00
|
|
|
else:
|
2024-08-09 11:19:18 -07:00
|
|
|
assert self.is_generation
|
2024-08-28 06:33:05 -07:00
|
|
|
await self._wait_for_cache_prefill_response(state, obj, rid, request)
|
2024-07-27 05:05:15 -07:00
|
|
|
yield input_ids
|
2024-07-20 14:10:01 +08:00
|
|
|
|
2024-08-10 08:39:05 -07:00
|
|
|
async def _handle_batch_request(
|
2024-08-28 06:33:05 -07:00
|
|
|
self,
|
2024-09-27 23:32:11 -07:00
|
|
|
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
2024-08-28 06:33:05 -07:00
|
|
|
request: Optional[fastapi.Request] = None,
|
2024-08-10 08:39:05 -07:00
|
|
|
):
|
2024-07-20 14:10:01 +08:00
|
|
|
batch_size = obj.batch_size
|
2024-08-10 08:39:05 -07:00
|
|
|
if self.is_generation:
|
|
|
|
|
parallel_sample_num = obj.parallel_sample_num
|
|
|
|
|
|
|
|
|
|
if parallel_sample_num != 1:
|
2024-08-25 14:46:34 -07:00
|
|
|
# Send prefill requests to cache the common prefix
|
2024-08-10 08:39:05 -07:00
|
|
|
parallel_sample_num += 1
|
|
|
|
|
input_id_result = [] if obj.input_ids is None else None
|
|
|
|
|
for i in range(batch_size):
|
|
|
|
|
async for input_id in self._handle_single_request(
|
|
|
|
|
obj, request, index=i, is_cache_for_prefill=True
|
|
|
|
|
):
|
|
|
|
|
if input_id_result is not None:
|
|
|
|
|
input_id_result.append(input_id)
|
|
|
|
|
if input_id_result is not None and len(input_id_result) > 1:
|
|
|
|
|
obj.input_ids = input_id_result
|
|
|
|
|
elif input_id_result is not None:
|
|
|
|
|
obj.input_ids = input_id_result[0]
|
|
|
|
|
else:
|
|
|
|
|
parallel_sample_num = 1
|
2024-08-05 07:43:09 +08:00
|
|
|
|
2024-07-20 14:10:01 +08:00
|
|
|
# First send out all requests
|
2024-08-20 08:06:55 -07:00
|
|
|
generators = []
|
2024-07-20 14:10:01 +08:00
|
|
|
for i in range(batch_size):
|
|
|
|
|
for j in range(parallel_sample_num):
|
|
|
|
|
if j == 0 and parallel_sample_num != 1:
|
2024-05-17 05:49:31 -07:00
|
|
|
continue
|
2024-07-20 14:10:01 +08:00
|
|
|
index = i * parallel_sample_num + j
|
|
|
|
|
if parallel_sample_num != 1:
|
2024-07-30 04:07:18 +08:00
|
|
|
# Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
|
2024-07-20 14:10:01 +08:00
|
|
|
index += batch_size - 1 - i
|
|
|
|
|
rid = obj.rid[index]
|
|
|
|
|
if parallel_sample_num == 1:
|
|
|
|
|
## select operation
|
2024-09-27 23:32:11 -07:00
|
|
|
if hasattr(obj, "conv"):
|
|
|
|
|
# reward model
|
|
|
|
|
conv = obj.conv[i]
|
|
|
|
|
input_text = self.tokenizer.apply_chat_template(
|
|
|
|
|
conv, tokenize=False
|
|
|
|
|
)
|
|
|
|
|
input_ids = self.tokenizer.encode(input_text)
|
|
|
|
|
elif obj.input_ids is None:
|
2024-07-20 14:10:01 +08:00
|
|
|
input_text = obj.text[i]
|
2024-09-27 23:32:11 -07:00
|
|
|
input_ids = self.tokenizer.encode(input_text)
|
2024-07-20 14:10:01 +08:00
|
|
|
else:
|
|
|
|
|
input_text = None
|
|
|
|
|
input_ids = obj.input_ids[i]
|
|
|
|
|
else:
|
2024-08-05 07:43:09 +08:00
|
|
|
assert obj.input_ids is not None
|
2024-07-20 14:10:01 +08:00
|
|
|
if batch_size == 1:
|
2024-08-05 07:43:09 +08:00
|
|
|
input_text = None
|
2024-07-20 14:10:01 +08:00
|
|
|
input_ids = obj.input_ids
|
|
|
|
|
else:
|
2024-08-05 07:43:09 +08:00
|
|
|
input_text = None
|
2024-07-20 14:10:01 +08:00
|
|
|
input_ids = obj.input_ids[i]
|
|
|
|
|
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
2024-05-12 06:41:32 -07:00
|
|
|
|
2024-08-10 08:39:05 -07:00
|
|
|
if self.is_generation:
|
2024-09-28 23:28:55 -07:00
|
|
|
image_inputs = await self._get_image_inputs(
|
|
|
|
|
obj, obj.image_data[index]
|
2024-08-10 08:39:05 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
tokenized_obj = TokenizedGenerateReqInput(
|
|
|
|
|
rid,
|
|
|
|
|
input_text,
|
|
|
|
|
input_ids,
|
2024-09-28 23:28:55 -07:00
|
|
|
image_inputs,
|
2024-08-10 08:39:05 -07:00
|
|
|
sampling_params,
|
|
|
|
|
obj.return_logprob[index],
|
|
|
|
|
obj.logprob_start_len[index],
|
|
|
|
|
obj.top_logprobs_num[index],
|
|
|
|
|
obj.stream,
|
2024-09-12 16:46:14 -07:00
|
|
|
(
|
|
|
|
|
obj.lora_path[index]
|
|
|
|
|
if isinstance(obj.lora_path, list)
|
|
|
|
|
else obj.lora_path
|
|
|
|
|
),
|
2024-08-10 08:39:05 -07:00
|
|
|
)
|
2024-09-27 23:32:11 -07:00
|
|
|
elif isinstance(obj, EmbeddingReqInput):
|
2024-08-10 08:39:05 -07:00
|
|
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
|
|
|
|
rid,
|
|
|
|
|
input_text,
|
|
|
|
|
input_ids,
|
|
|
|
|
sampling_params,
|
|
|
|
|
)
|
2024-09-27 23:32:11 -07:00
|
|
|
else:
|
|
|
|
|
assert isinstance(obj, RewardReqInput)
|
|
|
|
|
tokenized_obj = TokenizedRewardReqInput(
|
|
|
|
|
rid,
|
|
|
|
|
input_text,
|
|
|
|
|
input_ids,
|
|
|
|
|
sampling_params,
|
|
|
|
|
)
|
2024-09-29 02:36:12 -07:00
|
|
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
2024-01-08 04:37:50 +00:00
|
|
|
|
|
|
|
|
event = asyncio.Event()
|
2024-03-24 01:15:16 +08:00
|
|
|
state = ReqState([], False, event)
|
2024-01-08 04:37:50 +00:00
|
|
|
self.rid_to_state[rid] = state
|
2024-08-10 08:39:05 -07:00
|
|
|
|
2024-08-20 08:06:55 -07:00
|
|
|
generators.append(
|
|
|
|
|
self._wait_for_response(
|
|
|
|
|
state,
|
|
|
|
|
obj,
|
|
|
|
|
rid,
|
|
|
|
|
request,
|
|
|
|
|
index=index,
|
|
|
|
|
response_index=len(generators),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Then process the responses based on streaming option
|
|
|
|
|
is_stream = hasattr(obj, "stream") and obj.stream
|
|
|
|
|
|
|
|
|
|
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
|
2024-08-24 21:43:03 -07:00
|
|
|
output_list = [None] * len(tasks)
|
2024-08-20 08:06:55 -07:00
|
|
|
|
2024-08-28 06:33:05 -07:00
|
|
|
# Recv results
|
2024-08-20 08:06:55 -07:00
|
|
|
while tasks:
|
|
|
|
|
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
|
|
|
|
|
|
|
|
|
for task in done:
|
2024-08-24 21:43:03 -07:00
|
|
|
cur_index = tasks.index(task)
|
2024-08-20 08:06:55 -07:00
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
result = task.result()
|
|
|
|
|
|
|
|
|
|
if is_stream:
|
|
|
|
|
yield result
|
|
|
|
|
else:
|
2024-08-24 21:43:03 -07:00
|
|
|
output_list[result["index"]] = result
|
2024-08-20 08:06:55 -07:00
|
|
|
|
2024-08-24 21:43:03 -07:00
|
|
|
tasks[cur_index] = asyncio.create_task(
|
|
|
|
|
generators[cur_index].__anext__()
|
2024-05-14 22:40:46 +08:00
|
|
|
)
|
2024-08-20 08:06:55 -07:00
|
|
|
except StopAsyncIteration:
|
2024-08-24 21:43:03 -07:00
|
|
|
del generators[cur_index]
|
|
|
|
|
del tasks[cur_index]
|
2024-08-20 08:06:55 -07:00
|
|
|
|
|
|
|
|
if not is_stream:
|
|
|
|
|
yield output_list
|
2024-07-20 14:10:01 +08:00
|
|
|
|
2024-07-27 05:05:15 -07:00
|
|
|
def _validate_input_length(self, input_ids: List[int]):
|
2024-07-20 14:10:01 +08:00
|
|
|
if len(input_ids) >= self.context_len:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"The input ({len(input_ids)} tokens) is longer than the "
|
|
|
|
|
f"model's context length ({self.context_len} tokens)."
|
|
|
|
|
)
|
|
|
|
|
|
2024-07-27 05:05:15 -07:00
|
|
|
def _get_sampling_params(self, sampling_params_data: dict):
|
2024-07-20 14:10:01 +08:00
|
|
|
sampling_params = SamplingParams(**sampling_params_data)
|
|
|
|
|
if sampling_params.max_new_tokens != 0:
|
|
|
|
|
sampling_params.normalize(self.tokenizer)
|
|
|
|
|
sampling_params.verify()
|
|
|
|
|
return sampling_params
|
|
|
|
|
|
2024-07-27 05:05:15 -07:00
|
|
|
async def _wait_for_response(
|
|
|
|
|
self,
|
|
|
|
|
state: ReqState,
|
2024-09-27 23:32:11 -07:00
|
|
|
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
|
2024-07-27 05:05:15 -07:00
|
|
|
rid: str,
|
2024-08-28 06:33:05 -07:00
|
|
|
request: Optional[fastapi.Request] = None,
|
|
|
|
|
index: Optional[int] = None,
|
2024-08-20 08:06:55 -07:00
|
|
|
response_index: int = 0,
|
2024-07-27 05:05:15 -07:00
|
|
|
):
|
2024-07-20 14:10:01 +08:00
|
|
|
while True:
|
|
|
|
|
try:
|
2024-08-28 06:33:05 -07:00
|
|
|
await asyncio.wait_for(state.event.wait(), timeout=4)
|
2024-07-20 14:10:01 +08:00
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
if request is not None and await request.is_disconnected():
|
2024-08-20 08:06:55 -07:00
|
|
|
for rid in [obj.rid] if obj.is_single else obj.rid:
|
|
|
|
|
self.abort_request(rid)
|
2024-07-20 14:10:01 +08:00
|
|
|
raise ValueError(f"Abort request {rid}")
|
|
|
|
|
continue
|
|
|
|
|
|
2024-08-08 16:31:19 -07:00
|
|
|
if self.is_generation:
|
|
|
|
|
out = self.convert_logprob_style(
|
|
|
|
|
state.out_list[-1],
|
2024-08-20 08:06:55 -07:00
|
|
|
obj.return_logprob if index is None else obj.return_logprob[index],
|
|
|
|
|
(
|
|
|
|
|
obj.top_logprobs_num
|
|
|
|
|
if index is None
|
|
|
|
|
else obj.top_logprobs_num[index]
|
|
|
|
|
),
|
2024-08-08 16:31:19 -07:00
|
|
|
obj.return_text_in_logprobs,
|
|
|
|
|
)
|
2024-09-27 23:32:11 -07:00
|
|
|
else: # isinstance(obj, (EmbeddingReqInput, RewardReqInput))
|
2024-08-08 16:31:19 -07:00
|
|
|
out = state.out_list[-1]
|
2024-07-20 14:10:01 +08:00
|
|
|
|
2024-08-20 08:06:55 -07:00
|
|
|
out["index"] = response_index
|
|
|
|
|
|
2024-08-04 20:51:55 -07:00
|
|
|
# Log requests
|
2024-07-20 14:10:01 +08:00
|
|
|
if self.server_args.log_requests and state.finished:
|
2024-08-21 19:24:36 -07:00
|
|
|
logger.info(f"in={obj}, out={out}")
|
2024-07-20 14:10:01 +08:00
|
|
|
|
|
|
|
|
state.out_list = []
|
|
|
|
|
if state.finished:
|
|
|
|
|
del self.rid_to_state[rid]
|
|
|
|
|
yield out
|
|
|
|
|
break
|
|
|
|
|
|
2024-08-28 06:33:05 -07:00
|
|
|
state.event.clear()
|
2024-07-20 14:10:01 +08:00
|
|
|
yield out
|
|
|
|
|
|
2024-07-27 05:05:15 -07:00
|
|
|
async def _wait_for_cache_prefill_response(
|
|
|
|
|
self,
|
|
|
|
|
state: ReqState,
|
|
|
|
|
obj: GenerateReqInput,
|
|
|
|
|
rid: str,
|
2024-08-28 06:33:05 -07:00
|
|
|
request: Optional[fastapi.Request] = None,
|
2024-07-27 05:05:15 -07:00
|
|
|
):
|
2024-07-20 14:10:01 +08:00
|
|
|
while True:
|
|
|
|
|
try:
|
|
|
|
|
await asyncio.wait_for(state.event.wait(), timeout=4)
|
|
|
|
|
break
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
if request is not None and await request.is_disconnected():
|
|
|
|
|
for rid in obj.rid:
|
|
|
|
|
self.abort_request(rid)
|
|
|
|
|
raise ValueError(f"Abort request {rid}")
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
assert state.finished
|
|
|
|
|
del self.rid_to_state[rid]
|
2024-01-08 04:37:50 +00:00
|
|
|
|
2024-05-17 05:49:31 -07:00
|
|
|
def flush_cache(self):
|
|
|
|
|
req = FlushCacheReq()
|
2024-09-29 02:36:12 -07:00
|
|
|
self.send_to_scheduler.send_pyobj(req)
|
2024-01-26 13:32:59 +08:00
|
|
|
|
2024-08-25 14:46:34 -07:00
|
|
|
def abort_request(self, rid: str):
|
|
|
|
|
if rid not in self.rid_to_state:
|
|
|
|
|
return
|
|
|
|
|
del self.rid_to_state[rid]
|
|
|
|
|
req = AbortReq(rid)
|
2024-09-29 02:36:12 -07:00
|
|
|
self.send_to_scheduler.send_pyobj(req)
|
2024-08-25 14:46:34 -07:00
|
|
|
|
2024-08-28 06:33:05 -07:00
|
|
|
async def update_weights(
|
|
|
|
|
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
|
|
|
|
):
|
2024-08-20 13:48:24 -07:00
|
|
|
if self.to_create_loop:
|
|
|
|
|
self.create_handle_loop()
|
|
|
|
|
|
|
|
|
|
# default the load format to the server_args
|
|
|
|
|
if obj.load_format is None:
|
|
|
|
|
obj.load_format = self.server_args.load_format
|
|
|
|
|
|
|
|
|
|
if not self.model_update_lock.locked():
|
|
|
|
|
async with self.model_update_lock:
|
|
|
|
|
# wait for the previous generation requests to finish
|
|
|
|
|
while len(self.rid_to_state) > 0:
|
|
|
|
|
await asyncio.sleep(0)
|
2024-09-29 02:36:12 -07:00
|
|
|
self.send_to_scheduler.send_pyobj(obj)
|
2024-08-20 13:48:24 -07:00
|
|
|
self.model_update_result = asyncio.Future()
|
|
|
|
|
result = await self.model_update_result
|
|
|
|
|
if result.success:
|
|
|
|
|
self.server_args.model_path = obj.model_path
|
|
|
|
|
self.server_args.load_format = obj.load_format
|
|
|
|
|
self.model_path = obj.model_path
|
|
|
|
|
return result.success, result.message
|
|
|
|
|
else:
|
|
|
|
|
return False, "Another update is in progress. Please try again later."
|
|
|
|
|
|
2024-06-08 02:06:52 -07:00
|
|
|
def create_abort_task(self, obj: GenerateReqInput):
|
2024-05-20 18:41:21 -07:00
|
|
|
# Abort the request if the client is disconnected.
|
|
|
|
|
async def abort_request():
|
|
|
|
|
await asyncio.sleep(3)
|
|
|
|
|
if obj.is_single:
|
|
|
|
|
self.abort_request(obj.rid)
|
|
|
|
|
else:
|
2024-08-13 20:47:22 +08:00
|
|
|
for rid in obj.rid:
|
2024-05-20 18:41:21 -07:00
|
|
|
self.abort_request(rid)
|
|
|
|
|
|
|
|
|
|
background_tasks = BackgroundTasks()
|
|
|
|
|
background_tasks.add_task(abort_request)
|
|
|
|
|
return background_tasks
|
|
|
|
|
|
2024-05-17 05:49:31 -07:00
|
|
|
def create_handle_loop(self):
|
2024-08-24 08:02:23 -07:00
|
|
|
if not self.to_create_loop:
|
|
|
|
|
return
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
self.to_create_loop = False
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
loop.create_task(self.handle_loop())
|
|
|
|
|
|
|
|
|
|
async def handle_loop(self):
|
2024-08-25 14:46:34 -07:00
|
|
|
"""The event loop that handles requests"""
|
|
|
|
|
|
2024-01-08 04:37:50 +00:00
|
|
|
while True:
|
2024-08-20 13:48:24 -07:00
|
|
|
recv_obj: Union[
|
|
|
|
|
BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
|
|
|
|
|
] = await self.recv_from_detokenizer.recv_pyobj()
|
|
|
|
|
|
|
|
|
|
if isinstance(recv_obj, UpdateWeightReqOutput):
|
|
|
|
|
self.model_update_result.set_result(recv_obj)
|
|
|
|
|
continue
|
|
|
|
|
|
2024-08-10 03:14:13 +08:00
|
|
|
assert isinstance(
|
|
|
|
|
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
|
|
|
|
|
), f"Unexpected obj received: {type(recv_obj)}"
|
2024-08-20 13:48:24 -07:00
|
|
|
|
2024-06-08 04:20:40 +08:00
|
|
|
for i, rid in enumerate(recv_obj.rids):
|
|
|
|
|
state = self.rid_to_state.get(rid, None)
|
|
|
|
|
if state is None:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
recv_obj.meta_info[i]["id"] = rid
|
2024-08-08 16:31:19 -07:00
|
|
|
if isinstance(recv_obj, BatchStrOut):
|
|
|
|
|
out_dict = {
|
|
|
|
|
"text": recv_obj.output_strs[i],
|
|
|
|
|
"meta_info": recv_obj.meta_info[i],
|
|
|
|
|
}
|
2024-08-10 03:14:13 +08:00
|
|
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
|
|
|
|
read_start = 0 if i == 0 else recv_obj.read_offsets[i - 1]
|
|
|
|
|
out_dict = {
|
|
|
|
|
"token_ids": recv_obj.decode_ids[
|
|
|
|
|
read_start : recv_obj.read_offsets[i]
|
|
|
|
|
],
|
|
|
|
|
"meta_info": recv_obj.meta_info[i],
|
|
|
|
|
}
|
|
|
|
|
|
2024-08-08 16:31:19 -07:00
|
|
|
else:
|
|
|
|
|
assert isinstance(recv_obj, BatchEmbeddingOut)
|
|
|
|
|
out_dict = {
|
|
|
|
|
"embedding": recv_obj.embeddings[i],
|
|
|
|
|
"meta_info": recv_obj.meta_info[i],
|
|
|
|
|
}
|
2024-06-08 04:20:40 +08:00
|
|
|
state.out_list.append(out_dict)
|
|
|
|
|
state.finished = recv_obj.finished_reason[i] is not None
|
|
|
|
|
state.event.set()
|
2024-05-17 05:49:31 -07:00
|
|
|
|
2024-05-14 22:40:46 +08:00
|
|
|
def convert_logprob_style(
|
2024-07-27 05:05:15 -07:00
|
|
|
self,
|
|
|
|
|
ret: dict,
|
|
|
|
|
return_logprob: bool,
|
|
|
|
|
top_logprobs_num: int,
|
|
|
|
|
return_text_in_logprobs: bool,
|
2024-05-14 22:40:46 +08:00
|
|
|
):
|
2024-05-12 04:54:07 -07:00
|
|
|
if return_logprob:
|
2024-07-27 19:50:34 -07:00
|
|
|
ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
|
|
|
|
|
ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
|
2024-05-12 04:54:07 -07:00
|
|
|
)
|
2024-07-27 19:50:34 -07:00
|
|
|
ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
|
|
|
|
|
ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
|
2024-05-12 04:54:07 -07:00
|
|
|
)
|
2024-07-09 15:35:39 +08:00
|
|
|
|
|
|
|
|
if top_logprobs_num > 0:
|
2024-07-27 19:50:34 -07:00
|
|
|
ret["meta_info"]["input_top_logprobs"] = (
|
2024-07-18 04:55:39 +10:00
|
|
|
self.detokenize_top_logprobs_tokens(
|
2024-07-27 19:50:34 -07:00
|
|
|
ret["meta_info"]["input_top_logprobs"],
|
2024-07-18 04:55:39 +10:00
|
|
|
return_text_in_logprobs,
|
|
|
|
|
)
|
2024-07-09 15:35:39 +08:00
|
|
|
)
|
2024-07-27 19:50:34 -07:00
|
|
|
ret["meta_info"]["output_top_logprobs"] = (
|
2024-07-18 04:55:39 +10:00
|
|
|
self.detokenize_top_logprobs_tokens(
|
2024-07-27 19:50:34 -07:00
|
|
|
ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
|
2024-07-18 04:55:39 +10:00
|
|
|
)
|
2024-07-09 15:35:39 +08:00
|
|
|
)
|
2024-05-12 04:54:07 -07:00
|
|
|
return ret
|
|
|
|
|
|
2024-07-28 05:22:14 -07:00
|
|
|
def detokenize_logprob_tokens(
|
|
|
|
|
self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
|
|
|
|
|
):
|
2024-09-29 02:36:12 -07:00
|
|
|
# TODO(lianmin): This should run on DetokenizerManager
|
2024-05-12 04:54:07 -07:00
|
|
|
if not decode_to_text:
|
|
|
|
|
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
|
|
|
|
|
|
2024-08-10 03:14:13 +08:00
|
|
|
assert self.tokenizer is not None
|
2024-05-12 04:54:07 -07:00
|
|
|
token_ids = [tid for _, tid in token_logprobs]
|
|
|
|
|
token_texts = self.tokenizer.batch_decode(token_ids)
|
|
|
|
|
return [
|
|
|
|
|
(logprob, token_id, token_text)
|
|
|
|
|
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
|
|
|
|
|
]
|
|
|
|
|
|
2024-07-27 05:05:15 -07:00
|
|
|
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
|
2024-07-28 05:22:14 -07:00
|
|
|
# TODO: The current implementation only batches the detokenization for top-k tokens per single position.
|
|
|
|
|
# We should batch all top-k tokens in all positions.
|
|
|
|
|
for i, token_top_logprobs in enumerate(top_logprobs):
|
|
|
|
|
if token_top_logprobs:
|
|
|
|
|
top_logprobs[i] = self.detokenize_logprob_tokens(
|
|
|
|
|
token_top_logprobs, decode_to_text
|
|
|
|
|
)
|
2024-05-12 04:54:07 -07:00
|
|
|
return top_logprobs
|
2024-05-17 05:49:31 -07:00
|
|
|
|
2024-09-28 23:28:55 -07:00
|
|
|
async def _get_image_inputs(self, obj, image_data: List[Union[str, bytes]]):
|
2024-08-28 06:33:05 -07:00
|
|
|
if not image_data:
|
2024-09-28 23:28:55 -07:00
|
|
|
return None
|
2024-08-28 06:33:05 -07:00
|
|
|
|
2024-09-28 23:28:55 -07:00
|
|
|
# TODO: move this into a processor for each vision architecture
|
2024-08-28 06:33:05 -07:00
|
|
|
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
2024-08-25 14:46:34 -07:00
|
|
|
grid_pinpoints = (
|
|
|
|
|
self.hf_config.image_grid_pinpoints
|
|
|
|
|
if hasattr(self.hf_config, "image_grid_pinpoints")
|
|
|
|
|
and "anyres" in aspect_ratio
|
|
|
|
|
else None
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if isinstance(image_data, list) and len(image_data) > 0:
|
2024-08-28 06:33:05 -07:00
|
|
|
# Multiple images
|
2024-08-25 14:46:34 -07:00
|
|
|
if len(image_data) > 1:
|
|
|
|
|
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
|
2024-08-28 06:33:05 -07:00
|
|
|
pixel_values, image_hashes, image_sizes = [], [], []
|
2024-08-25 14:46:34 -07:00
|
|
|
for img_data in image_data:
|
|
|
|
|
pixel_v, image_h, image_s = await self._process_single_image(
|
|
|
|
|
img_data, aspect_ratio, grid_pinpoints
|
|
|
|
|
)
|
|
|
|
|
pixel_values.append(pixel_v)
|
2024-08-28 06:33:05 -07:00
|
|
|
image_hashes.append(image_h)
|
|
|
|
|
image_sizes.append(image_s)
|
|
|
|
|
|
|
|
|
|
if isinstance(pixel_values[0], np.ndarray):
|
|
|
|
|
pixel_values = np.stack(pixel_values, axis=0)
|
2024-08-25 14:46:34 -07:00
|
|
|
else:
|
2024-08-28 06:33:05 -07:00
|
|
|
# A single image
|
2024-08-25 14:46:34 -07:00
|
|
|
pixel_values, image_hash, image_size = await self._process_single_image(
|
|
|
|
|
image_data[0], aspect_ratio, grid_pinpoints
|
|
|
|
|
)
|
2024-08-28 06:33:05 -07:00
|
|
|
image_hashes = [image_hash]
|
|
|
|
|
image_sizes = [image_size]
|
2024-08-25 14:46:34 -07:00
|
|
|
elif isinstance(image_data, str):
|
2024-08-28 06:33:05 -07:00
|
|
|
# A single image
|
2024-08-25 14:46:34 -07:00
|
|
|
pixel_values, image_hash, image_size = await self._process_single_image(
|
|
|
|
|
image_data, aspect_ratio, grid_pinpoints
|
|
|
|
|
)
|
2024-08-28 06:33:05 -07:00
|
|
|
image_hashes = [image_hash]
|
|
|
|
|
image_sizes = [image_size]
|
2024-08-25 14:46:34 -07:00
|
|
|
else:
|
2024-08-28 06:33:05 -07:00
|
|
|
raise ValueError(f"Invalid image data: {image_data}")
|
2024-08-25 14:46:34 -07:00
|
|
|
|
2024-09-28 23:28:55 -07:00
|
|
|
return {
|
|
|
|
|
"pixel_values": pixel_values,
|
|
|
|
|
"image_hashes": image_hashes,
|
|
|
|
|
"image_sizes": image_sizes,
|
|
|
|
|
"modalities": obj.modalities,
|
|
|
|
|
}
|
2024-08-25 14:46:34 -07:00
|
|
|
|
2024-08-28 06:33:05 -07:00
|
|
|
async def _process_single_image(
|
|
|
|
|
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
|
|
|
|
|
):
|
2024-08-25 14:46:34 -07:00
|
|
|
if self.executor is not None:
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
return await loop.run_in_executor(
|
|
|
|
|
self.executor,
|
|
|
|
|
_process_single_image_task,
|
|
|
|
|
image_data,
|
|
|
|
|
aspect_ratio,
|
|
|
|
|
grid_pinpoints,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
return _process_single_image_task(
|
|
|
|
|
image_data, aspect_ratio, grid_pinpoints, self.processor
|
|
|
|
|
)
|
|
|
|
|
|
2024-05-17 05:49:31 -07:00
|
|
|
|
|
|
|
|
global global_processor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_global_processor(server_args: ServerArgs):
|
2024-08-25 14:46:34 -07:00
|
|
|
"""Init the global processor for multi modal models."""
|
2024-05-17 05:49:31 -07:00
|
|
|
global global_processor
|
|
|
|
|
transformers.logging.set_verbosity_error()
|
|
|
|
|
global_processor = get_processor(
|
|
|
|
|
server_args.tokenizer_path,
|
|
|
|
|
tokenizer_mode=server_args.tokenizer_mode,
|
|
|
|
|
trust_remote_code=server_args.trust_remote_code,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2024-08-25 14:46:34 -07:00
|
|
|
def _process_single_image_task(
|
2024-08-28 06:33:05 -07:00
|
|
|
image_data: Union[str, bytes],
|
|
|
|
|
image_aspect_ratio: Optional[str] = None,
|
|
|
|
|
image_grid_pinpoints: Optional[str] = None,
|
|
|
|
|
processor=None,
|
2024-05-17 05:49:31 -07:00
|
|
|
):
|
|
|
|
|
try:
|
|
|
|
|
processor = processor or global_processor
|
|
|
|
|
image, image_size = load_image(image_data)
|
2024-07-09 15:35:39 +08:00
|
|
|
if image_size is not None:
|
2024-08-28 06:33:05 -07:00
|
|
|
# It is a video with multiple images
|
2024-05-17 05:49:31 -07:00
|
|
|
image_hash = hash(image_data)
|
|
|
|
|
pixel_values = processor.image_processor(image)["pixel_values"]
|
|
|
|
|
for _ in range(len(pixel_values)):
|
|
|
|
|
pixel_values[_] = pixel_values[_].astype(np.float16)
|
|
|
|
|
pixel_values = np.stack(pixel_values, axis=0)
|
|
|
|
|
return pixel_values, image_hash, image_size
|
|
|
|
|
else:
|
2024-08-28 06:33:05 -07:00
|
|
|
# It is an image
|
2024-05-17 05:49:31 -07:00
|
|
|
image_hash = hash(image_data)
|
|
|
|
|
if image_aspect_ratio == "pad":
|
|
|
|
|
image = expand2square(
|
|
|
|
|
image,
|
|
|
|
|
tuple(int(x * 255) for x in processor.image_processor.image_mean),
|
|
|
|
|
)
|
2024-08-26 01:28:23 +08:00
|
|
|
pixel_values = processor.image_processor(image.convert("RGB"))[
|
|
|
|
|
"pixel_values"
|
|
|
|
|
][0]
|
2024-08-28 06:33:05 -07:00
|
|
|
elif image_aspect_ratio == "anyres" or (
|
|
|
|
|
image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio
|
|
|
|
|
):
|
2024-05-17 05:49:31 -07:00
|
|
|
pixel_values = process_anyres_image(
|
|
|
|
|
image, processor.image_processor, image_grid_pinpoints
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
2024-08-28 06:33:05 -07:00
|
|
|
|
|
|
|
|
if isinstance(pixel_values, np.ndarray):
|
|
|
|
|
pixel_values = pixel_values.astype(np.float16)
|
|
|
|
|
|
2024-05-17 05:49:31 -07:00
|
|
|
return pixel_values, image_hash, image.size
|
|
|
|
|
except Exception:
|
2024-08-25 14:46:34 -07:00
|
|
|
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
|