Fix jump forward when streaming (#665)
This commit is contained in:
@@ -90,6 +90,7 @@ class Req:
|
||||
# 1: surr_offset
|
||||
# 2: read_offset
|
||||
# 3: last token
|
||||
self.vid = 0 # version id to sync decode status with in detokenizer_manager
|
||||
self.decoded_text = ""
|
||||
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
||||
self.read_offset = None
|
||||
@@ -520,6 +521,9 @@ class Batch:
|
||||
req.output_ids = cur_output_ids
|
||||
continue
|
||||
|
||||
# The decode status has diverged from detokenizer_manager
|
||||
req.vid += 1
|
||||
|
||||
# insert the old request into tree_cache
|
||||
if req_pool_indices_cpu is None:
|
||||
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
||||
|
||||
@@ -589,6 +589,7 @@ class ModelTpServer:
|
||||
|
||||
def handle_finished_requests(self, batch: Batch):
|
||||
output_rids = []
|
||||
output_vids = []
|
||||
decoded_texts = []
|
||||
output_read_ids = []
|
||||
output_read_offsets = []
|
||||
@@ -614,6 +615,7 @@ class ModelTpServer:
|
||||
)
|
||||
):
|
||||
output_rids.append(req.rid)
|
||||
output_vids.append(req.vid)
|
||||
decoded_texts.append(req.decoded_text)
|
||||
read_ids, read_offset = req.init_incremental_detokenize()
|
||||
output_read_ids.append(read_ids)
|
||||
@@ -653,6 +655,7 @@ class ModelTpServer:
|
||||
self.out_pyobjs.append(
|
||||
BatchTokenIDOut(
|
||||
output_rids,
|
||||
output_vids,
|
||||
decoded_texts,
|
||||
output_read_ids,
|
||||
output_read_offsets,
|
||||
|
||||
@@ -20,6 +20,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DecodeStatus:
|
||||
vid: int
|
||||
decoded_text: str
|
||||
decode_ids: List[int]
|
||||
surr_offset: int
|
||||
@@ -53,13 +54,14 @@ class DetokenizerManager:
|
||||
assert isinstance(recv_obj, BatchTokenIDOut)
|
||||
bs = len(recv_obj.rids)
|
||||
|
||||
# FIXME: incremental detokenize is not compatible with jump forward
|
||||
# Initialize decode status
|
||||
read_ids, surr_ids = [], []
|
||||
for i in range(bs):
|
||||
rid = recv_obj.rids[i]
|
||||
if rid not in self.decode_status:
|
||||
vid = recv_obj.vids[i]
|
||||
if rid not in self.decode_status or self.decode_status[rid].vid != vid:
|
||||
s = DecodeStatus(
|
||||
vid=vid,
|
||||
decoded_text=recv_obj.decoded_texts[i],
|
||||
decode_ids=recv_obj.decode_ids[i],
|
||||
surr_offset=0,
|
||||
|
||||
@@ -111,6 +111,7 @@ class TokenizedGenerateReqInput:
|
||||
@dataclass
|
||||
class BatchTokenIDOut:
|
||||
rids: List[str]
|
||||
vids: List[int]
|
||||
decoded_texts: List[str]
|
||||
decode_ids: List[int]
|
||||
read_offsets: List[int]
|
||||
|
||||
Reference in New Issue
Block a user