359 lines
13 KiB
Python
359 lines
13 KiB
Python
import asyncio
|
|
import concurrent.futures
|
|
import dataclasses
|
|
import logging
|
|
import multiprocessing as mp
|
|
import os
|
|
from typing import List
|
|
|
|
import numpy as np
|
|
import transformers
|
|
import uvloop
|
|
import zmq
|
|
import zmq.asyncio
|
|
|
|
from sglang.srt.hf_transformers_utils import (
|
|
get_config,
|
|
get_context_length,
|
|
get_processor,
|
|
get_tokenizer,
|
|
)
|
|
from sglang.srt.managers.io_struct import (
|
|
BatchStrOut,
|
|
FlushCacheReq,
|
|
GenerateReqInput,
|
|
TokenizedGenerateReqInput,
|
|
)
|
|
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
|
from sglang.srt.sampling_params import SamplingParams
|
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
|
from sglang.srt.utils import is_multimodal_model, load_image
|
|
from sglang.utils import get_exception_traceback
|
|
|
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ReqState:
|
|
out_list: List
|
|
finished: bool
|
|
event: asyncio.Event
|
|
|
|
|
|
global global_processor
|
|
|
|
|
|
def init_global_processor(server_args: ServerArgs):
|
|
global global_processor
|
|
transformers.logging.set_verbosity_error()
|
|
global_processor = get_processor(
|
|
server_args.tokenizer_path,
|
|
tokenizer_mode=server_args.tokenizer_mode,
|
|
trust_remote_code=server_args.trust_remote_code,
|
|
)
|
|
|
|
|
|
def get_pixel_values(
|
|
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
|
|
):
|
|
try:
|
|
processor = processor or global_processor
|
|
image, image_size = load_image(image_data)
|
|
if image_size != None:
|
|
image_hash = hash(image_data)
|
|
pixel_values = processor.image_processor(image)["pixel_values"]
|
|
for _ in range(len(pixel_values)):
|
|
pixel_values[_] = pixel_values[_].astype(np.float16)
|
|
pixel_values = np.stack(pixel_values, axis=0)
|
|
return pixel_values, image_hash, image_size
|
|
else:
|
|
image_hash = hash(image_data)
|
|
if image_aspect_ratio == "pad":
|
|
image = expand2square(
|
|
image,
|
|
tuple(int(x * 255) for x in processor.image_processor.image_mean),
|
|
)
|
|
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
|
elif image_aspect_ratio == "anyres":
|
|
pixel_values = process_anyres_image(
|
|
image, processor.image_processor, image_grid_pinpoints
|
|
)
|
|
else:
|
|
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
|
pixel_values = pixel_values.astype(np.float16)
|
|
return pixel_values, image_hash, image.size
|
|
except Exception:
|
|
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
|
|
|
|
|
class TokenizerManager:
|
|
def __init__(
|
|
self,
|
|
server_args: ServerArgs,
|
|
port_args: PortArgs,
|
|
model_overide_args: dict = None,
|
|
):
|
|
self.server_args = server_args
|
|
|
|
context = zmq.asyncio.Context(2)
|
|
self.recv_from_detokenizer = context.socket(zmq.PULL)
|
|
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
|
|
|
self.send_to_router = context.socket(zmq.PUSH)
|
|
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.router_port}")
|
|
|
|
self.model_path = server_args.model_path
|
|
self.hf_config = get_config(
|
|
self.model_path,
|
|
trust_remote_code=server_args.trust_remote_code,
|
|
model_overide_args=model_overide_args,
|
|
)
|
|
self.context_len = get_context_length(self.hf_config)
|
|
|
|
if is_multimodal_model(self.model_path):
|
|
self.processor = get_processor(
|
|
server_args.tokenizer_path,
|
|
tokenizer_mode=server_args.tokenizer_mode,
|
|
trust_remote_code=server_args.trust_remote_code,
|
|
)
|
|
self.tokenizer = self.processor.tokenizer
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
self.executor = concurrent.futures.ProcessPoolExecutor(
|
|
initializer=init_global_processor,
|
|
mp_context=mp.get_context("fork"),
|
|
initargs=(server_args,),
|
|
)
|
|
else:
|
|
self.tokenizer = get_tokenizer(
|
|
server_args.tokenizer_path,
|
|
tokenizer_mode=server_args.tokenizer_mode,
|
|
trust_remote_code=server_args.trust_remote_code,
|
|
)
|
|
|
|
self.to_create_loop = True
|
|
self.rid_to_state = {} # Dict[str -> ReqState]
|
|
|
|
async def get_pixel_values(self, image_data):
|
|
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
|
grid_pinpoints = (
|
|
self.hf_config.image_grid_pinpoints if aspect_ratio == "anyres" else None
|
|
)
|
|
if self.executor is not None:
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(
|
|
self.executor,
|
|
get_pixel_values,
|
|
image_data,
|
|
aspect_ratio,
|
|
grid_pinpoints,
|
|
)
|
|
else:
|
|
return get_pixel_values(
|
|
image_data, aspect_ratio, grid_pinpoints, self.processor
|
|
)
|
|
|
|
async def generate_request(self, obj: GenerateReqInput):
|
|
if self.to_create_loop:
|
|
await self.create_handle_loop()
|
|
|
|
is_single = obj.is_single
|
|
if is_single:
|
|
rid = obj.rid
|
|
|
|
if obj.input_ids is None:
|
|
input_ids = self.tokenizer.encode(obj.text)
|
|
else:
|
|
input_ids = obj.input_ids
|
|
|
|
if len(input_ids) >= self.context_len:
|
|
raise ValueError(
|
|
f"The input ({len(input_ids)} tokens) is longer than the "
|
|
f"model's context length ({self.context_len} tokens)"
|
|
)
|
|
|
|
sampling_params = SamplingParams(**obj.sampling_params)
|
|
if sampling_params.max_new_tokens != 0:
|
|
sampling_params.normalize(self.tokenizer)
|
|
sampling_params.verify()
|
|
|
|
if isinstance(obj.image_data, list) and len(obj.image_data) > 0:
|
|
pixel_values, image_hash, image_size = await self.get_pixel_values(
|
|
obj.image_data[0]
|
|
)
|
|
elif isinstance(obj.image_data, str):
|
|
pixel_values, image_hash, image_size = await self.get_pixel_values(
|
|
obj.image_data
|
|
)
|
|
else:
|
|
pixel_values, image_hash, image_size = None, None, None
|
|
tokenized_obj = TokenizedGenerateReqInput(
|
|
rid=rid,
|
|
input_text=obj.text,
|
|
input_ids=input_ids,
|
|
pixel_values=pixel_values,
|
|
image_hash=image_hash,
|
|
image_size=image_size,
|
|
sampling_params=sampling_params,
|
|
return_logprob=obj.return_logprob,
|
|
logprob_start_len=obj.logprob_start_len,
|
|
top_logprobs_num=obj.top_logprobs_num,
|
|
stream=obj.stream,
|
|
)
|
|
self.send_to_router.send_pyobj(tokenized_obj)
|
|
|
|
event = asyncio.Event()
|
|
state = ReqState([], False, event)
|
|
self.rid_to_state[rid] = state
|
|
|
|
while True:
|
|
await event.wait()
|
|
out = self.convert_logprob_style(
|
|
state.out_list[-1],
|
|
obj.return_logprob,
|
|
obj.top_logprobs_num,
|
|
obj.return_text_in_logprobs,
|
|
)
|
|
|
|
if self.server_args.log_requests and state.finished:
|
|
logger.info(f"in={obj.text}, out={out}")
|
|
|
|
yield out
|
|
state.out_list = []
|
|
if state.finished:
|
|
del self.rid_to_state[rid]
|
|
break
|
|
event.clear()
|
|
else:
|
|
assert obj.stream is False
|
|
|
|
if obj.input_ids is None:
|
|
bs = len(obj.text)
|
|
else:
|
|
bs = len(obj.input_ids)
|
|
|
|
for i in range(bs):
|
|
rid = obj.rid[i]
|
|
|
|
if obj.input_ids is None:
|
|
input_text = obj.text[i]
|
|
input_ids = self.tokenizer.encode(obj.text[i])
|
|
else:
|
|
input_text = None
|
|
input_ids = obj.input_ids[i]
|
|
|
|
sampling_params = SamplingParams(**obj.sampling_params[i])
|
|
if sampling_params.max_new_tokens != 0:
|
|
sampling_params.normalize(self.tokenizer)
|
|
sampling_params.verify()
|
|
if obj.image_data[i] is None:
|
|
pixel_values, image_hash, image_size = None, None, None
|
|
else:
|
|
pixel_values, image_hash, image_size = await self.get_pixel_values(
|
|
obj.image_data[i]
|
|
)
|
|
tokenized_obj = TokenizedGenerateReqInput(
|
|
rid=rid,
|
|
input_text=input_text,
|
|
input_ids=input_ids,
|
|
pixel_values=pixel_values,
|
|
image_hash=image_hash,
|
|
image_size=image_size,
|
|
sampling_params=sampling_params,
|
|
return_logprob=obj.return_logprob[i],
|
|
logprob_start_len=obj.logprob_start_len[i],
|
|
top_logprobs_num=obj.top_logprobs_num[i],
|
|
stream=obj.stream,
|
|
)
|
|
self.send_to_router.send_pyobj(tokenized_obj)
|
|
|
|
event = asyncio.Event()
|
|
state = ReqState([], False, event)
|
|
self.rid_to_state[rid] = state
|
|
|
|
output_list = []
|
|
for i in range(bs):
|
|
rid = obj.rid[i]
|
|
state = self.rid_to_state[rid]
|
|
await state.event.wait()
|
|
output_list.append(
|
|
self.convert_logprob_style(
|
|
state.out_list[-1],
|
|
obj.return_logprob[i],
|
|
obj.top_logprobs_num[i],
|
|
obj.return_text_in_logprobs,
|
|
)
|
|
)
|
|
assert state.finished
|
|
del self.rid_to_state[rid]
|
|
|
|
yield output_list
|
|
|
|
async def flush_cache(self):
|
|
flush_cache_req = FlushCacheReq()
|
|
self.send_to_router.send_pyobj(flush_cache_req)
|
|
|
|
async def create_handle_loop(self):
|
|
self.to_create_loop = False
|
|
loop = asyncio.get_event_loop()
|
|
loop.create_task(self.handle_loop())
|
|
|
|
async def handle_loop(self):
|
|
while True:
|
|
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
|
|
|
if isinstance(recv_obj, BatchStrOut):
|
|
for i, rid in enumerate(recv_obj.rids):
|
|
recv_obj.meta_info[i]["id"] = rid
|
|
out_dict = {
|
|
"text": recv_obj.output_str[i],
|
|
"meta_info": recv_obj.meta_info[i],
|
|
}
|
|
state = self.rid_to_state[rid]
|
|
state.out_list.append(out_dict)
|
|
state.finished = recv_obj.finished[i]
|
|
state.event.set()
|
|
else:
|
|
raise ValueError(f"Invalid object: {recv_obj}")
|
|
|
|
def convert_logprob_style(
|
|
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
|
|
):
|
|
if return_logprob:
|
|
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
|
|
ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
|
|
)
|
|
ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
|
|
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
|
|
)
|
|
if top_logprobs_num > 0:
|
|
ret["meta_info"]["prefill_top_logprobs"] = (
|
|
self.detokenize_top_logprobs_tokens(
|
|
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
|
|
)
|
|
)
|
|
ret["meta_info"]["decode_top_logprobs"] = (
|
|
self.detokenize_top_logprobs_tokens(
|
|
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
|
)
|
|
)
|
|
return ret
|
|
|
|
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
|
|
if not decode_to_text:
|
|
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
|
|
|
|
token_ids = [tid for _, tid in token_logprobs]
|
|
token_texts = self.tokenizer.batch_decode(token_ids)
|
|
return [
|
|
(logprob, token_id, token_text)
|
|
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
|
|
]
|
|
|
|
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text):
|
|
for i, t in enumerate(top_logprobs):
|
|
if t:
|
|
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
|
|
return top_logprobs
|