From f70f72586ad26c1738a0d6dc6fbcaa878997b68c Mon Sep 17 00:00:00 2001 From: Qubitium <417764+Qubitium@users.noreply.github.com> Date: Sat, 8 Jun 2024 04:20:40 +0800 Subject: [PATCH] Fix rid state map leak + Refractor .finished (#505) Co-authored-by: ZX --- .../srt/managers/controller/dp_worker.py | 5 +- .../srt/managers/controller/infer_batch.py | 81 ++++++++++++------- .../srt/managers/controller/manager_single.py | 2 +- .../srt/managers/controller/tp_worker.py | 21 ++--- .../srt/managers/detokenizer_manager.py | 75 +++++++++-------- python/sglang/srt/managers/io_struct.py | 11 ++- .../sglang/srt/managers/tokenizer_manager.py | 43 +++++----- 7 files changed, 130 insertions(+), 108 deletions(-) diff --git a/python/sglang/srt/managers/controller/dp_worker.py b/python/sglang/srt/managers/controller/dp_worker.py index 3e300a17a..a1b67396d 100644 --- a/python/sglang/srt/managers/controller/dp_worker.py +++ b/python/sglang/srt/managers/controller/dp_worker.py @@ -10,6 +10,7 @@ import zmq from sglang.global_config import global_config from sglang.srt.managers.controller.tp_worker import ModelTpClient +from sglang.srt.managers.io_struct import BatchTokenIDOut from sglang.srt.server_args import PortArgs, ServerArgs from sglang.utils import get_exception_traceback @@ -44,6 +45,8 @@ class DataParallelWorkerThread(threading.Thread): requests = [] while not self.request_queue.empty(): requests.append(self.request_queue.get()) + + out_pyobjs: List[BatchTokenIDOut] = [] try: out_pyobjs = await self.step(requests) except Exception: @@ -61,7 +64,7 @@ class DataParallelWorkerThread(threading.Thread): # async sleep for receiving the subsequent request and avoiding cache miss if len(out_pyobjs) != 0: - has_finished = any([obj.finished for obj in out_pyobjs]) + has_finished = any([obj.finished_reason is not None for obj in out_pyobjs]) if has_finished: await asyncio.sleep(self.request_dependency_delay) await asyncio.sleep(global_config.wait_for_new_request_delay) diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index 6b82c9f07..410f8a230 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -15,25 +15,47 @@ class ForwardMode(IntEnum): EXTEND = auto() DECODE = auto() +class BaseFinishReason: + def __init__(self, is_error: bool = False): + self.is_error = is_error -class FinishReason(IntEnum): - EOS_TOKEN = auto() - LENGTH = auto() - STOP_STR = auto() - ABORT = auto() + def __str__(self): + raise NotImplementedError("Subclasses must implement this method") - @staticmethod - def to_str(reason): - if reason == FinishReason.EOS_TOKEN: - return None - elif reason == FinishReason.LENGTH: - return "length" - elif reason == FinishReason.STOP_STR: - return "stop" - elif reason == FinishReason.ABORT: - return "abort" - else: - return None + +class FINISH_MATCHED_TOKEN(BaseFinishReason): + def __init__(self, matched: int | List[int]): + super().__init__() + self.matched = matched + + def __str__(self) -> str: + return f"FINISH_MATCHED_TOKEN: {self.matched}" + + +class FINISH_LENGTH(BaseFinishReason): + def __init__(self, length: int): + super().__init__() + self.length = length + + def __str__(self) -> str: + return f"FINISH_LENGTH: {self.length}" + + +class FINISH_MATCHED_STR(BaseFinishReason): + def __init__(self, matched: str): + super().__init__() + self.matched = matched + + def __str__(self) -> str: + return f"FINISH_MATCHED_STR: {self.matched}" + + +class FINISH_ABORT(BaseFinishReason): + def __init__(self): + super().__init__(is_error=True) + + def __str__(self) -> str: + return "FINISH_ABORT" class Req: @@ -61,11 +83,10 @@ class Req: self.sampling_params = None self.stream = False - # Check finish self.tokenizer = None - self.finished = False - self.finish_reason = None - self.hit_stop_str = None + + # Check finish + self.finished_reason = None # Prefix info self.extend_input_len = 0 @@ -90,6 +111,10 @@ class Req: self.regex_fsm_state = 0 self.jump_forward_map = None + # whether request reached finished condition + def finished(self) -> bool: + return self.finished_reason is not None + def partial_decode(self, ids): first_token = self.tokenizer.convert_ids_to_tokens(ids[0]) first_token = ( @@ -101,23 +126,21 @@ class Req: return self.sampling_params.max_new_tokens def check_finished(self): - if self.finished: + if self.finished(): return if ( len(self.prev_output_ids) + len(self.output_ids) >= self.sampling_params.max_new_tokens ): - self.finished = True - self.finish_reason = FinishReason.LENGTH + self.finished_reason = FINISH_LENGTH(len(self.prev_output_ids) + len(self.output_ids)) return if ( self.output_ids[-1] == self.tokenizer.eos_token_id - and self.sampling_params.ignore_eos == False + and not self.sampling_params.ignore_eos ): - self.finished = True - self.finish_reason = FinishReason.EOS_TOKEN + self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.tokenizer.eos_token_id) return if len(self.sampling_params.stop_strs) > 0: @@ -128,9 +151,7 @@ class Req: for stop_str in self.sampling_params.stop_strs: # FIXME: (minor) try incremental match in prev_output_str if stop_str in tail_str or stop_str in self.prev_output_str: - self.finished = True - self.finish_reason = FinishReason.STOP_STR - self.hit_stop_str = stop_str + self.finished_reason = FINISH_MATCHED_STR(matched=stop_str) return def jump_forward_and_retokenize(self, jump_forward_str, next_state): diff --git a/python/sglang/srt/managers/controller/manager_single.py b/python/sglang/srt/managers/controller/manager_single.py index 227b8a7b7..7b39a56de 100644 --- a/python/sglang/srt/managers/controller/manager_single.py +++ b/python/sglang/srt/managers/controller/manager_single.py @@ -45,7 +45,7 @@ class ControllerSingle: # async sleep for receiving the subsequent request and avoiding cache miss slept = False if len(out_pyobjs) != 0: - has_finished = any([obj.finished for obj in out_pyobjs]) + has_finished = any([obj.finished_reason is not None for obj in out_pyobjs]) if has_finished: if self.request_dependency_delay > 0: slept = True diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index 7fb5e1b3b..8343429db 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -19,7 +19,7 @@ from sglang.srt.managers.io_struct import ( FlushCacheReq, TokenizedGenerateReqInput, ) -from sglang.srt.managers.controller.infer_batch import Batch, FinishReason, ForwardMode, Req +from sglang.srt.managers.controller.infer_batch import BaseFinishReason, Batch, FINISH_ABORT, ForwardMode, Req from sglang.srt.managers.controller.model_runner import ModelRunner from sglang.srt.managers.controller.radix_cache import RadixCache from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic @@ -595,20 +595,19 @@ class ModelTpServer: output_rids = [] prev_output_strs = [] output_tokens = [] - output_hit_stop_str = [] output_skip_special_tokens = [] output_spaces_between_special_tokens = [] output_meta_info = [] - output_finished = [] + output_finished_reason: List[BaseFinishReason] = [] finished_indices = [] unfinished_indices = [] for i, req in enumerate(batch.reqs): - if req.finished: + if req.finished(): finished_indices.append(i) else: unfinished_indices.append(i) - if req.finished or ( + if req.finished() or ( ( req.stream and ( @@ -620,7 +619,6 @@ class ModelTpServer: output_rids.append(req.rid) prev_output_strs.append(req.prev_output_str) output_tokens.append(req.output_ids) - output_hit_stop_str.append(req.hit_stop_str) output_skip_special_tokens.append( req.sampling_params.skip_special_tokens ) @@ -632,8 +630,7 @@ class ModelTpServer: "prompt_tokens": len(req.origin_input_ids), "completion_tokens": len(req.prev_output_ids) + len(req.output_ids), "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, - "finish_reason": FinishReason.to_str(req.finish_reason), - "hit_stop_str": req.hit_stop_str, + "finish_reason": str(req.finished_reason), } if req.return_logprob: ( @@ -650,7 +647,7 @@ class ModelTpServer: req.normalized_prompt_logprob, ) output_meta_info.append(meta_info) - output_finished.append(req.finished) + output_finished_reason.append(req.finished_reason) # Send to detokenizer if output_rids: @@ -659,11 +656,10 @@ class ModelTpServer: output_rids, prev_output_strs, output_tokens, - output_hit_stop_str, output_skip_special_tokens, output_spaces_between_special_tokens, output_meta_info, - output_finished, + output_finished_reason, ) ) @@ -720,8 +716,7 @@ class ModelTpServer: if self.running_batch: for req in self.running_batch.reqs: if req.rid == recv_req.rid: - req.finished = True - req.finish_reason = FinishReason.ABORT + req.finished_reason = FINISH_ABORT() break diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 4774dba33..c77625eb9 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -9,6 +9,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut from sglang.srt.server_args import PortArgs, ServerArgs from sglang.utils import get_exception_traceback, graceful_registry +from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -34,49 +35,47 @@ class DetokenizerManager: async def handle_loop(self): while True: - recv_obj = await self.recv_from_router.recv_pyobj() + recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj() + assert isinstance(recv_obj, BatchTokenIDOut) - if isinstance(recv_obj, BatchTokenIDOut): - output_tokens = recv_obj.output_tokens + output_tokens = recv_obj.output_tokens - # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request - output_strs = self.tokenizer.batch_decode( - output_tokens, - skip_special_tokens=recv_obj.skip_special_tokens[0], - spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[ - 0 - ], - ) + # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request + output_strs = self.tokenizer.batch_decode( + output_tokens, + skip_special_tokens=recv_obj.skip_special_tokens[0], + spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[ + 0 + ], + ) - # Trim stop str - # TODO(lmzheng): handle the case where multiple stop strs are hit - for i in range(len(output_strs)): - if len(output_tokens[i]) > 0: - first_token = self.tokenizer.convert_ids_to_tokens( - int(output_tokens[i][0]) - ) - if not isinstance(first_token, str): - first_token = first_token.decode("utf-8", errors="ignore") - if first_token.startswith("▁"): - output_strs[i] = " " + output_strs[i] - - output_strs[i] = recv_obj.prev_output_strs[i] + output_strs[i] - - if recv_obj.hit_stop_str[i] is not None: - pos = output_strs[i].find(recv_obj.hit_stop_str[i]) - if pos != -1: - output_strs[i] = output_strs[i][:pos] - - self.send_to_tokenizer.send_pyobj( - BatchStrOut( - recv_obj.rids, - output_strs, - recv_obj.meta_info, - recv_obj.finished, + # Trim stop str + # TODO(lmzheng): handle the case where multiple stop strs are hit + for i in range(len(output_strs)): + if len(output_tokens[i]) > 0: + first_token = self.tokenizer.convert_ids_to_tokens( + int(output_tokens[i][0]) ) + if not isinstance(first_token, str): + first_token = first_token.decode("utf-8", errors="ignore") + if first_token.startswith("▁"): + output_strs[i] = " " + output_strs[i] + + output_strs[i] = recv_obj.prev_output_strs[i] + output_strs[i] + + if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR): + pos = output_strs[i].find(recv_obj.finished_reason[i].matched) + if pos != -1: + output_strs[i] = output_strs[i][:pos] + + self.send_to_tokenizer.send_pyobj( + BatchStrOut( + rids=recv_obj.rids, + output_str=output_strs, + meta_info=recv_obj.meta_info, + finished_reason=recv_obj.finished_reason, ) - else: - raise ValueError(f"Invalid object: {recv_obj}") + ) def start_detokenizer_process( diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index a07042b46..004308c3b 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Union from sglang.srt.sampling_params import SamplingParams +from sglang.srt.managers.controller.infer_batch import BaseFinishReason @dataclass @@ -105,21 +106,19 @@ class TokenizedGenerateReqInput: @dataclass class BatchTokenIDOut: rids: List[str] - prev_output_strs : List[str] + prev_output_strs: List[str] output_tokens: List[List[int]] - hit_stop_str: List[Optional[str]] skip_special_tokens: List[bool] spaces_between_special_tokens: List[bool] meta_info: List[Dict] - finished: List[bool] - + finished_reason: List[BaseFinishReason] @dataclass class BatchStrOut: rids: List[str] output_str: List[str] meta_info: List[Dict] - finished: List[bool] + finished_reason: List[BaseFinishReason] @dataclass @@ -134,4 +133,4 @@ class AbortReq: @dataclass class DetokenizeReqInput: - input_ids: List[int] \ No newline at end of file + input_ids: List[int] diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 482347153..38f07739e 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -4,7 +4,7 @@ import dataclasses import logging import multiprocessing as mp import os -from typing import List +from typing import List, Dict import numpy as np import transformers @@ -26,6 +26,7 @@ from sglang.srt.managers.io_struct import ( GenerateReqInput, TokenizedGenerateReqInput, ) +from sglang.srt.managers.io_struct import BatchTokenIDOut from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs @@ -89,7 +90,7 @@ class TokenizerManager: ) self.to_create_loop = True - self.rid_to_state = {} # Dict[str -> ReqState] + self.rid_to_state: Dict[str, ReqState] = {} async def get_pixel_values(self, image_data): aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) @@ -183,12 +184,17 @@ class TokenizerManager: if self.server_args.log_requests and state.finished: logger.info(f"in={obj.text}, out={out}") - yield out state.out_list = [] if state.finished: del self.rid_to_state[rid] + + yield out + break + event.clear() + + yield out else: if obj.stream: raise ValueError("Do not support stream for batch mode.") @@ -298,24 +304,23 @@ class TokenizerManager: async def handle_loop(self): while True: - recv_obj = await self.recv_from_detokenizer.recv_pyobj() + recv_obj: BatchTokenIDOut = await self.recv_from_detokenizer.recv_pyobj() + assert isinstance(recv_obj, BatchStrOut) - if isinstance(recv_obj, BatchStrOut): - for i, rid in enumerate(recv_obj.rids): - state = self.rid_to_state.get(rid, None) - if state is None: - continue + for i, rid in enumerate(recv_obj.rids): + state = self.rid_to_state.get(rid, None) + if state is None: + continue + + recv_obj.meta_info[i]["id"] = rid + out_dict = { + "text": recv_obj.output_str[i], + "meta_info": recv_obj.meta_info[i], + } + state.out_list.append(out_dict) + state.finished = recv_obj.finished_reason[i] is not None + state.event.set() - recv_obj.meta_info[i]["id"] = rid - out_dict = { - "text": recv_obj.output_str[i], - "meta_info": recv_obj.meta_info[i], - } - state.out_list.append(out_dict) - state.finished = recv_obj.finished[i] - state.event.set() - else: - raise ValueError(f"Invalid object: {recv_obj}.") def convert_logprob_style( self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs