Fix rid state map leak + Refractor .finished (#505)

Co-authored-by: ZX <zx@lbx.dev>
This commit is contained in:
Qubitium
2024-06-08 04:20:40 +08:00
committed by GitHub
parent c0ae70c8ed
commit f70f72586a
7 changed files with 130 additions and 108 deletions

View File

@@ -4,7 +4,7 @@ import dataclasses
import logging
import multiprocessing as mp
import os
from typing import List
from typing import List, Dict
import numpy as np
import transformers
@@ -26,6 +26,7 @@ from sglang.srt.managers.io_struct import (
GenerateReqInput,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.io_struct import BatchTokenIDOut
from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
@@ -89,7 +90,7 @@ class TokenizerManager:
)
self.to_create_loop = True
self.rid_to_state = {} # Dict[str -> ReqState]
self.rid_to_state: Dict[str, ReqState] = {}
async def get_pixel_values(self, image_data):
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
@@ -183,12 +184,17 @@ class TokenizerManager:
if self.server_args.log_requests and state.finished:
logger.info(f"in={obj.text}, out={out}")
yield out
state.out_list = []
if state.finished:
del self.rid_to_state[rid]
yield out
break
event.clear()
yield out
else:
if obj.stream:
raise ValueError("Do not support stream for batch mode.")
@@ -298,24 +304,23 @@ class TokenizerManager:
async def handle_loop(self):
while True:
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
recv_obj: BatchTokenIDOut = await self.recv_from_detokenizer.recv_pyobj()
assert isinstance(recv_obj, BatchStrOut)
if isinstance(recv_obj, BatchStrOut):
for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
if state is None:
continue
for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
if state is None:
continue
recv_obj.meta_info[i]["id"] = rid
out_dict = {
"text": recv_obj.output_str[i],
"meta_info": recv_obj.meta_info[i],
}
state.out_list.append(out_dict)
state.finished = recv_obj.finished_reason[i] is not None
state.event.set()
recv_obj.meta_info[i]["id"] = rid
out_dict = {
"text": recv_obj.output_str[i],
"meta_info": recv_obj.meta_info[i],
}
state.out_list.append(out_dict)
state.finished = recv_obj.finished[i]
state.event.set()
else:
raise ValueError(f"Invalid object: {recv_obj}.")
def convert_logprob_style(
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs