Fix rid state map leak + Refractor .finished (#505)
Co-authored-by: ZX <zx@lbx.dev>
This commit is contained in:
@@ -10,6 +10,7 @@ import zmq
|
|||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.managers.controller.tp_worker import ModelTpClient
|
from sglang.srt.managers.controller.tp_worker import ModelTpClient
|
||||||
|
from sglang.srt.managers.io_struct import BatchTokenIDOut
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
@@ -44,6 +45,8 @@ class DataParallelWorkerThread(threading.Thread):
|
|||||||
requests = []
|
requests = []
|
||||||
while not self.request_queue.empty():
|
while not self.request_queue.empty():
|
||||||
requests.append(self.request_queue.get())
|
requests.append(self.request_queue.get())
|
||||||
|
|
||||||
|
out_pyobjs: List[BatchTokenIDOut] = []
|
||||||
try:
|
try:
|
||||||
out_pyobjs = await self.step(requests)
|
out_pyobjs = await self.step(requests)
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -61,7 +64,7 @@ class DataParallelWorkerThread(threading.Thread):
|
|||||||
|
|
||||||
# async sleep for receiving the subsequent request and avoiding cache miss
|
# async sleep for receiving the subsequent request and avoiding cache miss
|
||||||
if len(out_pyobjs) != 0:
|
if len(out_pyobjs) != 0:
|
||||||
has_finished = any([obj.finished for obj in out_pyobjs])
|
has_finished = any([obj.finished_reason is not None for obj in out_pyobjs])
|
||||||
if has_finished:
|
if has_finished:
|
||||||
await asyncio.sleep(self.request_dependency_delay)
|
await asyncio.sleep(self.request_dependency_delay)
|
||||||
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
||||||
|
|||||||
@@ -15,25 +15,47 @@ class ForwardMode(IntEnum):
|
|||||||
EXTEND = auto()
|
EXTEND = auto()
|
||||||
DECODE = auto()
|
DECODE = auto()
|
||||||
|
|
||||||
|
class BaseFinishReason:
|
||||||
|
def __init__(self, is_error: bool = False):
|
||||||
|
self.is_error = is_error
|
||||||
|
|
||||||
class FinishReason(IntEnum):
|
def __str__(self):
|
||||||
EOS_TOKEN = auto()
|
raise NotImplementedError("Subclasses must implement this method")
|
||||||
LENGTH = auto()
|
|
||||||
STOP_STR = auto()
|
|
||||||
ABORT = auto()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def to_str(reason):
|
class FINISH_MATCHED_TOKEN(BaseFinishReason):
|
||||||
if reason == FinishReason.EOS_TOKEN:
|
def __init__(self, matched: int | List[int]):
|
||||||
return None
|
super().__init__()
|
||||||
elif reason == FinishReason.LENGTH:
|
self.matched = matched
|
||||||
return "length"
|
|
||||||
elif reason == FinishReason.STOP_STR:
|
def __str__(self) -> str:
|
||||||
return "stop"
|
return f"FINISH_MATCHED_TOKEN: {self.matched}"
|
||||||
elif reason == FinishReason.ABORT:
|
|
||||||
return "abort"
|
|
||||||
else:
|
class FINISH_LENGTH(BaseFinishReason):
|
||||||
return None
|
def __init__(self, length: int):
|
||||||
|
super().__init__()
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"FINISH_LENGTH: {self.length}"
|
||||||
|
|
||||||
|
|
||||||
|
class FINISH_MATCHED_STR(BaseFinishReason):
|
||||||
|
def __init__(self, matched: str):
|
||||||
|
super().__init__()
|
||||||
|
self.matched = matched
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"FINISH_MATCHED_STR: {self.matched}"
|
||||||
|
|
||||||
|
|
||||||
|
class FINISH_ABORT(BaseFinishReason):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(is_error=True)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return "FINISH_ABORT"
|
||||||
|
|
||||||
|
|
||||||
class Req:
|
class Req:
|
||||||
@@ -61,11 +83,10 @@ class Req:
|
|||||||
self.sampling_params = None
|
self.sampling_params = None
|
||||||
self.stream = False
|
self.stream = False
|
||||||
|
|
||||||
# Check finish
|
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.finished = False
|
|
||||||
self.finish_reason = None
|
# Check finish
|
||||||
self.hit_stop_str = None
|
self.finished_reason = None
|
||||||
|
|
||||||
# Prefix info
|
# Prefix info
|
||||||
self.extend_input_len = 0
|
self.extend_input_len = 0
|
||||||
@@ -90,6 +111,10 @@ class Req:
|
|||||||
self.regex_fsm_state = 0
|
self.regex_fsm_state = 0
|
||||||
self.jump_forward_map = None
|
self.jump_forward_map = None
|
||||||
|
|
||||||
|
# whether request reached finished condition
|
||||||
|
def finished(self) -> bool:
|
||||||
|
return self.finished_reason is not None
|
||||||
|
|
||||||
def partial_decode(self, ids):
|
def partial_decode(self, ids):
|
||||||
first_token = self.tokenizer.convert_ids_to_tokens(ids[0])
|
first_token = self.tokenizer.convert_ids_to_tokens(ids[0])
|
||||||
first_token = (
|
first_token = (
|
||||||
@@ -101,23 +126,21 @@ class Req:
|
|||||||
return self.sampling_params.max_new_tokens
|
return self.sampling_params.max_new_tokens
|
||||||
|
|
||||||
def check_finished(self):
|
def check_finished(self):
|
||||||
if self.finished:
|
if self.finished():
|
||||||
return
|
return
|
||||||
|
|
||||||
if (
|
if (
|
||||||
len(self.prev_output_ids) + len(self.output_ids)
|
len(self.prev_output_ids) + len(self.output_ids)
|
||||||
>= self.sampling_params.max_new_tokens
|
>= self.sampling_params.max_new_tokens
|
||||||
):
|
):
|
||||||
self.finished = True
|
self.finished_reason = FINISH_LENGTH(len(self.prev_output_ids) + len(self.output_ids))
|
||||||
self.finish_reason = FinishReason.LENGTH
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.output_ids[-1] == self.tokenizer.eos_token_id
|
self.output_ids[-1] == self.tokenizer.eos_token_id
|
||||||
and self.sampling_params.ignore_eos == False
|
and not self.sampling_params.ignore_eos
|
||||||
):
|
):
|
||||||
self.finished = True
|
self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.tokenizer.eos_token_id)
|
||||||
self.finish_reason = FinishReason.EOS_TOKEN
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if len(self.sampling_params.stop_strs) > 0:
|
if len(self.sampling_params.stop_strs) > 0:
|
||||||
@@ -128,9 +151,7 @@ class Req:
|
|||||||
for stop_str in self.sampling_params.stop_strs:
|
for stop_str in self.sampling_params.stop_strs:
|
||||||
# FIXME: (minor) try incremental match in prev_output_str
|
# FIXME: (minor) try incremental match in prev_output_str
|
||||||
if stop_str in tail_str or stop_str in self.prev_output_str:
|
if stop_str in tail_str or stop_str in self.prev_output_str:
|
||||||
self.finished = True
|
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
||||||
self.finish_reason = FinishReason.STOP_STR
|
|
||||||
self.hit_stop_str = stop_str
|
|
||||||
return
|
return
|
||||||
|
|
||||||
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ class ControllerSingle:
|
|||||||
# async sleep for receiving the subsequent request and avoiding cache miss
|
# async sleep for receiving the subsequent request and avoiding cache miss
|
||||||
slept = False
|
slept = False
|
||||||
if len(out_pyobjs) != 0:
|
if len(out_pyobjs) != 0:
|
||||||
has_finished = any([obj.finished for obj in out_pyobjs])
|
has_finished = any([obj.finished_reason is not None for obj in out_pyobjs])
|
||||||
if has_finished:
|
if has_finished:
|
||||||
if self.request_dependency_delay > 0:
|
if self.request_dependency_delay > 0:
|
||||||
slept = True
|
slept = True
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.controller.infer_batch import Batch, FinishReason, ForwardMode, Req
|
from sglang.srt.managers.controller.infer_batch import BaseFinishReason, Batch, FINISH_ABORT, ForwardMode, Req
|
||||||
from sglang.srt.managers.controller.model_runner import ModelRunner
|
from sglang.srt.managers.controller.model_runner import ModelRunner
|
||||||
from sglang.srt.managers.controller.radix_cache import RadixCache
|
from sglang.srt.managers.controller.radix_cache import RadixCache
|
||||||
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
|
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
|
||||||
@@ -595,20 +595,19 @@ class ModelTpServer:
|
|||||||
output_rids = []
|
output_rids = []
|
||||||
prev_output_strs = []
|
prev_output_strs = []
|
||||||
output_tokens = []
|
output_tokens = []
|
||||||
output_hit_stop_str = []
|
|
||||||
output_skip_special_tokens = []
|
output_skip_special_tokens = []
|
||||||
output_spaces_between_special_tokens = []
|
output_spaces_between_special_tokens = []
|
||||||
output_meta_info = []
|
output_meta_info = []
|
||||||
output_finished = []
|
output_finished_reason: List[BaseFinishReason] = []
|
||||||
finished_indices = []
|
finished_indices = []
|
||||||
unfinished_indices = []
|
unfinished_indices = []
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
if req.finished:
|
if req.finished():
|
||||||
finished_indices.append(i)
|
finished_indices.append(i)
|
||||||
else:
|
else:
|
||||||
unfinished_indices.append(i)
|
unfinished_indices.append(i)
|
||||||
|
|
||||||
if req.finished or (
|
if req.finished() or (
|
||||||
(
|
(
|
||||||
req.stream
|
req.stream
|
||||||
and (
|
and (
|
||||||
@@ -620,7 +619,6 @@ class ModelTpServer:
|
|||||||
output_rids.append(req.rid)
|
output_rids.append(req.rid)
|
||||||
prev_output_strs.append(req.prev_output_str)
|
prev_output_strs.append(req.prev_output_str)
|
||||||
output_tokens.append(req.output_ids)
|
output_tokens.append(req.output_ids)
|
||||||
output_hit_stop_str.append(req.hit_stop_str)
|
|
||||||
output_skip_special_tokens.append(
|
output_skip_special_tokens.append(
|
||||||
req.sampling_params.skip_special_tokens
|
req.sampling_params.skip_special_tokens
|
||||||
)
|
)
|
||||||
@@ -632,8 +630,7 @@ class ModelTpServer:
|
|||||||
"prompt_tokens": len(req.origin_input_ids),
|
"prompt_tokens": len(req.origin_input_ids),
|
||||||
"completion_tokens": len(req.prev_output_ids) + len(req.output_ids),
|
"completion_tokens": len(req.prev_output_ids) + len(req.output_ids),
|
||||||
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
||||||
"finish_reason": FinishReason.to_str(req.finish_reason),
|
"finish_reason": str(req.finished_reason),
|
||||||
"hit_stop_str": req.hit_stop_str,
|
|
||||||
}
|
}
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
(
|
(
|
||||||
@@ -650,7 +647,7 @@ class ModelTpServer:
|
|||||||
req.normalized_prompt_logprob,
|
req.normalized_prompt_logprob,
|
||||||
)
|
)
|
||||||
output_meta_info.append(meta_info)
|
output_meta_info.append(meta_info)
|
||||||
output_finished.append(req.finished)
|
output_finished_reason.append(req.finished_reason)
|
||||||
|
|
||||||
# Send to detokenizer
|
# Send to detokenizer
|
||||||
if output_rids:
|
if output_rids:
|
||||||
@@ -659,11 +656,10 @@ class ModelTpServer:
|
|||||||
output_rids,
|
output_rids,
|
||||||
prev_output_strs,
|
prev_output_strs,
|
||||||
output_tokens,
|
output_tokens,
|
||||||
output_hit_stop_str,
|
|
||||||
output_skip_special_tokens,
|
output_skip_special_tokens,
|
||||||
output_spaces_between_special_tokens,
|
output_spaces_between_special_tokens,
|
||||||
output_meta_info,
|
output_meta_info,
|
||||||
output_finished,
|
output_finished_reason,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -720,8 +716,7 @@ class ModelTpServer:
|
|||||||
if self.running_batch:
|
if self.running_batch:
|
||||||
for req in self.running_batch.reqs:
|
for req in self.running_batch.reqs:
|
||||||
if req.rid == recv_req.rid:
|
if req.rid == recv_req.rid:
|
||||||
req.finished = True
|
req.finished_reason = FINISH_ABORT()
|
||||||
req.finish_reason = FinishReason.ABORT
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
|
|||||||
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.utils import get_exception_traceback, graceful_registry
|
from sglang.utils import get_exception_traceback, graceful_registry
|
||||||
|
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
@@ -34,49 +35,47 @@ class DetokenizerManager:
|
|||||||
|
|
||||||
async def handle_loop(self):
|
async def handle_loop(self):
|
||||||
while True:
|
while True:
|
||||||
recv_obj = await self.recv_from_router.recv_pyobj()
|
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
|
||||||
|
assert isinstance(recv_obj, BatchTokenIDOut)
|
||||||
|
|
||||||
if isinstance(recv_obj, BatchTokenIDOut):
|
output_tokens = recv_obj.output_tokens
|
||||||
output_tokens = recv_obj.output_tokens
|
|
||||||
|
|
||||||
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
|
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
|
||||||
output_strs = self.tokenizer.batch_decode(
|
output_strs = self.tokenizer.batch_decode(
|
||||||
output_tokens,
|
output_tokens,
|
||||||
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
||||||
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
|
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
|
||||||
0
|
0
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Trim stop str
|
# Trim stop str
|
||||||
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
||||||
for i in range(len(output_strs)):
|
for i in range(len(output_strs)):
|
||||||
if len(output_tokens[i]) > 0:
|
if len(output_tokens[i]) > 0:
|
||||||
first_token = self.tokenizer.convert_ids_to_tokens(
|
first_token = self.tokenizer.convert_ids_to_tokens(
|
||||||
int(output_tokens[i][0])
|
int(output_tokens[i][0])
|
||||||
)
|
|
||||||
if not isinstance(first_token, str):
|
|
||||||
first_token = first_token.decode("utf-8", errors="ignore")
|
|
||||||
if first_token.startswith("▁"):
|
|
||||||
output_strs[i] = " " + output_strs[i]
|
|
||||||
|
|
||||||
output_strs[i] = recv_obj.prev_output_strs[i] + output_strs[i]
|
|
||||||
|
|
||||||
if recv_obj.hit_stop_str[i] is not None:
|
|
||||||
pos = output_strs[i].find(recv_obj.hit_stop_str[i])
|
|
||||||
if pos != -1:
|
|
||||||
output_strs[i] = output_strs[i][:pos]
|
|
||||||
|
|
||||||
self.send_to_tokenizer.send_pyobj(
|
|
||||||
BatchStrOut(
|
|
||||||
recv_obj.rids,
|
|
||||||
output_strs,
|
|
||||||
recv_obj.meta_info,
|
|
||||||
recv_obj.finished,
|
|
||||||
)
|
)
|
||||||
|
if not isinstance(first_token, str):
|
||||||
|
first_token = first_token.decode("utf-8", errors="ignore")
|
||||||
|
if first_token.startswith("▁"):
|
||||||
|
output_strs[i] = " " + output_strs[i]
|
||||||
|
|
||||||
|
output_strs[i] = recv_obj.prev_output_strs[i] + output_strs[i]
|
||||||
|
|
||||||
|
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
|
||||||
|
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
|
||||||
|
if pos != -1:
|
||||||
|
output_strs[i] = output_strs[i][:pos]
|
||||||
|
|
||||||
|
self.send_to_tokenizer.send_pyobj(
|
||||||
|
BatchStrOut(
|
||||||
|
rids=recv_obj.rids,
|
||||||
|
output_str=output_strs,
|
||||||
|
meta_info=recv_obj.meta_info,
|
||||||
|
finished_reason=recv_obj.finished_reason,
|
||||||
)
|
)
|
||||||
else:
|
)
|
||||||
raise ValueError(f"Invalid object: {recv_obj}")
|
|
||||||
|
|
||||||
|
|
||||||
def start_detokenizer_process(
|
def start_detokenizer_process(
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from dataclasses import dataclass
|
|||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from sglang.srt.sampling_params import SamplingParams
|
from sglang.srt.sampling_params import SamplingParams
|
||||||
|
from sglang.srt.managers.controller.infer_batch import BaseFinishReason
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -105,21 +106,19 @@ class TokenizedGenerateReqInput:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class BatchTokenIDOut:
|
class BatchTokenIDOut:
|
||||||
rids: List[str]
|
rids: List[str]
|
||||||
prev_output_strs : List[str]
|
prev_output_strs: List[str]
|
||||||
output_tokens: List[List[int]]
|
output_tokens: List[List[int]]
|
||||||
hit_stop_str: List[Optional[str]]
|
|
||||||
skip_special_tokens: List[bool]
|
skip_special_tokens: List[bool]
|
||||||
spaces_between_special_tokens: List[bool]
|
spaces_between_special_tokens: List[bool]
|
||||||
meta_info: List[Dict]
|
meta_info: List[Dict]
|
||||||
finished: List[bool]
|
finished_reason: List[BaseFinishReason]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchStrOut:
|
class BatchStrOut:
|
||||||
rids: List[str]
|
rids: List[str]
|
||||||
output_str: List[str]
|
output_str: List[str]
|
||||||
meta_info: List[Dict]
|
meta_info: List[Dict]
|
||||||
finished: List[bool]
|
finished_reason: List[BaseFinishReason]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -134,4 +133,4 @@ class AbortReq:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DetokenizeReqInput:
|
class DetokenizeReqInput:
|
||||||
input_ids: List[int]
|
input_ids: List[int]
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import dataclasses
|
|||||||
import logging
|
import logging
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List, Dict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import transformers
|
import transformers
|
||||||
@@ -26,6 +26,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.io_struct import BatchTokenIDOut
|
||||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||||
from sglang.srt.sampling_params import SamplingParams
|
from sglang.srt.sampling_params import SamplingParams
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
@@ -89,7 +90,7 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.to_create_loop = True
|
self.to_create_loop = True
|
||||||
self.rid_to_state = {} # Dict[str -> ReqState]
|
self.rid_to_state: Dict[str, ReqState] = {}
|
||||||
|
|
||||||
async def get_pixel_values(self, image_data):
|
async def get_pixel_values(self, image_data):
|
||||||
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
||||||
@@ -183,12 +184,17 @@ class TokenizerManager:
|
|||||||
if self.server_args.log_requests and state.finished:
|
if self.server_args.log_requests and state.finished:
|
||||||
logger.info(f"in={obj.text}, out={out}")
|
logger.info(f"in={obj.text}, out={out}")
|
||||||
|
|
||||||
yield out
|
|
||||||
state.out_list = []
|
state.out_list = []
|
||||||
if state.finished:
|
if state.finished:
|
||||||
del self.rid_to_state[rid]
|
del self.rid_to_state[rid]
|
||||||
|
|
||||||
|
yield out
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
event.clear()
|
event.clear()
|
||||||
|
|
||||||
|
yield out
|
||||||
else:
|
else:
|
||||||
if obj.stream:
|
if obj.stream:
|
||||||
raise ValueError("Do not support stream for batch mode.")
|
raise ValueError("Do not support stream for batch mode.")
|
||||||
@@ -298,24 +304,23 @@ class TokenizerManager:
|
|||||||
|
|
||||||
async def handle_loop(self):
|
async def handle_loop(self):
|
||||||
while True:
|
while True:
|
||||||
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
recv_obj: BatchTokenIDOut = await self.recv_from_detokenizer.recv_pyobj()
|
||||||
|
assert isinstance(recv_obj, BatchStrOut)
|
||||||
|
|
||||||
if isinstance(recv_obj, BatchStrOut):
|
for i, rid in enumerate(recv_obj.rids):
|
||||||
for i, rid in enumerate(recv_obj.rids):
|
state = self.rid_to_state.get(rid, None)
|
||||||
state = self.rid_to_state.get(rid, None)
|
if state is None:
|
||||||
if state is None:
|
continue
|
||||||
continue
|
|
||||||
|
recv_obj.meta_info[i]["id"] = rid
|
||||||
|
out_dict = {
|
||||||
|
"text": recv_obj.output_str[i],
|
||||||
|
"meta_info": recv_obj.meta_info[i],
|
||||||
|
}
|
||||||
|
state.out_list.append(out_dict)
|
||||||
|
state.finished = recv_obj.finished_reason[i] is not None
|
||||||
|
state.event.set()
|
||||||
|
|
||||||
recv_obj.meta_info[i]["id"] = rid
|
|
||||||
out_dict = {
|
|
||||||
"text": recv_obj.output_str[i],
|
|
||||||
"meta_info": recv_obj.meta_info[i],
|
|
||||||
}
|
|
||||||
state.out_list.append(out_dict)
|
|
||||||
state.finished = recv_obj.finished[i]
|
|
||||||
state.event.set()
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid object: {recv_obj}.")
|
|
||||||
|
|
||||||
def convert_logprob_style(
|
def convert_logprob_style(
|
||||||
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
|
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
|
||||||
|
|||||||
Reference in New Issue
Block a user