Detokenize incrementally when streaming (#653)

This commit is contained in:
Liangsheng Yin
2024-07-18 17:57:40 -07:00
committed by GitHub
parent 21ba3a88a1
commit a9ef49c12c
5 changed files with 101 additions and 33 deletions

View File

@@ -1,7 +1,9 @@
"""DetokenizerManager is a process that detokenizes the token ids."""
import asyncio
import dataclasses
import inspect
from typing import List
import uvloop
import zmq
@@ -16,6 +18,14 @@ from sglang.utils import find_printable_text, get_exception_traceback, graceful_
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@dataclasses.dataclass
class DecodeStatus:
decoded_text: str
decode_ids: List[int]
surr_offset: int
read_offset: int
class DetokenizerManager:
def __init__(
self,
@@ -35,19 +45,42 @@ class DetokenizerManager:
trust_remote_code=server_args.trust_remote_code,
)
self.decode_status = {}
async def handle_loop(self):
while True:
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
assert isinstance(recv_obj, BatchTokenIDOut)
bs = len(recv_obj.rids)
# FIXME: incremental detokenize is not compatible with jump forward
# Initialize decode status
read_ids, surr_ids = [], []
for i in range(bs):
rid = recv_obj.rids[i]
if rid not in self.decode_status:
s = DecodeStatus(
decoded_text=recv_obj.decoded_texts[i],
decode_ids=recv_obj.decode_ids[i],
surr_offset=0,
read_offset=recv_obj.read_offsets[i],
)
self.decode_status[rid] = s
else:
s = self.decode_status[rid]
s.decode_ids = recv_obj.decode_ids[i]
read_ids.append(s.decode_ids[s.surr_offset :])
surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
surr_texts = self.tokenizer.batch_decode(
recv_obj.surr_output_ids,
surr_ids,
skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
)
read_texts = self.tokenizer.batch_decode(
recv_obj.read_output_ids,
read_ids,
skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
)
@@ -55,11 +88,20 @@ class DetokenizerManager:
# Trim stop str
# TODO(lmzheng): handle the case where multiple stop strs are hit
output_strs = []
for i in range(len(recv_obj.rids)):
for i in range(bs):
s = self.decode_status[recv_obj.rids[i]]
new_text = read_texts[i][len(surr_texts[i]) :]
if recv_obj.finished_reason[i] is None:
new_text = find_printable_text(new_text)
output_strs.append(recv_obj.decoded_texts[i] + new_text)
# Streaming chunk: update the decode status
if len(new_text) > 0 and not new_text.endswith("<EFBFBD>"):
s.decoded_text = s.decoded_text + new_text
s.surr_offset = s.read_offset
s.read_offset = len(s.decode_ids)
new_text = ""
else:
new_text = find_printable_text(new_text)
output_strs.append(s.decoded_text + new_text)
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)