Fix rid state map leak + Refractor .finished (#505)
Co-authored-by: ZX <zx@lbx.dev>
This commit is contained in:
@@ -19,7 +19,7 @@ from sglang.srt.managers.io_struct import (
|
||||
FlushCacheReq,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.controller.infer_batch import Batch, FinishReason, ForwardMode, Req
|
||||
from sglang.srt.managers.controller.infer_batch import BaseFinishReason, Batch, FINISH_ABORT, ForwardMode, Req
|
||||
from sglang.srt.managers.controller.model_runner import ModelRunner
|
||||
from sglang.srt.managers.controller.radix_cache import RadixCache
|
||||
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
|
||||
@@ -595,20 +595,19 @@ class ModelTpServer:
|
||||
output_rids = []
|
||||
prev_output_strs = []
|
||||
output_tokens = []
|
||||
output_hit_stop_str = []
|
||||
output_skip_special_tokens = []
|
||||
output_spaces_between_special_tokens = []
|
||||
output_meta_info = []
|
||||
output_finished = []
|
||||
output_finished_reason: List[BaseFinishReason] = []
|
||||
finished_indices = []
|
||||
unfinished_indices = []
|
||||
for i, req in enumerate(batch.reqs):
|
||||
if req.finished:
|
||||
if req.finished():
|
||||
finished_indices.append(i)
|
||||
else:
|
||||
unfinished_indices.append(i)
|
||||
|
||||
if req.finished or (
|
||||
if req.finished() or (
|
||||
(
|
||||
req.stream
|
||||
and (
|
||||
@@ -620,7 +619,6 @@ class ModelTpServer:
|
||||
output_rids.append(req.rid)
|
||||
prev_output_strs.append(req.prev_output_str)
|
||||
output_tokens.append(req.output_ids)
|
||||
output_hit_stop_str.append(req.hit_stop_str)
|
||||
output_skip_special_tokens.append(
|
||||
req.sampling_params.skip_special_tokens
|
||||
)
|
||||
@@ -632,8 +630,7 @@ class ModelTpServer:
|
||||
"prompt_tokens": len(req.origin_input_ids),
|
||||
"completion_tokens": len(req.prev_output_ids) + len(req.output_ids),
|
||||
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
||||
"finish_reason": FinishReason.to_str(req.finish_reason),
|
||||
"hit_stop_str": req.hit_stop_str,
|
||||
"finish_reason": str(req.finished_reason),
|
||||
}
|
||||
if req.return_logprob:
|
||||
(
|
||||
@@ -650,7 +647,7 @@ class ModelTpServer:
|
||||
req.normalized_prompt_logprob,
|
||||
)
|
||||
output_meta_info.append(meta_info)
|
||||
output_finished.append(req.finished)
|
||||
output_finished_reason.append(req.finished_reason)
|
||||
|
||||
# Send to detokenizer
|
||||
if output_rids:
|
||||
@@ -659,11 +656,10 @@ class ModelTpServer:
|
||||
output_rids,
|
||||
prev_output_strs,
|
||||
output_tokens,
|
||||
output_hit_stop_str,
|
||||
output_skip_special_tokens,
|
||||
output_spaces_between_special_tokens,
|
||||
output_meta_info,
|
||||
output_finished,
|
||||
output_finished_reason,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -720,8 +716,7 @@ class ModelTpServer:
|
||||
if self.running_batch:
|
||||
for req in self.running_batch.reqs:
|
||||
if req.rid == recv_req.rid:
|
||||
req.finished = True
|
||||
req.finish_reason = FinishReason.ABORT
|
||||
req.finished_reason = FINISH_ABORT()
|
||||
break
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user