Files
sglang/python/sglang/srt/managers/detokenizer_manager.py
2024-02-05 16:50:37 +08:00

92 lines
3.2 KiB
Python

import asyncio
import uvloop
import zmq
import zmq.asyncio
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
class DetokenizerManager:
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
):
context = zmq.asyncio.Context(2)
self.recv_from_router = context.socket(zmq.PULL)
self.recv_from_router.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}")
self.send_to_tokenizer = context.socket(zmq.PUSH)
self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
async def handle_loop(self):
while True:
recv_obj = await self.recv_from_router.recv_pyobj()
if isinstance(recv_obj, BatchTokenIDOut):
output_tokens = recv_obj.output_tokens
# TODO(lmzheng): handle skip_special_tokens per request
output_strs = self.tokenizer.batch_decode(
output_tokens,
skip_special_tokens=recv_obj.skip_special_tokens[0],
)
# Trim stop str
# TODO(lmzheng): handle the case where multiple stop strs are hit
for i in range(len(output_strs)):
if recv_obj.hit_stop_str[i] is not None:
pos = output_strs[i].find(recv_obj.hit_stop_str[i])
if pos != -1:
output_strs[i] = output_strs[i][:pos]
if len(output_tokens[i]) > 0:
first_token = self.tokenizer.convert_ids_to_tokens(
int(output_tokens[i][0])
)
if not isinstance(first_token, str):
first_token = first_token.decode("utf-8", errors="ignore")
if first_token.startswith(""):
output_strs[i] = " " + output_strs[i]
output_strs[i] = (
recv_obj.output_and_jump_forward_strs[i] + output_strs[i]
)
self.send_to_tokenizer.send_pyobj(
BatchStrOut(
recv_obj.rids,
output_strs,
recv_obj.meta_info,
recv_obj.finished,
)
)
else:
raise ValueError(f"Invalid object: {recv_obj}")
def start_detokenizer_process(
server_args: ServerArgs,
port_args: PortArgs,
pipe_writer,
):
try:
manager = DetokenizerManager(server_args, port_args)
except Exception as e:
pipe_writer.send(get_exception_traceback())
raise
pipe_writer.send("init ok")
loop = asyncio.get_event_loop()
loop.run_until_complete(manager.handle_loop())