diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index ac6bf5d62..5eed985d3 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -90,6 +90,7 @@ class Req: # 1: surr_offset # 2: read_offset # 3: last token + self.vid = 0 # version id to sync decode status with in detokenizer_manager self.decoded_text = "" self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm self.read_offset = None @@ -520,6 +521,9 @@ class Batch: req.output_ids = cur_output_ids continue + # The decode status has diverged from detokenizer_manager + req.vid += 1 + # insert the old request into tree_cache if req_pool_indices_cpu is None: req_pool_indices_cpu = self.req_pool_indices.tolist() diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index ab189c27e..183a7a786 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -589,6 +589,7 @@ class ModelTpServer: def handle_finished_requests(self, batch: Batch): output_rids = [] + output_vids = [] decoded_texts = [] output_read_ids = [] output_read_offsets = [] @@ -614,6 +615,7 @@ class ModelTpServer: ) ): output_rids.append(req.rid) + output_vids.append(req.vid) decoded_texts.append(req.decoded_text) read_ids, read_offset = req.init_incremental_detokenize() output_read_ids.append(read_ids) @@ -653,6 +655,7 @@ class ModelTpServer: self.out_pyobjs.append( BatchTokenIDOut( output_rids, + output_vids, decoded_texts, output_read_ids, output_read_offsets, diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 046cb37b6..ecf493625 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -20,6 +20,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @dataclasses.dataclass class DecodeStatus: + vid: int decoded_text: str decode_ids: List[int] surr_offset: int @@ -53,13 +54,14 @@ class DetokenizerManager: assert isinstance(recv_obj, BatchTokenIDOut) bs = len(recv_obj.rids) - # FIXME: incremental detokenize is not compatible with jump forward # Initialize decode status read_ids, surr_ids = [], [] for i in range(bs): rid = recv_obj.rids[i] - if rid not in self.decode_status: + vid = recv_obj.vids[i] + if rid not in self.decode_status or self.decode_status[rid].vid != vid: s = DecodeStatus( + vid=vid, decoded_text=recv_obj.decoded_texts[i], decode_ids=recv_obj.decode_ids[i], surr_offset=0, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index d1ad4b097..89de9b1c3 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -111,6 +111,7 @@ class TokenizedGenerateReqInput: @dataclass class BatchTokenIDOut: rids: List[str] + vids: List[int] decoded_texts: List[str] decode_ids: List[int] read_offsets: List[int]