Detokenize incrementally when streaming (#653)
This commit is contained in:
@@ -136,7 +136,33 @@ class RadixAttention(nn.Module):
|
||||
return self.decode_forward(q, k, v, input_metadata)
|
||||
|
||||
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
||||
key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
|
||||
key_buffer[input_metadata.out_cache_loc] = cache_k
|
||||
value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
|
||||
value_buffer[input_metadata.out_cache_loc] = cache_v
|
||||
kv_cache = input_metadata.token_to_kv_pool.kv_data[self.layer_id]
|
||||
_store_kv_cache(cache_k, cache_v, kv_cache, input_metadata.out_cache_loc)
|
||||
|
||||
|
||||
try:
|
||||
|
||||
@torch.library.custom_op("mylib::store_kv_cache", mutates_args={"kv_cache"})
|
||||
def _store_kv_cache(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
) -> None:
|
||||
kv_cache[cache_loc, 0] = k
|
||||
kv_cache[cache_loc, 1] = v
|
||||
|
||||
@_store_kv_cache.register_fake
|
||||
def _(k, v, kv_cache, cache_loc):
|
||||
pass
|
||||
|
||||
except:
|
||||
|
||||
def _store_kv_cache(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
) -> None:
|
||||
kv_cache[cache_loc, 0] = k
|
||||
kv_cache[cache_loc, 1] = v
|
||||
|
||||
@@ -82,6 +82,14 @@ class Req:
|
||||
self.input_ids = None # input_ids = origin_input_ids + output_ids
|
||||
|
||||
# For incremental decoding
|
||||
# ----- | --------- read_ids -------|
|
||||
# ----- | surr_ids |
|
||||
# xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
|
||||
# ----- ^ ----------- ^ ----------- ^
|
||||
# ----- 1 ----------- 2 ----------- 3
|
||||
# 1: surr_offset
|
||||
# 2: read_offset
|
||||
# 3: last token
|
||||
self.decoded_text = ""
|
||||
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
||||
self.read_offset = None
|
||||
@@ -132,7 +140,7 @@ class Req:
|
||||
return self.finished_reason is not None
|
||||
|
||||
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
|
||||
def init_detokenize_incrementally(self):
|
||||
def init_incremental_detokenize(self):
|
||||
first_iter = self.surr_offset is None or self.read_offset is None
|
||||
|
||||
if first_iter:
|
||||
@@ -142,13 +150,11 @@ class Req:
|
||||
)
|
||||
|
||||
all_ids = self.origin_input_ids_unpadded + self.output_ids
|
||||
surr_ids = all_ids[self.surr_offset : self.read_offset]
|
||||
read_ids = all_ids[self.surr_offset :]
|
||||
return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
|
||||
|
||||
return surr_ids, read_ids, len(all_ids)
|
||||
|
||||
def detokenize_incrementally(self, inplace: bool = True):
|
||||
surr_ids, read_ids, num_all_tokens = self.init_detokenize_incrementally()
|
||||
def get_next_inc_detokenization(self):
|
||||
read_ids, read_offset = self.init_incremental_detokenize()
|
||||
surr_ids = read_ids[:read_offset]
|
||||
|
||||
surr_text = self.tokenizer.decode(
|
||||
surr_ids,
|
||||
@@ -162,13 +168,7 @@ class Req:
|
||||
)
|
||||
|
||||
if len(new_text) > len(surr_text) and not new_text.endswith("<EFBFBD>"):
|
||||
new_text = new_text[len(surr_text) :]
|
||||
if inplace:
|
||||
self.decoded_text += new_text
|
||||
self.surr_offset = self.read_offset
|
||||
self.read_offset = num_all_tokens
|
||||
|
||||
return True, new_text
|
||||
return True, new_text[len(surr_text) :]
|
||||
|
||||
return False, ""
|
||||
|
||||
@@ -501,7 +501,7 @@ class Batch:
|
||||
cur_output_ids = req.output_ids
|
||||
|
||||
req.output_ids.extend(suffix_ids)
|
||||
decode_res, new_text = req.detokenize_incrementally(inplace=False)
|
||||
decode_res, new_text = req.get_next_inc_detokenization()
|
||||
if not decode_res:
|
||||
req.output_ids = cur_output_ids
|
||||
continue
|
||||
|
||||
@@ -590,8 +590,8 @@ class ModelTpServer:
|
||||
def handle_finished_requests(self, batch: Batch):
|
||||
output_rids = []
|
||||
decoded_texts = []
|
||||
surr_output_ids = []
|
||||
read_output_ids = []
|
||||
output_read_ids = []
|
||||
output_read_offsets = []
|
||||
output_skip_special_tokens = []
|
||||
output_spaces_between_special_tokens = []
|
||||
output_meta_info = []
|
||||
@@ -615,9 +615,9 @@ class ModelTpServer:
|
||||
):
|
||||
output_rids.append(req.rid)
|
||||
decoded_texts.append(req.decoded_text)
|
||||
surr_ids, read_ids, _ = req.init_detokenize_incrementally()
|
||||
surr_output_ids.append(surr_ids)
|
||||
read_output_ids.append(read_ids)
|
||||
read_ids, read_offset = req.init_incremental_detokenize()
|
||||
output_read_ids.append(read_ids)
|
||||
output_read_offsets.append(read_offset)
|
||||
output_skip_special_tokens.append(
|
||||
req.sampling_params.skip_special_tokens
|
||||
)
|
||||
@@ -654,8 +654,8 @@ class ModelTpServer:
|
||||
BatchTokenIDOut(
|
||||
output_rids,
|
||||
decoded_texts,
|
||||
surr_output_ids,
|
||||
read_output_ids,
|
||||
output_read_ids,
|
||||
output_read_offsets,
|
||||
output_skip_special_tokens,
|
||||
output_spaces_between_special_tokens,
|
||||
output_meta_info,
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""DetokenizerManager is a process that detokenizes the token ids."""
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import inspect
|
||||
from typing import List
|
||||
|
||||
import uvloop
|
||||
import zmq
|
||||
@@ -16,6 +18,14 @@ from sglang.utils import find_printable_text, get_exception_traceback, graceful_
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DecodeStatus:
|
||||
decoded_text: str
|
||||
decode_ids: List[int]
|
||||
surr_offset: int
|
||||
read_offset: int
|
||||
|
||||
|
||||
class DetokenizerManager:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -35,19 +45,42 @@ class DetokenizerManager:
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
|
||||
self.decode_status = {}
|
||||
|
||||
async def handle_loop(self):
|
||||
while True:
|
||||
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
|
||||
assert isinstance(recv_obj, BatchTokenIDOut)
|
||||
bs = len(recv_obj.rids)
|
||||
|
||||
# FIXME: incremental detokenize is not compatible with jump forward
|
||||
# Initialize decode status
|
||||
read_ids, surr_ids = [], []
|
||||
for i in range(bs):
|
||||
rid = recv_obj.rids[i]
|
||||
if rid not in self.decode_status:
|
||||
s = DecodeStatus(
|
||||
decoded_text=recv_obj.decoded_texts[i],
|
||||
decode_ids=recv_obj.decode_ids[i],
|
||||
surr_offset=0,
|
||||
read_offset=recv_obj.read_offsets[i],
|
||||
)
|
||||
self.decode_status[rid] = s
|
||||
else:
|
||||
s = self.decode_status[rid]
|
||||
s.decode_ids = recv_obj.decode_ids[i]
|
||||
|
||||
read_ids.append(s.decode_ids[s.surr_offset :])
|
||||
surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
|
||||
|
||||
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
|
||||
surr_texts = self.tokenizer.batch_decode(
|
||||
recv_obj.surr_output_ids,
|
||||
surr_ids,
|
||||
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
||||
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
||||
)
|
||||
read_texts = self.tokenizer.batch_decode(
|
||||
recv_obj.read_output_ids,
|
||||
read_ids,
|
||||
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
||||
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
||||
)
|
||||
@@ -55,11 +88,20 @@ class DetokenizerManager:
|
||||
# Trim stop str
|
||||
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
||||
output_strs = []
|
||||
for i in range(len(recv_obj.rids)):
|
||||
for i in range(bs):
|
||||
s = self.decode_status[recv_obj.rids[i]]
|
||||
new_text = read_texts[i][len(surr_texts[i]) :]
|
||||
if recv_obj.finished_reason[i] is None:
|
||||
new_text = find_printable_text(new_text)
|
||||
output_strs.append(recv_obj.decoded_texts[i] + new_text)
|
||||
# Streaming chunk: update the decode status
|
||||
if len(new_text) > 0 and not new_text.endswith("<EFBFBD>"):
|
||||
s.decoded_text = s.decoded_text + new_text
|
||||
s.surr_offset = s.read_offset
|
||||
s.read_offset = len(s.decode_ids)
|
||||
new_text = ""
|
||||
else:
|
||||
new_text = find_printable_text(new_text)
|
||||
|
||||
output_strs.append(s.decoded_text + new_text)
|
||||
|
||||
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
|
||||
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
|
||||
|
||||
@@ -111,8 +111,8 @@ class TokenizedGenerateReqInput:
|
||||
class BatchTokenIDOut:
|
||||
rids: List[str]
|
||||
decoded_texts: List[str]
|
||||
surr_output_ids: List[List[int]]
|
||||
read_output_ids: List[List[int]]
|
||||
decode_ids: List[int]
|
||||
read_offsets: List[int]
|
||||
skip_special_tokens: List[bool]
|
||||
spaces_between_special_tokens: List[bool]
|
||||
meta_info: List[Dict]
|
||||
|
||||
Reference in New Issue
Block a user