Fix jump forward when streaming (#665)
This commit is contained in:
@@ -90,6 +90,7 @@ class Req:
|
|||||||
# 1: surr_offset
|
# 1: surr_offset
|
||||||
# 2: read_offset
|
# 2: read_offset
|
||||||
# 3: last token
|
# 3: last token
|
||||||
|
self.vid = 0 # version id to sync decode status with in detokenizer_manager
|
||||||
self.decoded_text = ""
|
self.decoded_text = ""
|
||||||
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
||||||
self.read_offset = None
|
self.read_offset = None
|
||||||
@@ -520,6 +521,9 @@ class Batch:
|
|||||||
req.output_ids = cur_output_ids
|
req.output_ids = cur_output_ids
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# The decode status has diverged from detokenizer_manager
|
||||||
|
req.vid += 1
|
||||||
|
|
||||||
# insert the old request into tree_cache
|
# insert the old request into tree_cache
|
||||||
if req_pool_indices_cpu is None:
|
if req_pool_indices_cpu is None:
|
||||||
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
||||||
|
|||||||
@@ -589,6 +589,7 @@ class ModelTpServer:
|
|||||||
|
|
||||||
def handle_finished_requests(self, batch: Batch):
|
def handle_finished_requests(self, batch: Batch):
|
||||||
output_rids = []
|
output_rids = []
|
||||||
|
output_vids = []
|
||||||
decoded_texts = []
|
decoded_texts = []
|
||||||
output_read_ids = []
|
output_read_ids = []
|
||||||
output_read_offsets = []
|
output_read_offsets = []
|
||||||
@@ -614,6 +615,7 @@ class ModelTpServer:
|
|||||||
)
|
)
|
||||||
):
|
):
|
||||||
output_rids.append(req.rid)
|
output_rids.append(req.rid)
|
||||||
|
output_vids.append(req.vid)
|
||||||
decoded_texts.append(req.decoded_text)
|
decoded_texts.append(req.decoded_text)
|
||||||
read_ids, read_offset = req.init_incremental_detokenize()
|
read_ids, read_offset = req.init_incremental_detokenize()
|
||||||
output_read_ids.append(read_ids)
|
output_read_ids.append(read_ids)
|
||||||
@@ -653,6 +655,7 @@ class ModelTpServer:
|
|||||||
self.out_pyobjs.append(
|
self.out_pyobjs.append(
|
||||||
BatchTokenIDOut(
|
BatchTokenIDOut(
|
||||||
output_rids,
|
output_rids,
|
||||||
|
output_vids,
|
||||||
decoded_texts,
|
decoded_texts,
|
||||||
output_read_ids,
|
output_read_ids,
|
||||||
output_read_offsets,
|
output_read_offsets,
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class DecodeStatus:
|
class DecodeStatus:
|
||||||
|
vid: int
|
||||||
decoded_text: str
|
decoded_text: str
|
||||||
decode_ids: List[int]
|
decode_ids: List[int]
|
||||||
surr_offset: int
|
surr_offset: int
|
||||||
@@ -53,13 +54,14 @@ class DetokenizerManager:
|
|||||||
assert isinstance(recv_obj, BatchTokenIDOut)
|
assert isinstance(recv_obj, BatchTokenIDOut)
|
||||||
bs = len(recv_obj.rids)
|
bs = len(recv_obj.rids)
|
||||||
|
|
||||||
# FIXME: incremental detokenize is not compatible with jump forward
|
|
||||||
# Initialize decode status
|
# Initialize decode status
|
||||||
read_ids, surr_ids = [], []
|
read_ids, surr_ids = [], []
|
||||||
for i in range(bs):
|
for i in range(bs):
|
||||||
rid = recv_obj.rids[i]
|
rid = recv_obj.rids[i]
|
||||||
if rid not in self.decode_status:
|
vid = recv_obj.vids[i]
|
||||||
|
if rid not in self.decode_status or self.decode_status[rid].vid != vid:
|
||||||
s = DecodeStatus(
|
s = DecodeStatus(
|
||||||
|
vid=vid,
|
||||||
decoded_text=recv_obj.decoded_texts[i],
|
decoded_text=recv_obj.decoded_texts[i],
|
||||||
decode_ids=recv_obj.decode_ids[i],
|
decode_ids=recv_obj.decode_ids[i],
|
||||||
surr_offset=0,
|
surr_offset=0,
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ class TokenizedGenerateReqInput:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class BatchTokenIDOut:
|
class BatchTokenIDOut:
|
||||||
rids: List[str]
|
rids: List[str]
|
||||||
|
vids: List[int]
|
||||||
decoded_texts: List[str]
|
decoded_texts: List[str]
|
||||||
decode_ids: List[int]
|
decode_ids: List[int]
|
||||||
read_offsets: List[int]
|
read_offsets: List[int]
|
||||||
|
|||||||
Reference in New Issue
Block a user