From a9ef49c12ccd1c36fb225b8831f8a434d90485f4 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Thu, 18 Jul 2024 17:57:40 -0700 Subject: [PATCH] Detokenize incrementally when streaming (#653) --- python/sglang/srt/layers/radix_attention.py | 34 ++++++++++-- .../srt/managers/controller/infer_batch.py | 30 +++++------ .../srt/managers/controller/tp_worker.py | 14 ++--- .../srt/managers/detokenizer_manager.py | 52 +++++++++++++++++-- python/sglang/srt/managers/io_struct.py | 4 +- 5 files changed, 101 insertions(+), 33 deletions(-) diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 7f57c6a96..bf2ca9ba2 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -136,7 +136,33 @@ class RadixAttention(nn.Module): return self.decode_forward(q, k, v, input_metadata) def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata): - key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id) - key_buffer[input_metadata.out_cache_loc] = cache_k - value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id) - value_buffer[input_metadata.out_cache_loc] = cache_v + kv_cache = input_metadata.token_to_kv_pool.kv_data[self.layer_id] + _store_kv_cache(cache_k, cache_v, kv_cache, input_metadata.out_cache_loc) + + +try: + + @torch.library.custom_op("mylib::store_kv_cache", mutates_args={"kv_cache"}) + def _store_kv_cache( + k: torch.Tensor, + v: torch.Tensor, + kv_cache: torch.Tensor, + cache_loc: torch.Tensor, + ) -> None: + kv_cache[cache_loc, 0] = k + kv_cache[cache_loc, 1] = v + + @_store_kv_cache.register_fake + def _(k, v, kv_cache, cache_loc): + pass + +except: + + def _store_kv_cache( + k: torch.Tensor, + v: torch.Tensor, + kv_cache: torch.Tensor, + cache_loc: torch.Tensor, + ) -> None: + kv_cache[cache_loc, 0] = k + kv_cache[cache_loc, 1] = v diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index 191bf388f..ac6bf5d62 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -82,6 +82,14 @@ class Req: self.input_ids = None # input_ids = origin_input_ids + output_ids # For incremental decoding + # ----- | --------- read_ids -------| + # ----- | surr_ids | + # xxxxx | xxxxxxxxxxx | xxxxxxxxxxx | + # ----- ^ ----------- ^ ----------- ^ + # ----- 1 ----------- 2 ----------- 3 + # 1: surr_offset + # 2: read_offset + # 3: last token self.decoded_text = "" self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm self.read_offset = None @@ -132,7 +140,7 @@ class Req: return self.finished_reason is not None # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313 - def init_detokenize_incrementally(self): + def init_incremental_detokenize(self): first_iter = self.surr_offset is None or self.read_offset is None if first_iter: @@ -142,13 +150,11 @@ class Req: ) all_ids = self.origin_input_ids_unpadded + self.output_ids - surr_ids = all_ids[self.surr_offset : self.read_offset] - read_ids = all_ids[self.surr_offset :] + return all_ids[self.surr_offset :], self.read_offset - self.surr_offset - return surr_ids, read_ids, len(all_ids) - - def detokenize_incrementally(self, inplace: bool = True): - surr_ids, read_ids, num_all_tokens = self.init_detokenize_incrementally() + def get_next_inc_detokenization(self): + read_ids, read_offset = self.init_incremental_detokenize() + surr_ids = read_ids[:read_offset] surr_text = self.tokenizer.decode( surr_ids, @@ -162,13 +168,7 @@ class Req: ) if len(new_text) > len(surr_text) and not new_text.endswith("�"): - new_text = new_text[len(surr_text) :] - if inplace: - self.decoded_text += new_text - self.surr_offset = self.read_offset - self.read_offset = num_all_tokens - - return True, new_text + return True, new_text[len(surr_text) :] return False, "" @@ -501,7 +501,7 @@ class Batch: cur_output_ids = req.output_ids req.output_ids.extend(suffix_ids) - decode_res, new_text = req.detokenize_incrementally(inplace=False) + decode_res, new_text = req.get_next_inc_detokenization() if not decode_res: req.output_ids = cur_output_ids continue diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index 14a557e27..ab189c27e 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -590,8 +590,8 @@ class ModelTpServer: def handle_finished_requests(self, batch: Batch): output_rids = [] decoded_texts = [] - surr_output_ids = [] - read_output_ids = [] + output_read_ids = [] + output_read_offsets = [] output_skip_special_tokens = [] output_spaces_between_special_tokens = [] output_meta_info = [] @@ -615,9 +615,9 @@ class ModelTpServer: ): output_rids.append(req.rid) decoded_texts.append(req.decoded_text) - surr_ids, read_ids, _ = req.init_detokenize_incrementally() - surr_output_ids.append(surr_ids) - read_output_ids.append(read_ids) + read_ids, read_offset = req.init_incremental_detokenize() + output_read_ids.append(read_ids) + output_read_offsets.append(read_offset) output_skip_special_tokens.append( req.sampling_params.skip_special_tokens ) @@ -654,8 +654,8 @@ class ModelTpServer: BatchTokenIDOut( output_rids, decoded_texts, - surr_output_ids, - read_output_ids, + output_read_ids, + output_read_offsets, output_skip_special_tokens, output_spaces_between_special_tokens, output_meta_info, diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 3e0183b1b..046cb37b6 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -1,7 +1,9 @@ """DetokenizerManager is a process that detokenizes the token ids.""" import asyncio +import dataclasses import inspect +from typing import List import uvloop import zmq @@ -16,6 +18,14 @@ from sglang.utils import find_printable_text, get_exception_traceback, graceful_ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +@dataclasses.dataclass +class DecodeStatus: + decoded_text: str + decode_ids: List[int] + surr_offset: int + read_offset: int + + class DetokenizerManager: def __init__( self, @@ -35,19 +45,42 @@ class DetokenizerManager: trust_remote_code=server_args.trust_remote_code, ) + self.decode_status = {} + async def handle_loop(self): while True: recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj() assert isinstance(recv_obj, BatchTokenIDOut) + bs = len(recv_obj.rids) + + # FIXME: incremental detokenize is not compatible with jump forward + # Initialize decode status + read_ids, surr_ids = [], [] + for i in range(bs): + rid = recv_obj.rids[i] + if rid not in self.decode_status: + s = DecodeStatus( + decoded_text=recv_obj.decoded_texts[i], + decode_ids=recv_obj.decode_ids[i], + surr_offset=0, + read_offset=recv_obj.read_offsets[i], + ) + self.decode_status[rid] = s + else: + s = self.decode_status[rid] + s.decode_ids = recv_obj.decode_ids[i] + + read_ids.append(s.decode_ids[s.surr_offset :]) + surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset]) # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request surr_texts = self.tokenizer.batch_decode( - recv_obj.surr_output_ids, + surr_ids, skip_special_tokens=recv_obj.skip_special_tokens[0], spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], ) read_texts = self.tokenizer.batch_decode( - recv_obj.read_output_ids, + read_ids, skip_special_tokens=recv_obj.skip_special_tokens[0], spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], ) @@ -55,11 +88,20 @@ class DetokenizerManager: # Trim stop str # TODO(lmzheng): handle the case where multiple stop strs are hit output_strs = [] - for i in range(len(recv_obj.rids)): + for i in range(bs): + s = self.decode_status[recv_obj.rids[i]] new_text = read_texts[i][len(surr_texts[i]) :] if recv_obj.finished_reason[i] is None: - new_text = find_printable_text(new_text) - output_strs.append(recv_obj.decoded_texts[i] + new_text) + # Streaming chunk: update the decode status + if len(new_text) > 0 and not new_text.endswith("�"): + s.decoded_text = s.decoded_text + new_text + s.surr_offset = s.read_offset + s.read_offset = len(s.decode_ids) + new_text = "" + else: + new_text = find_printable_text(new_text) + + output_strs.append(s.decoded_text + new_text) if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR): pos = output_strs[i].find(recv_obj.finished_reason[i].matched) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 7b26a4f2d..f0240f6dc 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -111,8 +111,8 @@ class TokenizedGenerateReqInput: class BatchTokenIDOut: rids: List[str] decoded_texts: List[str] - surr_output_ids: List[List[int]] - read_output_ids: List[List[int]] + decode_ids: List[int] + read_offsets: List[int] skip_special_tokens: List[bool] spaces_between_special_tokens: List[bool] meta_info: List[Dict]