From 9c902b1954c55ec152a5ea91ed47e8cb696f7e46 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Wed, 12 Jun 2024 14:39:12 +0800 Subject: [PATCH] Decode Incrementally (#517) --- examples/usage/chinese_regex.py | 53 +++++ python/sglang/srt/constrained/__init__.py | 7 +- python/sglang/srt/constrained/fsm_cache.py | 4 +- python/sglang/srt/constrained/jump_forward.py | 131 ++++++++--- .../srt/managers/controller/infer_batch.py | 203 ++++++++++++------ .../srt/managers/controller/tp_worker.py | 46 ++-- .../srt/managers/detokenizer_manager.py | 30 ++- python/sglang/srt/managers/io_struct.py | 6 +- 8 files changed, 345 insertions(+), 135 deletions(-) create mode 100644 examples/usage/chinese_regex.py diff --git a/examples/usage/chinese_regex.py b/examples/usage/chinese_regex.py new file mode 100644 index 000000000..78e9c7e16 --- /dev/null +++ b/examples/usage/chinese_regex.py @@ -0,0 +1,53 @@ +import sglang as sgl + +character_regex = ( + r"""\{\n""" + + r""" "姓名": "[^"]{1,32}",\n""" + + r""" "学院": "(格兰芬多|赫奇帕奇|拉文克劳|斯莱特林)",\n""" + + r""" "血型": "(纯血|混血|麻瓜)",\n""" + + r""" "职业": "(学生|教师|傲罗|魔法部|食死徒|凤凰社成员)",\n""" + + r""" "魔杖": \{\n""" + + r""" "材质": "[^"]{1,32}",\n""" + + r""" "杖芯": "[^"]{1,32}",\n""" + + r""" "长度": [0-9]{1,2}\.[0-9]{0,2}\n""" + + r""" \},\n""" + + r""" "存活": "(存活|死亡)",\n""" + + r""" "守护神": "[^"]{1,32}",\n""" + + r""" "博格特": "[^"]{1,32}"\n""" + + r"""\}""" +) + + +@sgl.function +def character_gen(s, name): + s += name + " 是一名哈利波特系列小说中的角色。请填写以下关于这个角色的信息。" + s += """\ +这是一个例子 +{ + "姓名": "哈利波特", + "学院": "格兰芬多", + "血型": "混血", + "职业": "学生", + "魔杖": { + "材质": "冬青木", + "杖芯": "凤凰尾羽", + "长度": 11.0 + }, + "存活": "存活", + "守护神": "麋鹿", + "博格特": "摄魂怪" +} +""" + s += f"现在请你填写{name}的信息:\n" + s += sgl.gen("json_output", max_tokens=256, regex=character_regex) + + +def main(): + backend = sgl.RuntimeEndpoint("http://localhost:30000") + sgl.set_default_backend(backend) + ret = character_gen.run(name="赫敏格兰杰", temperature=0) + print(ret.text()) + + +if __name__ == "__main__": + main() diff --git a/python/sglang/srt/constrained/__init__.py b/python/sglang/srt/constrained/__init__.py index 463d71c22..ab6f56e5d 100644 --- a/python/sglang/srt/constrained/__init__.py +++ b/python/sglang/srt/constrained/__init__.py @@ -3,8 +3,8 @@ from typing import Dict, Optional, Union from outlines.caching import cache as disk_cache from outlines.caching import disable_cache -from outlines.fsm.fsm import RegexFSM -from outlines.fsm.regex import FSMInfo, make_deterministic_fsm +from outlines.fsm.guide import RegexGuide +from outlines.fsm.regex import FSMInfo, make_deterministic_fsm, make_byte_level_fsm from outlines.models.transformers import TransformerTokenizer from pydantic import BaseModel @@ -28,11 +28,12 @@ except ImportError: __all__ = [ - "RegexFSM", + "RegexGuide", "FSMInfo", "make_deterministic_fsm", "build_regex_from_object", "TransformerTokenizer", "disk_cache", "disable_cache", + "make_byte_level_fsm", ] diff --git a/python/sglang/srt/constrained/fsm_cache.py b/python/sglang/srt/constrained/fsm_cache.py index a8cbde1dd..387ccf024 100644 --- a/python/sglang/srt/constrained/fsm_cache.py +++ b/python/sglang/srt/constrained/fsm_cache.py @@ -1,5 +1,5 @@ """Cache for the compressed finite state machine.""" -from sglang.srt.constrained import RegexFSM, TransformerTokenizer +from sglang.srt.constrained import RegexGuide, TransformerTokenizer from sglang.srt.constrained.base_cache import BaseCache @@ -26,4 +26,4 @@ class FSMCache(BaseCache): ) def init_value(self, regex): - return RegexFSM(regex, self.outlines_tokenizer) + return RegexGuide(regex, self.outlines_tokenizer) diff --git a/python/sglang/srt/constrained/jump_forward.py b/python/sglang/srt/constrained/jump_forward.py index 9e4a58803..f71123cf2 100644 --- a/python/sglang/srt/constrained/jump_forward.py +++ b/python/sglang/srt/constrained/jump_forward.py @@ -2,20 +2,41 @@ Faster constrained decoding. Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/ """ -import interegular -from sglang.srt.constrained import FSMInfo, disk_cache, make_deterministic_fsm +import interegular +import dataclasses +from collections import defaultdict + +import outlines.caching +from sglang.srt.constrained import ( + FSMInfo, + disk_cache, + make_deterministic_fsm, + make_byte_level_fsm, +) from sglang.srt.constrained.base_cache import BaseCache IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" +@dataclasses.dataclass +class JumpEdge: + symbol: str = None + symbol_next_state: int = None + byte: int = None + byte_next_state: int = None + + class JumpForwardMap: def __init__(self, regex_string): @disk_cache() def _init_state_to_jump_forward(regex_string): regex_pattern = interegular.parse_pattern(regex_string) - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) + + byte_fsm = make_byte_level_fsm( + regex_pattern.to_fsm().reduce(), keep_utf8=True + ) + regex_fsm, _ = make_deterministic_fsm(byte_fsm) fsm_info: FSMInfo = regex_fsm.fsm_info @@ -25,40 +46,91 @@ class JumpForwardMap: id_to_symbol.setdefault(id_, []).append(symbol) transitions = fsm_info.transitions - dirty_states = set() + outgoings_ct = defaultdict(int) state_to_jump_forward = {} for (state, id_), next_state in transitions.items(): - if state in dirty_states: - continue - if state in state_to_jump_forward: - dirty_states.add(state) - del state_to_jump_forward[state] - continue - if len(id_to_symbol[id_]) > 1: - dirty_states.add(state) + if id_ == fsm_info.alphabet_anything_value: continue + symbols = id_to_symbol[id_] + for c in symbols: + if len(c) > 1: + # Skip byte level transitions + continue - state_to_jump_forward[state] = (id_to_symbol[id_][0], next_state) + outgoings_ct[state] += 1 + if outgoings_ct[state] > 1: + if state in state_to_jump_forward: + del state_to_jump_forward[state] + break + + state_to_jump_forward[state] = JumpEdge( + symbol=c, + symbol_next_state=next_state, + ) + + # Process the byte level jump forward + outgoings_ct = defaultdict(int) + for (state, id_), next_state in transitions.items(): + if id_ == fsm_info.alphabet_anything_value: + continue + symbols = id_to_symbol[id_] + for c in symbols: + byte_ = None + if len(c) == 1 and ord(c) < 0x80: + # ASCII character + byte_ = ord(c) + elif len(c) == 2: + byte_ = int(symbols[0], 16) + + if byte_ is not None: + outgoings_ct[state] += 1 + if outgoings_ct[state] > 1: + if state in state_to_jump_forward: + del state_to_jump_forward[state] + break + e = state_to_jump_forward.get(state, JumpEdge()) + e.byte = byte_ + e.byte_next_state = next_state + state_to_jump_forward[state] = e return state_to_jump_forward self.state_to_jump_forward = _init_state_to_jump_forward(regex_string) - def valid_states(self): - return self.state_to_jump_forward.keys() + def jump_forward_symbol(self, state): + jump_forward_str = "" + next_state = state + while state in self.state_to_jump_forward: + e = self.state_to_jump_forward[state] + if e.symbol is None: + break + jump_forward_str += e.symbol + next_state = e.symbol_next_state + state = next_state - def jump_forward(self, state): + return jump_forward_str, next_state + + def jump_forward_byte(self, state): if state not in self.state_to_jump_forward: return None - jump_forward_str = "" + jump_forward_bytes = [] next_state = None while state in self.state_to_jump_forward: - symbol, next_state = self.state_to_jump_forward[state] - jump_forward_str += symbol + e = self.state_to_jump_forward[state] + assert e.byte is not None and e.byte_next_state is not None + jump_forward_bytes.append((e.byte, e.byte_next_state)) + next_state = e.byte_next_state state = next_state - return jump_forward_str, next_state + + return jump_forward_bytes + + def is_jump_forward_symbol_state(self, state): + return ( + state in self.state_to_jump_forward + and self.state_to_jump_forward[state].symbol is not None + ) class JumpForwardCache(BaseCache): @@ -69,12 +141,21 @@ class JumpForwardCache(BaseCache): return JumpForwardMap(regex) -def test_main(): - regex_string = r"The google's DNS sever address is " + IP_REGEX +def test_main(regex_string): jump_forward_map = JumpForwardMap(regex_string) - for state in jump_forward_map.valid_states(): - print(state, f'"{jump_forward_map.jump_forward(state)}"') + for state, e in jump_forward_map.state_to_jump_forward.items(): + if e.symbol is not None: + jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state) + print(f"{state} -> {next_state}", jump_forward_str) + bytes_ = jump_forward_map.jump_forward_byte(state) + print(f"{state} -> {bytes_[-1][1]}", [hex(b) for b, _ in bytes_]) if __name__ == "__main__": - test_main() + import outlines + + outlines.caching.clear_cache() + test_main(r"The google's DNS sever address is " + IP_REGEX) + test_main(r"霍格沃茨特快列车|霍比特人比尔博") + # 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ... + # 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ... diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index 6e235fefa..7ff9406ea 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -3,12 +3,17 @@ from dataclasses import dataclass from enum import IntEnum, auto from typing import List +import warnings import numpy as np import torch from sglang.srt.managers.controller.radix_cache import RadixCache from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool +from sglang.srt.constrained.jump_forward import JumpForwardMap +from sglang.srt.constrained import RegexGuide + +INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 class ForwardMode(IntEnum): @@ -64,12 +69,15 @@ class Req: def __init__(self, rid, origin_input_text, origin_input_ids): self.rid = rid self.origin_input_text = origin_input_text + self.origin_input_ids_unpadded = origin_input_ids # Before image padding self.origin_input_ids = origin_input_ids - self.origin_input_ids_unpadded = origin_input_ids # before image padding - self.prev_output_str = "" - self.prev_output_ids = [] - self.output_ids = [] - self.input_ids = None # input_ids = origin_input_ids + prev_output_ids + self.output_ids = [] # Each decode stage's output ids + self.input_ids = None # input_ids = origin_input_ids + output_ids + + # For incremental decode + self.decoded_text = "" + self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm + self.read_offset = None # The number of decoded tokens for token usage report. Note that # this does not include the jump forward tokens. @@ -109,20 +117,54 @@ class Req: self.last_update_decode_tokens = 0 # Constrained decoding - self.regex_fsm = None - self.regex_fsm_state = 0 - self.jump_forward_map = None + self.regex_fsm: RegexGuide = None + self.regex_fsm_state: int = 0 + self.jump_forward_map: JumpForwardMap = None # whether request reached finished condition def finished(self) -> bool: return self.finished_reason is not None - def partial_decode(self, ids): - first_token = self.tokenizer.convert_ids_to_tokens(ids[0]) - first_token = ( - first_token.decode() if isinstance(first_token, bytes) else first_token + # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313 + def init_detokenize_incrementally(self): + first_iter = self.surr_offset is None or self.read_offset is None + + if first_iter: + self.read_offset = len(self.origin_input_ids_unpadded) + self.surr_offset = max( + self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0 + ) + + all_ids = self.origin_input_ids_unpadded + self.output_ids + surr_ids = all_ids[self.surr_offset : self.read_offset] + read_ids = all_ids[self.surr_offset :] + + return surr_ids, read_ids, len(all_ids) + + def detokenize_incrementally(self, inplace: bool = True): + surr_ids, read_ids, num_all_tokens = self.init_detokenize_incrementally() + + surr_text = self.tokenizer.decode( + surr_ids, + skip_special_tokens=self.sampling_params.skip_special_tokens, + spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens, ) - return (" " if first_token.startswith("▁") else "") + self.tokenizer.decode(ids) + new_text = self.tokenizer.decode( + read_ids, + skip_special_tokens=self.sampling_params.skip_special_tokens, + spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens, + ) + + if len(new_text) > len(surr_text) and not new_text.endswith("�"): + new_text = new_text[len(surr_text) :] + if inplace: + self.decoded_text += new_text + self.surr_offset = self.read_offset + self.read_offset = num_all_tokens + + return True, new_text + + return False, "" def max_new_tokens(self): return self.sampling_params.max_new_tokens @@ -131,18 +173,17 @@ class Req: if self.finished(): return - if ( - len(self.prev_output_ids) + len(self.output_ids) - >= self.sampling_params.max_new_tokens - ): - self.finished_reason = FINISH_LENGTH(len(self.prev_output_ids) + len(self.output_ids)) + if len(self.output_ids) >= self.sampling_params.max_new_tokens: + self.finished_reason = FINISH_LENGTH(len(self.output_ids)) return if ( self.output_ids[-1] == self.tokenizer.eos_token_id and not self.sampling_params.ignore_eos ): - self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.tokenizer.eos_token_id) + self.finished_reason = FINISH_MATCHED_TOKEN( + matched=self.tokenizer.eos_token_id + ) return if len(self.sampling_params.stop_strs) > 0: @@ -151,61 +192,59 @@ class Req: ) for stop_str in self.sampling_params.stop_strs: - # FIXME: (minor) try incremental match in prev_output_str - if stop_str in tail_str or stop_str in self.prev_output_str: + if stop_str in tail_str or stop_str in self.decoded_text: self.finished_reason = FINISH_MATCHED_STR(matched=stop_str) return def jump_forward_and_retokenize(self, jump_forward_str, next_state): - # FIXME: This logic does not really solve the problem of determining whether - # there should be a leading space. - cur_output_str = self.partial_decode(self.output_ids) - - # TODO(lsyin): apply re-tokenize only for decode tokens so that we do not need origin_input_text anymore if self.origin_input_text is None: # Recovering text can only use unpadded ids self.origin_input_text = self.tokenizer.decode( self.origin_input_ids_unpadded ) - all_text = ( - self.origin_input_text - + self.prev_output_str - + cur_output_str - + jump_forward_str - ) + all_text = self.origin_input_text + self.decoded_text + jump_forward_str all_ids = self.tokenizer.encode(all_text) prompt_tokens = len(self.origin_input_ids_unpadded) - self.origin_input_ids = all_ids[:prompt_tokens] - self.origin_input_ids_unpadded = self.origin_input_ids - # NOTE: the output ids may not strictly correspond to the output text - old_prev_output_ids = self.prev_output_ids - self.prev_output_ids = all_ids[prompt_tokens:] - self.prev_output_str = self.prev_output_str + cur_output_str + jump_forward_str - self.output_ids = [] + + if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]: + # TODO(lsyin): fix token fusion + warnings.warn( + "Token fusion between input and output, try to avoid this by removing the space at the end of the input." + ) + return False + + old_output_ids = self.output_ids + self.output_ids = all_ids[prompt_tokens:] + self.decoded_text = self.decoded_text + jump_forward_str + self.surr_offset = prompt_tokens + self.read_offset = len(all_ids) + + # NOTE: A trick to reduce the surrouding tokens decoding overhead + for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET): + surr_text_ = self.tokenizer.decode( + all_ids[self.read_offset - i : self.read_offset] + ) + if not surr_text_.endswith("�"): + self.surr_offset = self.read_offset - i + break self.regex_fsm_state = next_state if self.return_logprob: # For fast-forward part's logprobs k = 0 - for i, old_id in enumerate(old_prev_output_ids): - if old_id == self.prev_output_ids[i]: + for i, old_id in enumerate(old_output_ids): + if old_id == self.output_ids[i]: k = k + 1 else: break self.decode_token_logprobs = self.decode_token_logprobs[:k] self.decode_top_logprobs = self.decode_top_logprobs[:k] self.logprob_start_len = prompt_tokens + k - self.last_update_decode_tokens = len(self.prev_output_ids) - k + self.last_update_decode_tokens = len(self.output_ids) - k - # print("=" * 100) - # print(f"Catch jump forward:\n{jump_forward_str}") - # print(self.tokenizer.convert_ids_to_tokens(self.input_ids)) - # print(self.tokenizer.convert_ids_to_tokens(new_input_ids)) - - # print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}") - # print("*" * 100) + return True def __repr__(self): return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, " @@ -381,7 +420,10 @@ class Batch: sorted_indices = [i for i in range(len(self.reqs))] # TODO(lsyin): improve the priority of retraction sorted_indices.sort( - key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)), + key=lambda i: ( + len(self.reqs[i].output_ids), + -len(self.reqs[i].origin_input_ids), + ), reverse=True, ) @@ -403,14 +445,9 @@ class Batch: # release the last node self.tree_cache.dec_lock_ref(req.last_node) - cur_output_str = req.partial_decode(req.output_ids) - req.prev_output_str = req.prev_output_str + cur_output_str - req.prev_output_ids.extend(req.output_ids) - req.prefix_indices = None req.last_node = None req.extend_input_len = 0 - req.output_ids = [] # For incremental logprobs req.last_update_decode_tokens = 0 @@ -428,18 +465,53 @@ class Batch: for i, req in enumerate(self.reqs): if req.jump_forward_map is not None: - res = req.jump_forward_map.jump_forward(req.regex_fsm_state) - if res is not None: - jump_forward_str, next_state = res - if len(jump_forward_str) <= 1: + jump_forward_bytes = req.jump_forward_map.jump_forward_byte( + req.regex_fsm_state + ) + if jump_forward_bytes is not None and len(jump_forward_bytes) > 1: + suffix_bytes = [] + continuation_range = range(0x80, 0xC0) + cur_state = req.regex_fsm_state + while ( + len(jump_forward_bytes) + and jump_forward_bytes[0][0] in continuation_range + ): + # continuation bytes + byte_edge = jump_forward_bytes.pop(0) + suffix_bytes.append(byte_edge[0]) + cur_state = byte_edge[1] + + suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes] + suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens) + + # Current ids, for cache and revert + cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1] + cur_output_ids = req.output_ids + + req.output_ids.extend(suffix_ids) + decode_res, new_text = req.detokenize_incrementally(inplace=False) + if not decode_res: + req.output_ids = cur_output_ids continue - if req_pool_indices_cpu is None: - req_pool_indices_cpu = self.req_pool_indices.tolist() + jump_forward_str, next_state = ( + req.jump_forward_map.jump_forward_symbol(cur_state) + ) + + # Make the incrementally decoded text part of jump_forward_str + # so that the UTF-8 will not corrupt + jump_forward_str = new_text + jump_forward_str + if not req.jump_forward_and_retokenize( + jump_forward_str, next_state + ): + req.output_ids = cur_output_ids + continue # insert the old request into tree_cache + if req_pool_indices_cpu is None: + req_pool_indices_cpu = self.req_pool_indices.tolist() self.tree_cache.cache_req( - token_ids=tuple(req.input_ids + req.output_ids)[:-1], + token_ids=cur_all_ids, last_uncached_pos=len(req.prefix_indices), req_pool_idx=req_pool_indices_cpu[i], ) @@ -447,9 +519,6 @@ class Batch: # unlock the last node self.tree_cache.dec_lock_ref(req.last_node) - # jump-forward - req.jump_forward_and_retokenize(jump_forward_str, next_state) - # re-applying image padding if req.pixel_values is not None: ( @@ -583,7 +652,7 @@ class Batch: if req.regex_fsm is not None: allowed_mask.zero_() allowed_mask[ - req.regex_fsm.allowed_token_ids(req.regex_fsm_state) + req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens ] = 1 logits[i].masked_fill_(~allowed_mask, float("-inf")) @@ -602,7 +671,7 @@ class Batch: batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy() for i, req in enumerate(self.reqs): if req.regex_fsm is not None: - req.regex_fsm_state = req.regex_fsm.next_state( + req.regex_fsm_state = req.regex_fsm.get_next_state( req.regex_fsm_state, batch_next_token_ids_cpu[i] ) diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index 1edd26337..3d4c48e51 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -21,7 +21,13 @@ from sglang.srt.managers.io_struct import ( FlushCacheReq, TokenizedGenerateReqInput, ) -from sglang.srt.managers.controller.infer_batch import BaseFinishReason, Batch, FINISH_ABORT, ForwardMode, Req +from sglang.srt.managers.controller.infer_batch import ( + BaseFinishReason, + Batch, + FINISH_ABORT, + ForwardMode, + Req, +) from sglang.srt.managers.controller.model_runner import ModelRunner from sglang.srt.managers.controller.radix_cache import RadixCache from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic @@ -98,8 +104,11 @@ class ModelTpServer: else server_args.max_prefill_tokens ), ) - self.max_running_requests = (self.max_total_num_tokens // 2 - if server_args.max_running_requests is None else server_args.max_running_requests) + self.max_running_requests = ( + self.max_total_num_tokens // 2 + if server_args.max_running_requests is None + else server_args.max_running_requests + ) self.int_token_logit_bias = torch.tensor( get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size) ) @@ -314,10 +323,7 @@ class ModelTpServer: # Compute matched prefix length for req in self.forward_queue: - assert ( - len(req.output_ids) == 0 - ), "The output ids should be empty when prefilling" - req.input_ids = req.origin_input_ids + req.prev_output_ids + req.input_ids = req.origin_input_ids + req.output_ids prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids) if req.return_logprob: prefix_indices = prefix_indices[: req.logprob_start_len] @@ -464,7 +470,7 @@ class ModelTpServer: pt = 0 for i, req in enumerate(batch.reqs): req.completion_tokens_wo_jump_forward += 1 - req.output_ids = [next_token_ids[i]] + req.output_ids.append(next_token_ids[i]) req.check_finished() if req.return_logprob: @@ -524,7 +530,7 @@ class ModelTpServer: req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy() for i, req in enumerate(batch.reqs): new_prefix_indices, new_last_node = self.tree_cache.cache_req( - token_ids=tuple(req.input_ids + req.output_ids)[:-1], + token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1], last_uncached_pos=len(req.prefix_indices), req_pool_idx=req_pool_indices_cpu[i], del_in_memory_pool=False, @@ -596,8 +602,9 @@ class ModelTpServer: def handle_finished_requests(self, batch: Batch): output_rids = [] - prev_output_strs = [] - output_tokens = [] + decoded_texts = [] + surr_output_ids = [] + read_output_ids = [] output_skip_special_tokens = [] output_spaces_between_special_tokens = [] output_meta_info = [] @@ -620,8 +627,10 @@ class ModelTpServer: ) ): output_rids.append(req.rid) - prev_output_strs.append(req.prev_output_str) - output_tokens.append(req.output_ids) + decoded_texts.append(req.decoded_text) + surr_ids, read_ids, _ = req.init_detokenize_incrementally() + surr_output_ids.append(surr_ids) + read_output_ids.append(read_ids) output_skip_special_tokens.append( req.sampling_params.skip_special_tokens ) @@ -631,7 +640,7 @@ class ModelTpServer: meta_info = { "prompt_tokens": len(req.origin_input_ids), - "completion_tokens": len(req.prev_output_ids) + len(req.output_ids), + "completion_tokens": len(req.output_ids), "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, "finish_reason": str(req.finished_reason), } @@ -657,8 +666,9 @@ class ModelTpServer: self.out_pyobjs.append( BatchTokenIDOut( output_rids, - prev_output_strs, - output_tokens, + decoded_texts, + surr_output_ids, + read_output_ids, output_skip_special_tokens, output_spaces_between_special_tokens, output_meta_info, @@ -673,7 +683,7 @@ class ModelTpServer: for i in finished_indices: req = batch.reqs[i] self.tree_cache.cache_req( - token_ids=tuple(req.input_ids + req.output_ids)[:-1], + token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1], last_uncached_pos=len(req.prefix_indices), req_pool_idx=req_pool_indices_cpu[i], ) @@ -790,4 +800,4 @@ class ModelTpClient: return _func - self.step = async_wrap("step") \ No newline at end of file + self.step = async_wrap("step") diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 1c591a6cc..b5231e69a 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -39,30 +39,24 @@ class DetokenizerManager: recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj() assert isinstance(recv_obj, BatchTokenIDOut) - output_tokens = recv_obj.output_tokens - # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request - output_strs = self.tokenizer.batch_decode( - output_tokens, + surr_texts = self.tokenizer.batch_decode( + recv_obj.surr_output_ids, skip_special_tokens=recv_obj.skip_special_tokens[0], - spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[ - 0 - ], + spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], + ) + read_texts = self.tokenizer.batch_decode( + recv_obj.read_output_ids, + skip_special_tokens=recv_obj.skip_special_tokens[0], + spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], ) # Trim stop str # TODO(lmzheng): handle the case where multiple stop strs are hit - for i in range(len(output_strs)): - if len(output_tokens[i]) > 0: - first_token = self.tokenizer.convert_ids_to_tokens( - int(output_tokens[i][0]) - ) - if not isinstance(first_token, str): - first_token = first_token.decode("utf-8", errors="ignore") - if first_token.startswith("▁"): - output_strs[i] = " " + output_strs[i] - - output_strs[i] = recv_obj.prev_output_strs[i] + output_strs[i] + output_strs = [] + for i in range(len(recv_obj.rids)): + new_text = read_texts[i][len(surr_texts[i]) :] + output_strs.append(recv_obj.decoded_texts[i] + new_text) if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR): pos = output_strs[i].find(recv_obj.finished_reason[i].matched) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index abc4d3033..1897a2c41 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -111,13 +111,15 @@ class TokenizedGenerateReqInput: @dataclass class BatchTokenIDOut: rids: List[str] - prev_output_strs: List[str] - output_tokens: List[List[int]] + decoded_texts: List[str] + surr_output_ids: List[List[int]] + read_output_ids: List[List[int]] skip_special_tokens: List[bool] spaces_between_special_tokens: List[bool] meta_info: List[Dict] finished_reason: List[BaseFinishReason] + @dataclass class BatchStrOut: rids: List[str]