Files
sglang/python/sglang/srt/managers/tokenizer_manager.py

347 lines
13 KiB
Python
Raw Normal View History

import asyncio
import concurrent.futures
import dataclasses
import logging
2024-01-24 10:35:31 +00:00
import multiprocessing as mp
import os
from typing import List
import numpy as np
import transformers
import uvloop
import zmq
import zmq.asyncio
2024-04-22 22:38:09 +08:00
from sglang.srt.hf_transformers_utils import (
get_config,
get_context_length,
get_processor,
get_tokenizer,
)
from sglang.srt.managers.io_struct import (
BatchStrOut,
2024-02-06 12:24:55 -08:00
DetokenizeReqInput,
FlushCacheReq,
GenerateReqInput,
TokenizedGenerateReqInput,
)
from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_exception_traceback, is_multimodal_model, load_image
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class ReqState:
out_list: List
finished: bool
event: asyncio.Event
global global_processor
def init_global_processor(server_args: ServerArgs):
global global_processor
transformers.logging.set_verbosity_error()
global_processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
2024-01-24 03:33:34 -08:00
def get_pixel_values(
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
):
try:
processor = processor or global_processor
2024-05-14 07:57:00 +08:00
image, image_size = load_image(image_data)
if image_size != None:
image_hash = hash(image_data)
pixel_values = processor.image_processor(image)["pixel_values"]
for _ in range(len(pixel_values)):
pixel_values[_] = pixel_values[_].astype(np.float16)
pixel_values = np.stack(pixel_values, axis=0)
return pixel_values, image_hash, image_size
else:
2024-05-14 07:57:00 +08:00
image_hash = hash(image_data)
if image_aspect_ratio == "pad":
image = expand2square(
image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
)
pixel_values = processor.image_processor(image)["pixel_values"][0]
elif image_aspect_ratio == "anyres":
pixel_values = process_anyres_image(
image, processor.image_processor, image_grid_pinpoints
)
else:
pixel_values = processor.image_processor(image)["pixel_values"][0]
pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash, image.size
except Exception:
print("Exception in TokenizerManager:\n" + get_exception_traceback())
class TokenizerManager:
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
2024-05-14 07:57:00 +08:00
model_overide_args: dict = None,
):
2024-03-11 20:06:52 +08:00
self.server_args = server_args
context = zmq.asyncio.Context(2)
self.recv_from_detokenizer = context.socket(zmq.PULL)
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
self.send_to_router = context.socket(zmq.PUSH)
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.router_port}")
self.model_path = server_args.model_path
self.hf_config = get_config(
2024-05-14 07:57:00 +08:00
self.model_path,
trust_remote_code=server_args.trust_remote_code,
model_overide_args=model_overide_args,
)
self.context_len = get_context_length(self.hf_config)
if is_multimodal_model(self.model_path):
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.tokenizer = self.processor.tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
self.executor = concurrent.futures.ProcessPoolExecutor(
2024-01-24 10:35:31 +00:00
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
initargs=(server_args,),
)
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.to_create_loop = True
self.rid_to_state = {} # Dict[str -> ReqState]
async def get_pixel_values(self, image_data):
2024-01-24 10:45:44 +00:00
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
2024-01-24 03:33:34 -08:00
grid_pinpoints = (
self.hf_config.image_grid_pinpoints if aspect_ratio == "anyres" else None
)
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
2024-01-24 03:33:34 -08:00
self.executor,
get_pixel_values,
image_data,
aspect_ratio,
grid_pinpoints,
)
else:
2024-01-24 03:33:34 -08:00
return get_pixel_values(
image_data, aspect_ratio, grid_pinpoints, self.processor
)
async def generate_request(self, obj: GenerateReqInput):
if self.to_create_loop:
await self.create_handle_loop()
is_single = obj.is_single
if is_single:
rid = obj.rid
if obj.input_ids is None:
input_ids = self.tokenizer.encode(obj.text)
else:
input_ids = obj.input_ids
2024-05-13 15:56:00 -07:00
if len(input_ids) >= self.context_len:
raise ValueError(
f"The input ({len(input_ids)} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)"
)
sampling_params = SamplingParams(**obj.sampling_params)
if sampling_params.max_new_tokens != 0:
sampling_params.normalize(self.tokenizer)
sampling_params.verify()
if isinstance(obj.image_data, list) and len(obj.image_data) > 0:
pixel_values, image_hash, image_size = await self.get_pixel_values(
obj.image_data[0]
)
elif isinstance(obj.image_data, str):
pixel_values, image_hash, image_size = await self.get_pixel_values(
obj.image_data
)
else:
pixel_values, image_hash, image_size = None, None, None
tokenized_obj = TokenizedGenerateReqInput(
rid=rid,
input_text=obj.text,
input_ids=input_ids,
pixel_values=pixel_values,
image_hash=image_hash,
image_size=image_size,
sampling_params=sampling_params,
2024-01-23 05:07:30 -08:00
return_logprob=obj.return_logprob,
logprob_start_len=obj.logprob_start_len,
2024-03-28 14:34:49 +08:00
top_logprobs_num=obj.top_logprobs_num,
stream=obj.stream,
)
self.send_to_router.send_pyobj(tokenized_obj)
event = asyncio.Event()
state = ReqState([], False, event)
self.rid_to_state[rid] = state
while True:
await event.wait()
out = self.convert_logprob_style(state.out_list[-1],
2024-05-12 04:54:07 -07:00
obj.return_logprob,
obj.top_logprobs_num,
obj.return_text_in_logprobs)
if self.server_args.log_requests and state.finished:
logger.info(f"in={obj.text}, out={out}")
yield out
state.out_list = []
if state.finished:
del self.rid_to_state[rid]
break
event.clear()
else:
assert obj.stream is False
if obj.input_ids is None:
bs = len(obj.text)
else:
bs = len(obj.input_ids)
for i in range(bs):
rid = obj.rid[i]
if obj.input_ids is None:
input_text = obj.text[i]
input_ids = self.tokenizer.encode(obj.text[i])
else:
input_text = None
input_ids = obj.input_ids[i]
sampling_params = SamplingParams(**obj.sampling_params[i])
if sampling_params.max_new_tokens != 0:
sampling_params.normalize(self.tokenizer)
sampling_params.verify()
if obj.image_data[i] is None:
pixel_values, image_hash, image_size = None, None, None
else:
pixel_values, image_hash, image_size = await self.get_pixel_values(
obj.image_data[i]
)
tokenized_obj = TokenizedGenerateReqInput(
rid=rid,
input_text=input_text,
input_ids=input_ids,
pixel_values=pixel_values,
image_hash=image_hash,
image_size=image_size,
sampling_params=sampling_params,
2024-01-23 05:07:30 -08:00
return_logprob=obj.return_logprob[i],
logprob_start_len=obj.logprob_start_len[i],
2024-03-28 14:34:49 +08:00
top_logprobs_num=obj.top_logprobs_num[i],
stream=obj.stream,
)
self.send_to_router.send_pyobj(tokenized_obj)
event = asyncio.Event()
state = ReqState([], False, event)
self.rid_to_state[rid] = state
output_list = []
for i in range(bs):
rid = obj.rid[i]
state = self.rid_to_state[rid]
await state.event.wait()
2024-05-12 04:54:07 -07:00
output_list.append(
self.convert_logprob_style(state.out_list[-1],
obj.return_logprob[i],
obj.top_logprobs_num[i],
obj.return_text_in_logprobs))
assert state.finished
del self.rid_to_state[rid]
yield output_list
2024-01-26 13:32:59 +08:00
async def flush_cache(self):
flush_cache_req = FlushCacheReq()
self.send_to_router.send_pyobj(flush_cache_req)
async def create_handle_loop(self):
self.to_create_loop = False
loop = asyncio.get_event_loop()
loop.create_task(self.handle_loop())
async def handle_loop(self):
while True:
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
if isinstance(recv_obj, BatchStrOut):
for i, rid in enumerate(recv_obj.rids):
recv_obj.meta_info[i]["id"] = rid
out_dict = {
"text": recv_obj.output_str[i],
"meta_info": recv_obj.meta_info[i],
}
state = self.rid_to_state[rid]
state.out_list.append(out_dict)
state.finished = recv_obj.finished[i]
state.event.set()
else:
raise ValueError(f"Invalid object: {recv_obj}")
2024-05-12 04:54:07 -07:00
def convert_logprob_style(self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs):
if return_logprob:
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
)
ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
)
if top_logprobs_num > 0:
ret["meta_info"]["prefill_top_logprobs"] = self.detokenize_top_logprobs_tokens(
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
)
ret["meta_info"]["decode_top_logprobs"] = self.detokenize_top_logprobs_tokens(
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
)
return ret
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
if not decode_to_text:
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
token_ids = [tid for _, tid in token_logprobs]
token_texts = self.tokenizer.batch_decode(token_ids)
return [
(logprob, token_id, token_text)
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
]
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text):
for i, t in enumerate(top_logprobs):
if t:
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
return top_logprobs