Fix rid state map leak + Refractor .finished (#505)

Co-authored-by: ZX <zx@lbx.dev>
This commit is contained in:
Qubitium
2024-06-08 04:20:40 +08:00
committed by GitHub
parent c0ae70c8ed
commit f70f72586a
7 changed files with 130 additions and 108 deletions

View File

@@ -19,7 +19,7 @@ from sglang.srt.managers.io_struct import (
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.controller.infer_batch import Batch, FinishReason, ForwardMode, Req
from sglang.srt.managers.controller.infer_batch import BaseFinishReason, Batch, FINISH_ABORT, ForwardMode, Req
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
@@ -595,20 +595,19 @@ class ModelTpServer:
output_rids = []
prev_output_strs = []
output_tokens = []
output_hit_stop_str = []
output_skip_special_tokens = []
output_spaces_between_special_tokens = []
output_meta_info = []
output_finished = []
output_finished_reason: List[BaseFinishReason] = []
finished_indices = []
unfinished_indices = []
for i, req in enumerate(batch.reqs):
if req.finished:
if req.finished():
finished_indices.append(i)
else:
unfinished_indices.append(i)
if req.finished or (
if req.finished() or (
(
req.stream
and (
@@ -620,7 +619,6 @@ class ModelTpServer:
output_rids.append(req.rid)
prev_output_strs.append(req.prev_output_str)
output_tokens.append(req.output_ids)
output_hit_stop_str.append(req.hit_stop_str)
output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens
)
@@ -632,8 +630,7 @@ class ModelTpServer:
"prompt_tokens": len(req.origin_input_ids),
"completion_tokens": len(req.prev_output_ids) + len(req.output_ids),
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"finish_reason": FinishReason.to_str(req.finish_reason),
"hit_stop_str": req.hit_stop_str,
"finish_reason": str(req.finished_reason),
}
if req.return_logprob:
(
@@ -650,7 +647,7 @@ class ModelTpServer:
req.normalized_prompt_logprob,
)
output_meta_info.append(meta_info)
output_finished.append(req.finished)
output_finished_reason.append(req.finished_reason)
# Send to detokenizer
if output_rids:
@@ -659,11 +656,10 @@ class ModelTpServer:
output_rids,
prev_output_strs,
output_tokens,
output_hit_stop_str,
output_skip_special_tokens,
output_spaces_between_special_tokens,
output_meta_info,
output_finished,
output_finished_reason,
)
)
@@ -720,8 +716,7 @@ class ModelTpServer:
if self.running_batch:
for req in self.running_batch.reqs:
if req.rid == recv_req.rid:
req.finished = True
req.finish_reason = FinishReason.ABORT
req.finished_reason = FINISH_ABORT()
break