Detokenize incrementally when streaming (#653)
This commit is contained in:
@@ -136,7 +136,33 @@ class RadixAttention(nn.Module):
|
|||||||
return self.decode_forward(q, k, v, input_metadata)
|
return self.decode_forward(q, k, v, input_metadata)
|
||||||
|
|
||||||
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
||||||
key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
|
kv_cache = input_metadata.token_to_kv_pool.kv_data[self.layer_id]
|
||||||
key_buffer[input_metadata.out_cache_loc] = cache_k
|
_store_kv_cache(cache_k, cache_v, kv_cache, input_metadata.out_cache_loc)
|
||||||
value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
|
|
||||||
value_buffer[input_metadata.out_cache_loc] = cache_v
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
@torch.library.custom_op("mylib::store_kv_cache", mutates_args={"kv_cache"})
|
||||||
|
def _store_kv_cache(
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
cache_loc: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
kv_cache[cache_loc, 0] = k
|
||||||
|
kv_cache[cache_loc, 1] = v
|
||||||
|
|
||||||
|
@_store_kv_cache.register_fake
|
||||||
|
def _(k, v, kv_cache, cache_loc):
|
||||||
|
pass
|
||||||
|
|
||||||
|
except:
|
||||||
|
|
||||||
|
def _store_kv_cache(
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
cache_loc: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
kv_cache[cache_loc, 0] = k
|
||||||
|
kv_cache[cache_loc, 1] = v
|
||||||
|
|||||||
@@ -82,6 +82,14 @@ class Req:
|
|||||||
self.input_ids = None # input_ids = origin_input_ids + output_ids
|
self.input_ids = None # input_ids = origin_input_ids + output_ids
|
||||||
|
|
||||||
# For incremental decoding
|
# For incremental decoding
|
||||||
|
# ----- | --------- read_ids -------|
|
||||||
|
# ----- | surr_ids |
|
||||||
|
# xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
|
||||||
|
# ----- ^ ----------- ^ ----------- ^
|
||||||
|
# ----- 1 ----------- 2 ----------- 3
|
||||||
|
# 1: surr_offset
|
||||||
|
# 2: read_offset
|
||||||
|
# 3: last token
|
||||||
self.decoded_text = ""
|
self.decoded_text = ""
|
||||||
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
||||||
self.read_offset = None
|
self.read_offset = None
|
||||||
@@ -132,7 +140,7 @@ class Req:
|
|||||||
return self.finished_reason is not None
|
return self.finished_reason is not None
|
||||||
|
|
||||||
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
|
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
|
||||||
def init_detokenize_incrementally(self):
|
def init_incremental_detokenize(self):
|
||||||
first_iter = self.surr_offset is None or self.read_offset is None
|
first_iter = self.surr_offset is None or self.read_offset is None
|
||||||
|
|
||||||
if first_iter:
|
if first_iter:
|
||||||
@@ -142,13 +150,11 @@ class Req:
|
|||||||
)
|
)
|
||||||
|
|
||||||
all_ids = self.origin_input_ids_unpadded + self.output_ids
|
all_ids = self.origin_input_ids_unpadded + self.output_ids
|
||||||
surr_ids = all_ids[self.surr_offset : self.read_offset]
|
return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
|
||||||
read_ids = all_ids[self.surr_offset :]
|
|
||||||
|
|
||||||
return surr_ids, read_ids, len(all_ids)
|
def get_next_inc_detokenization(self):
|
||||||
|
read_ids, read_offset = self.init_incremental_detokenize()
|
||||||
def detokenize_incrementally(self, inplace: bool = True):
|
surr_ids = read_ids[:read_offset]
|
||||||
surr_ids, read_ids, num_all_tokens = self.init_detokenize_incrementally()
|
|
||||||
|
|
||||||
surr_text = self.tokenizer.decode(
|
surr_text = self.tokenizer.decode(
|
||||||
surr_ids,
|
surr_ids,
|
||||||
@@ -162,13 +168,7 @@ class Req:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if len(new_text) > len(surr_text) and not new_text.endswith("<EFBFBD>"):
|
if len(new_text) > len(surr_text) and not new_text.endswith("<EFBFBD>"):
|
||||||
new_text = new_text[len(surr_text) :]
|
return True, new_text[len(surr_text) :]
|
||||||
if inplace:
|
|
||||||
self.decoded_text += new_text
|
|
||||||
self.surr_offset = self.read_offset
|
|
||||||
self.read_offset = num_all_tokens
|
|
||||||
|
|
||||||
return True, new_text
|
|
||||||
|
|
||||||
return False, ""
|
return False, ""
|
||||||
|
|
||||||
@@ -501,7 +501,7 @@ class Batch:
|
|||||||
cur_output_ids = req.output_ids
|
cur_output_ids = req.output_ids
|
||||||
|
|
||||||
req.output_ids.extend(suffix_ids)
|
req.output_ids.extend(suffix_ids)
|
||||||
decode_res, new_text = req.detokenize_incrementally(inplace=False)
|
decode_res, new_text = req.get_next_inc_detokenization()
|
||||||
if not decode_res:
|
if not decode_res:
|
||||||
req.output_ids = cur_output_ids
|
req.output_ids = cur_output_ids
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -590,8 +590,8 @@ class ModelTpServer:
|
|||||||
def handle_finished_requests(self, batch: Batch):
|
def handle_finished_requests(self, batch: Batch):
|
||||||
output_rids = []
|
output_rids = []
|
||||||
decoded_texts = []
|
decoded_texts = []
|
||||||
surr_output_ids = []
|
output_read_ids = []
|
||||||
read_output_ids = []
|
output_read_offsets = []
|
||||||
output_skip_special_tokens = []
|
output_skip_special_tokens = []
|
||||||
output_spaces_between_special_tokens = []
|
output_spaces_between_special_tokens = []
|
||||||
output_meta_info = []
|
output_meta_info = []
|
||||||
@@ -615,9 +615,9 @@ class ModelTpServer:
|
|||||||
):
|
):
|
||||||
output_rids.append(req.rid)
|
output_rids.append(req.rid)
|
||||||
decoded_texts.append(req.decoded_text)
|
decoded_texts.append(req.decoded_text)
|
||||||
surr_ids, read_ids, _ = req.init_detokenize_incrementally()
|
read_ids, read_offset = req.init_incremental_detokenize()
|
||||||
surr_output_ids.append(surr_ids)
|
output_read_ids.append(read_ids)
|
||||||
read_output_ids.append(read_ids)
|
output_read_offsets.append(read_offset)
|
||||||
output_skip_special_tokens.append(
|
output_skip_special_tokens.append(
|
||||||
req.sampling_params.skip_special_tokens
|
req.sampling_params.skip_special_tokens
|
||||||
)
|
)
|
||||||
@@ -654,8 +654,8 @@ class ModelTpServer:
|
|||||||
BatchTokenIDOut(
|
BatchTokenIDOut(
|
||||||
output_rids,
|
output_rids,
|
||||||
decoded_texts,
|
decoded_texts,
|
||||||
surr_output_ids,
|
output_read_ids,
|
||||||
read_output_ids,
|
output_read_offsets,
|
||||||
output_skip_special_tokens,
|
output_skip_special_tokens,
|
||||||
output_spaces_between_special_tokens,
|
output_spaces_between_special_tokens,
|
||||||
output_meta_info,
|
output_meta_info,
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
"""DetokenizerManager is a process that detokenizes the token ids."""
|
"""DetokenizerManager is a process that detokenizes the token ids."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import dataclasses
|
||||||
import inspect
|
import inspect
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import uvloop
|
import uvloop
|
||||||
import zmq
|
import zmq
|
||||||
@@ -16,6 +18,14 @@ from sglang.utils import find_printable_text, get_exception_traceback, graceful_
|
|||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class DecodeStatus:
|
||||||
|
decoded_text: str
|
||||||
|
decode_ids: List[int]
|
||||||
|
surr_offset: int
|
||||||
|
read_offset: int
|
||||||
|
|
||||||
|
|
||||||
class DetokenizerManager:
|
class DetokenizerManager:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -35,19 +45,42 @@ class DetokenizerManager:
|
|||||||
trust_remote_code=server_args.trust_remote_code,
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.decode_status = {}
|
||||||
|
|
||||||
async def handle_loop(self):
|
async def handle_loop(self):
|
||||||
while True:
|
while True:
|
||||||
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
|
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
|
||||||
assert isinstance(recv_obj, BatchTokenIDOut)
|
assert isinstance(recv_obj, BatchTokenIDOut)
|
||||||
|
bs = len(recv_obj.rids)
|
||||||
|
|
||||||
|
# FIXME: incremental detokenize is not compatible with jump forward
|
||||||
|
# Initialize decode status
|
||||||
|
read_ids, surr_ids = [], []
|
||||||
|
for i in range(bs):
|
||||||
|
rid = recv_obj.rids[i]
|
||||||
|
if rid not in self.decode_status:
|
||||||
|
s = DecodeStatus(
|
||||||
|
decoded_text=recv_obj.decoded_texts[i],
|
||||||
|
decode_ids=recv_obj.decode_ids[i],
|
||||||
|
surr_offset=0,
|
||||||
|
read_offset=recv_obj.read_offsets[i],
|
||||||
|
)
|
||||||
|
self.decode_status[rid] = s
|
||||||
|
else:
|
||||||
|
s = self.decode_status[rid]
|
||||||
|
s.decode_ids = recv_obj.decode_ids[i]
|
||||||
|
|
||||||
|
read_ids.append(s.decode_ids[s.surr_offset :])
|
||||||
|
surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
|
||||||
|
|
||||||
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
|
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
|
||||||
surr_texts = self.tokenizer.batch_decode(
|
surr_texts = self.tokenizer.batch_decode(
|
||||||
recv_obj.surr_output_ids,
|
surr_ids,
|
||||||
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
||||||
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
||||||
)
|
)
|
||||||
read_texts = self.tokenizer.batch_decode(
|
read_texts = self.tokenizer.batch_decode(
|
||||||
recv_obj.read_output_ids,
|
read_ids,
|
||||||
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
||||||
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
||||||
)
|
)
|
||||||
@@ -55,11 +88,20 @@ class DetokenizerManager:
|
|||||||
# Trim stop str
|
# Trim stop str
|
||||||
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
||||||
output_strs = []
|
output_strs = []
|
||||||
for i in range(len(recv_obj.rids)):
|
for i in range(bs):
|
||||||
|
s = self.decode_status[recv_obj.rids[i]]
|
||||||
new_text = read_texts[i][len(surr_texts[i]) :]
|
new_text = read_texts[i][len(surr_texts[i]) :]
|
||||||
if recv_obj.finished_reason[i] is None:
|
if recv_obj.finished_reason[i] is None:
|
||||||
new_text = find_printable_text(new_text)
|
# Streaming chunk: update the decode status
|
||||||
output_strs.append(recv_obj.decoded_texts[i] + new_text)
|
if len(new_text) > 0 and not new_text.endswith("<EFBFBD>"):
|
||||||
|
s.decoded_text = s.decoded_text + new_text
|
||||||
|
s.surr_offset = s.read_offset
|
||||||
|
s.read_offset = len(s.decode_ids)
|
||||||
|
new_text = ""
|
||||||
|
else:
|
||||||
|
new_text = find_printable_text(new_text)
|
||||||
|
|
||||||
|
output_strs.append(s.decoded_text + new_text)
|
||||||
|
|
||||||
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
|
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
|
||||||
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
|
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
|
||||||
|
|||||||
@@ -111,8 +111,8 @@ class TokenizedGenerateReqInput:
|
|||||||
class BatchTokenIDOut:
|
class BatchTokenIDOut:
|
||||||
rids: List[str]
|
rids: List[str]
|
||||||
decoded_texts: List[str]
|
decoded_texts: List[str]
|
||||||
surr_output_ids: List[List[int]]
|
decode_ids: List[int]
|
||||||
read_output_ids: List[List[int]]
|
read_offsets: List[int]
|
||||||
skip_special_tokens: List[bool]
|
skip_special_tokens: List[bool]
|
||||||
spaces_between_special_tokens: List[bool]
|
spaces_between_special_tokens: List[bool]
|
||||||
meta_info: List[Dict]
|
meta_info: List[Dict]
|
||||||
|
|||||||
Reference in New Issue
Block a user