Fix rid state map leak + Refractor .finished (#505)
Co-authored-by: ZX <zx@lbx.dev>
This commit is contained in:
@@ -4,7 +4,7 @@ import dataclasses
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from typing import List
|
||||
from typing import List, Dict
|
||||
|
||||
import numpy as np
|
||||
import transformers
|
||||
@@ -26,6 +26,7 @@ from sglang.srt.managers.io_struct import (
|
||||
GenerateReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.io_struct import BatchTokenIDOut
|
||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||
from sglang.srt.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
@@ -89,7 +90,7 @@ class TokenizerManager:
|
||||
)
|
||||
|
||||
self.to_create_loop = True
|
||||
self.rid_to_state = {} # Dict[str -> ReqState]
|
||||
self.rid_to_state: Dict[str, ReqState] = {}
|
||||
|
||||
async def get_pixel_values(self, image_data):
|
||||
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
||||
@@ -183,12 +184,17 @@ class TokenizerManager:
|
||||
if self.server_args.log_requests and state.finished:
|
||||
logger.info(f"in={obj.text}, out={out}")
|
||||
|
||||
yield out
|
||||
state.out_list = []
|
||||
if state.finished:
|
||||
del self.rid_to_state[rid]
|
||||
|
||||
yield out
|
||||
|
||||
break
|
||||
|
||||
event.clear()
|
||||
|
||||
yield out
|
||||
else:
|
||||
if obj.stream:
|
||||
raise ValueError("Do not support stream for batch mode.")
|
||||
@@ -298,24 +304,23 @@ class TokenizerManager:
|
||||
|
||||
async def handle_loop(self):
|
||||
while True:
|
||||
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
||||
recv_obj: BatchTokenIDOut = await self.recv_from_detokenizer.recv_pyobj()
|
||||
assert isinstance(recv_obj, BatchStrOut)
|
||||
|
||||
if isinstance(recv_obj, BatchStrOut):
|
||||
for i, rid in enumerate(recv_obj.rids):
|
||||
state = self.rid_to_state.get(rid, None)
|
||||
if state is None:
|
||||
continue
|
||||
for i, rid in enumerate(recv_obj.rids):
|
||||
state = self.rid_to_state.get(rid, None)
|
||||
if state is None:
|
||||
continue
|
||||
|
||||
recv_obj.meta_info[i]["id"] = rid
|
||||
out_dict = {
|
||||
"text": recv_obj.output_str[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
}
|
||||
state.out_list.append(out_dict)
|
||||
state.finished = recv_obj.finished_reason[i] is not None
|
||||
state.event.set()
|
||||
|
||||
recv_obj.meta_info[i]["id"] = rid
|
||||
out_dict = {
|
||||
"text": recv_obj.output_str[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
}
|
||||
state.out_list.append(out_dict)
|
||||
state.finished = recv_obj.finished[i]
|
||||
state.event.set()
|
||||
else:
|
||||
raise ValueError(f"Invalid object: {recv_obj}.")
|
||||
|
||||
def convert_logprob_style(
|
||||
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
|
||||
|
||||
Reference in New Issue
Block a user