Fix rid state map leak + Refractor .finished (#505)

Co-authored-by: ZX <zx@lbx.dev>
This commit is contained in:
Qubitium
2024-06-08 04:20:40 +08:00
committed by GitHub
parent c0ae70c8ed
commit f70f72586a
7 changed files with 130 additions and 108 deletions

View File

@@ -9,6 +9,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import get_exception_traceback, graceful_registry
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -34,49 +35,47 @@ class DetokenizerManager:
async def handle_loop(self):
while True:
recv_obj = await self.recv_from_router.recv_pyobj()
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
assert isinstance(recv_obj, BatchTokenIDOut)
if isinstance(recv_obj, BatchTokenIDOut):
output_tokens = recv_obj.output_tokens
output_tokens = recv_obj.output_tokens
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
output_strs = self.tokenizer.batch_decode(
output_tokens,
skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
0
],
)
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
output_strs = self.tokenizer.batch_decode(
output_tokens,
skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
0
],
)
# Trim stop str
# TODO(lmzheng): handle the case where multiple stop strs are hit
for i in range(len(output_strs)):
if len(output_tokens[i]) > 0:
first_token = self.tokenizer.convert_ids_to_tokens(
int(output_tokens[i][0])
)
if not isinstance(first_token, str):
first_token = first_token.decode("utf-8", errors="ignore")
if first_token.startswith(""):
output_strs[i] = " " + output_strs[i]
output_strs[i] = recv_obj.prev_output_strs[i] + output_strs[i]
if recv_obj.hit_stop_str[i] is not None:
pos = output_strs[i].find(recv_obj.hit_stop_str[i])
if pos != -1:
output_strs[i] = output_strs[i][:pos]
self.send_to_tokenizer.send_pyobj(
BatchStrOut(
recv_obj.rids,
output_strs,
recv_obj.meta_info,
recv_obj.finished,
# Trim stop str
# TODO(lmzheng): handle the case where multiple stop strs are hit
for i in range(len(output_strs)):
if len(output_tokens[i]) > 0:
first_token = self.tokenizer.convert_ids_to_tokens(
int(output_tokens[i][0])
)
if not isinstance(first_token, str):
first_token = first_token.decode("utf-8", errors="ignore")
if first_token.startswith(""):
output_strs[i] = " " + output_strs[i]
output_strs[i] = recv_obj.prev_output_strs[i] + output_strs[i]
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
if pos != -1:
output_strs[i] = output_strs[i][:pos]
self.send_to_tokenizer.send_pyobj(
BatchStrOut(
rids=recv_obj.rids,
output_str=output_strs,
meta_info=recv_obj.meta_info,
finished_reason=recv_obj.finished_reason,
)
else:
raise ValueError(f"Invalid object: {recv_obj}")
)
def start_detokenizer_process(