release initial code

Co-authored-by: Ying Sheng <sqy1415@gmail.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
Co-authored-by: parasol-aser <3848358+parasol-aser@users.noreply.github.com>
Co-authored-by: LiviaSun <33578456+ChuyueSun@users.noreply.github.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
Lianmin Zheng
2024-01-08 04:37:50 +00:00
parent f6d40df0ee
commit 22085081bb
145 changed files with 17802 additions and 2 deletions

View File

@@ -0,0 +1,219 @@
import asyncio
import concurrent.futures
import dataclasses
import os
from typing import List
import numpy as np
import transformers
import uvloop
import zmq
import zmq.asyncio
from sglang.srt.hf_transformers_utils import (
get_config,
get_context_length,
get_processor,
get_tokenizer,
)
from sglang.srt.managers.io_struct import (
BatchStrOut,
GenerateReqInput,
TokenizedGenerateReqInput,
)
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_exception_traceback, is_multimodal_model, load_image
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@dataclasses.dataclass
class ReqState:
out_list: List
finished: bool
event: asyncio.Event
lock: asyncio.Lock
global global_processor
def init_global_processor(server_args: ServerArgs):
global global_processor
transformers.logging.set_verbosity_error()
global_processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
def get_pixel_values(image_data, processor=None):
try:
processor = processor or global_processor
image = load_image(image_data)
image_hash = hash(image_data)
pixel_values = processor.image_processor(image)["pixel_values"][0]
pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash
except Exception:
print("Exception in TokenizerManager:\n" + get_exception_traceback())
class TokenizerManager:
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
):
context = zmq.asyncio.Context(2)
self.recv_from_detokenizer = context.socket(zmq.PULL)
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
self.send_to_router = context.socket(zmq.PUSH)
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.router_port}")
self.model_path = server_args.model_path
self.hf_config = get_config(
self.model_path, trust_remote_code=server_args.trust_remote_code
)
self.context_len = get_context_length(self.hf_config)
if is_multimodal_model(self.model_path):
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.tokenizer = self.processor.tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor, initargs=(server_args,)
)
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.to_create_loop = True
self.rid_to_state = {} # Dict[str -> ReqState]
async def get_pixel_values(self, image_data):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor, get_pixel_values, image_data
)
else:
return get_pixel_values(image_data, self.processor)
async def generate_request(self, obj: GenerateReqInput):
if self.to_create_loop:
await self.create_handle_loop()
is_single = isinstance(obj.text, str)
if is_single:
rid = obj.rid
input_ids = self.tokenizer.encode(obj.text)
sampling_params = SamplingParams(**obj.sampling_params)
if sampling_params.max_new_tokens != 0:
sampling_params.normalize(self.tokenizer)
sampling_params.verify()
if obj.image_data is None:
pixel_values, image_hash = None, None
else:
pixel_values, image_hash = await self.get_pixel_values(obj.image_data)
tokenized_obj = TokenizedGenerateReqInput(
rid=rid,
input_ids=input_ids,
pixel_values=pixel_values,
image_hash=image_hash,
sampling_params=sampling_params,
return_normalized_logprob=obj.return_normalized_logprob,
normalized_logprob_start_len=obj.normalized_logprob_start_len,
stream=obj.stream,
)
self.send_to_router.send_pyobj(tokenized_obj)
lock = asyncio.Lock()
event = asyncio.Event()
state = ReqState([], False, event, lock)
self.rid_to_state[rid] = state
while True:
await event.wait()
yield state.out_list[-1]
state.out_list = []
if state.finished:
del self.rid_to_state[rid]
break
event.clear()
else:
assert obj.stream is False
bs = len(obj.text)
for i in range(bs):
rid = obj.rid[i]
input_ids = self.tokenizer.encode(obj.text[i])
sampling_params = SamplingParams(**obj.sampling_params[i])
if sampling_params.max_new_tokens != 0:
sampling_params.normalize(self.tokenizer)
sampling_params.verify()
if obj.image_data[i] is None:
pixel_values, image_hash = None, None
else:
pixel_values, image_hash = await self.get_pixel_values(
obj.image_data[i]
)
tokenized_obj = TokenizedGenerateReqInput(
rid=rid,
input_ids=input_ids,
pixel_values=pixel_values,
image_hash=image_hash,
sampling_params=sampling_params,
return_normalized_logprob=obj.return_normalized_logprob[i],
normalized_logprob_start_len=obj.normalized_logprob_start_len[i],
stream=obj.stream,
)
self.send_to_router.send_pyobj(tokenized_obj)
lock = asyncio.Lock()
event = asyncio.Event()
state = ReqState([], False, event, lock)
self.rid_to_state[rid] = state
output_list = []
for i in range(bs):
rid = obj.rid[i]
state = self.rid_to_state[rid]
await state.event.wait()
output_list.append(state.out_list[-1])
assert state.finished
del self.rid_to_state[rid]
yield output_list
async def create_handle_loop(self):
self.to_create_loop = False
loop = asyncio.get_event_loop()
loop.create_task(self.handle_loop())
async def handle_loop(self):
while True:
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
if isinstance(recv_obj, BatchStrOut):
for i, rid in enumerate(recv_obj.rids):
recv_obj.meta_info[i]["id"] = rid
out_dict = {
"text": recv_obj.output_str[i],
"meta_info": recv_obj.meta_info[i],
}
state = self.rid_to_state[rid]
state.out_list.append(out_dict)
state.finished = recv_obj.finished[i]
state.event.set()
else:
raise ValueError(f"Invalid object: {recv_obj}")