Decode Incrementally (#517)
This commit is contained in:
53
examples/usage/chinese_regex.py
Normal file
53
examples/usage/chinese_regex.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
import sglang as sgl
|
||||||
|
|
||||||
|
character_regex = (
|
||||||
|
r"""\{\n"""
|
||||||
|
+ r""" "姓名": "[^"]{1,32}",\n"""
|
||||||
|
+ r""" "学院": "(格兰芬多|赫奇帕奇|拉文克劳|斯莱特林)",\n"""
|
||||||
|
+ r""" "血型": "(纯血|混血|麻瓜)",\n"""
|
||||||
|
+ r""" "职业": "(学生|教师|傲罗|魔法部|食死徒|凤凰社成员)",\n"""
|
||||||
|
+ r""" "魔杖": \{\n"""
|
||||||
|
+ r""" "材质": "[^"]{1,32}",\n"""
|
||||||
|
+ r""" "杖芯": "[^"]{1,32}",\n"""
|
||||||
|
+ r""" "长度": [0-9]{1,2}\.[0-9]{0,2}\n"""
|
||||||
|
+ r""" \},\n"""
|
||||||
|
+ r""" "存活": "(存活|死亡)",\n"""
|
||||||
|
+ r""" "守护神": "[^"]{1,32}",\n"""
|
||||||
|
+ r""" "博格特": "[^"]{1,32}"\n"""
|
||||||
|
+ r"""\}"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def character_gen(s, name):
|
||||||
|
s += name + " 是一名哈利波特系列小说中的角色。请填写以下关于这个角色的信息。"
|
||||||
|
s += """\
|
||||||
|
这是一个例子
|
||||||
|
{
|
||||||
|
"姓名": "哈利波特",
|
||||||
|
"学院": "格兰芬多",
|
||||||
|
"血型": "混血",
|
||||||
|
"职业": "学生",
|
||||||
|
"魔杖": {
|
||||||
|
"材质": "冬青木",
|
||||||
|
"杖芯": "凤凰尾羽",
|
||||||
|
"长度": 11.0
|
||||||
|
},
|
||||||
|
"存活": "存活",
|
||||||
|
"守护神": "麋鹿",
|
||||||
|
"博格特": "摄魂怪"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
s += f"现在请你填写{name}的信息:\n"
|
||||||
|
s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
backend = sgl.RuntimeEndpoint("http://localhost:30000")
|
||||||
|
sgl.set_default_backend(backend)
|
||||||
|
ret = character_gen.run(name="赫敏格兰杰", temperature=0)
|
||||||
|
print(ret.text())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -3,8 +3,8 @@ from typing import Dict, Optional, Union
|
|||||||
|
|
||||||
from outlines.caching import cache as disk_cache
|
from outlines.caching import cache as disk_cache
|
||||||
from outlines.caching import disable_cache
|
from outlines.caching import disable_cache
|
||||||
from outlines.fsm.fsm import RegexFSM
|
from outlines.fsm.guide import RegexGuide
|
||||||
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm
|
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm, make_byte_level_fsm
|
||||||
from outlines.models.transformers import TransformerTokenizer
|
from outlines.models.transformers import TransformerTokenizer
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -28,11 +28,12 @@ except ImportError:
|
|||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"RegexFSM",
|
"RegexGuide",
|
||||||
"FSMInfo",
|
"FSMInfo",
|
||||||
"make_deterministic_fsm",
|
"make_deterministic_fsm",
|
||||||
"build_regex_from_object",
|
"build_regex_from_object",
|
||||||
"TransformerTokenizer",
|
"TransformerTokenizer",
|
||||||
"disk_cache",
|
"disk_cache",
|
||||||
"disable_cache",
|
"disable_cache",
|
||||||
|
"make_byte_level_fsm",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""Cache for the compressed finite state machine."""
|
"""Cache for the compressed finite state machine."""
|
||||||
from sglang.srt.constrained import RegexFSM, TransformerTokenizer
|
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
||||||
from sglang.srt.constrained.base_cache import BaseCache
|
from sglang.srt.constrained.base_cache import BaseCache
|
||||||
|
|
||||||
|
|
||||||
@@ -26,4 +26,4 @@ class FSMCache(BaseCache):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def init_value(self, regex):
|
def init_value(self, regex):
|
||||||
return RegexFSM(regex, self.outlines_tokenizer)
|
return RegexGuide(regex, self.outlines_tokenizer)
|
||||||
|
|||||||
@@ -2,20 +2,41 @@
|
|||||||
Faster constrained decoding.
|
Faster constrained decoding.
|
||||||
Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
|
Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
|
||||||
"""
|
"""
|
||||||
import interegular
|
|
||||||
|
|
||||||
from sglang.srt.constrained import FSMInfo, disk_cache, make_deterministic_fsm
|
import interegular
|
||||||
|
import dataclasses
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import outlines.caching
|
||||||
|
from sglang.srt.constrained import (
|
||||||
|
FSMInfo,
|
||||||
|
disk_cache,
|
||||||
|
make_deterministic_fsm,
|
||||||
|
make_byte_level_fsm,
|
||||||
|
)
|
||||||
from sglang.srt.constrained.base_cache import BaseCache
|
from sglang.srt.constrained.base_cache import BaseCache
|
||||||
|
|
||||||
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class JumpEdge:
|
||||||
|
symbol: str = None
|
||||||
|
symbol_next_state: int = None
|
||||||
|
byte: int = None
|
||||||
|
byte_next_state: int = None
|
||||||
|
|
||||||
|
|
||||||
class JumpForwardMap:
|
class JumpForwardMap:
|
||||||
def __init__(self, regex_string):
|
def __init__(self, regex_string):
|
||||||
@disk_cache()
|
@disk_cache()
|
||||||
def _init_state_to_jump_forward(regex_string):
|
def _init_state_to_jump_forward(regex_string):
|
||||||
regex_pattern = interegular.parse_pattern(regex_string)
|
regex_pattern = interegular.parse_pattern(regex_string)
|
||||||
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
|
|
||||||
|
byte_fsm = make_byte_level_fsm(
|
||||||
|
regex_pattern.to_fsm().reduce(), keep_utf8=True
|
||||||
|
)
|
||||||
|
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
|
||||||
|
|
||||||
fsm_info: FSMInfo = regex_fsm.fsm_info
|
fsm_info: FSMInfo = regex_fsm.fsm_info
|
||||||
|
|
||||||
@@ -25,40 +46,91 @@ class JumpForwardMap:
|
|||||||
id_to_symbol.setdefault(id_, []).append(symbol)
|
id_to_symbol.setdefault(id_, []).append(symbol)
|
||||||
|
|
||||||
transitions = fsm_info.transitions
|
transitions = fsm_info.transitions
|
||||||
dirty_states = set()
|
outgoings_ct = defaultdict(int)
|
||||||
state_to_jump_forward = {}
|
state_to_jump_forward = {}
|
||||||
|
|
||||||
for (state, id_), next_state in transitions.items():
|
for (state, id_), next_state in transitions.items():
|
||||||
if state in dirty_states:
|
if id_ == fsm_info.alphabet_anything_value:
|
||||||
continue
|
|
||||||
if state in state_to_jump_forward:
|
|
||||||
dirty_states.add(state)
|
|
||||||
del state_to_jump_forward[state]
|
|
||||||
continue
|
|
||||||
if len(id_to_symbol[id_]) > 1:
|
|
||||||
dirty_states.add(state)
|
|
||||||
continue
|
continue
|
||||||
|
symbols = id_to_symbol[id_]
|
||||||
|
for c in symbols:
|
||||||
|
if len(c) > 1:
|
||||||
|
# Skip byte level transitions
|
||||||
|
continue
|
||||||
|
|
||||||
state_to_jump_forward[state] = (id_to_symbol[id_][0], next_state)
|
outgoings_ct[state] += 1
|
||||||
|
if outgoings_ct[state] > 1:
|
||||||
|
if state in state_to_jump_forward:
|
||||||
|
del state_to_jump_forward[state]
|
||||||
|
break
|
||||||
|
|
||||||
|
state_to_jump_forward[state] = JumpEdge(
|
||||||
|
symbol=c,
|
||||||
|
symbol_next_state=next_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process the byte level jump forward
|
||||||
|
outgoings_ct = defaultdict(int)
|
||||||
|
for (state, id_), next_state in transitions.items():
|
||||||
|
if id_ == fsm_info.alphabet_anything_value:
|
||||||
|
continue
|
||||||
|
symbols = id_to_symbol[id_]
|
||||||
|
for c in symbols:
|
||||||
|
byte_ = None
|
||||||
|
if len(c) == 1 and ord(c) < 0x80:
|
||||||
|
# ASCII character
|
||||||
|
byte_ = ord(c)
|
||||||
|
elif len(c) == 2:
|
||||||
|
byte_ = int(symbols[0], 16)
|
||||||
|
|
||||||
|
if byte_ is not None:
|
||||||
|
outgoings_ct[state] += 1
|
||||||
|
if outgoings_ct[state] > 1:
|
||||||
|
if state in state_to_jump_forward:
|
||||||
|
del state_to_jump_forward[state]
|
||||||
|
break
|
||||||
|
e = state_to_jump_forward.get(state, JumpEdge())
|
||||||
|
e.byte = byte_
|
||||||
|
e.byte_next_state = next_state
|
||||||
|
state_to_jump_forward[state] = e
|
||||||
|
|
||||||
return state_to_jump_forward
|
return state_to_jump_forward
|
||||||
|
|
||||||
self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
|
self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
|
||||||
|
|
||||||
def valid_states(self):
|
def jump_forward_symbol(self, state):
|
||||||
return self.state_to_jump_forward.keys()
|
jump_forward_str = ""
|
||||||
|
next_state = state
|
||||||
|
while state in self.state_to_jump_forward:
|
||||||
|
e = self.state_to_jump_forward[state]
|
||||||
|
if e.symbol is None:
|
||||||
|
break
|
||||||
|
jump_forward_str += e.symbol
|
||||||
|
next_state = e.symbol_next_state
|
||||||
|
state = next_state
|
||||||
|
|
||||||
def jump_forward(self, state):
|
return jump_forward_str, next_state
|
||||||
|
|
||||||
|
def jump_forward_byte(self, state):
|
||||||
if state not in self.state_to_jump_forward:
|
if state not in self.state_to_jump_forward:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
jump_forward_str = ""
|
jump_forward_bytes = []
|
||||||
next_state = None
|
next_state = None
|
||||||
while state in self.state_to_jump_forward:
|
while state in self.state_to_jump_forward:
|
||||||
symbol, next_state = self.state_to_jump_forward[state]
|
e = self.state_to_jump_forward[state]
|
||||||
jump_forward_str += symbol
|
assert e.byte is not None and e.byte_next_state is not None
|
||||||
|
jump_forward_bytes.append((e.byte, e.byte_next_state))
|
||||||
|
next_state = e.byte_next_state
|
||||||
state = next_state
|
state = next_state
|
||||||
return jump_forward_str, next_state
|
|
||||||
|
return jump_forward_bytes
|
||||||
|
|
||||||
|
def is_jump_forward_symbol_state(self, state):
|
||||||
|
return (
|
||||||
|
state in self.state_to_jump_forward
|
||||||
|
and self.state_to_jump_forward[state].symbol is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class JumpForwardCache(BaseCache):
|
class JumpForwardCache(BaseCache):
|
||||||
@@ -69,12 +141,21 @@ class JumpForwardCache(BaseCache):
|
|||||||
return JumpForwardMap(regex)
|
return JumpForwardMap(regex)
|
||||||
|
|
||||||
|
|
||||||
def test_main():
|
def test_main(regex_string):
|
||||||
regex_string = r"The google's DNS sever address is " + IP_REGEX
|
|
||||||
jump_forward_map = JumpForwardMap(regex_string)
|
jump_forward_map = JumpForwardMap(regex_string)
|
||||||
for state in jump_forward_map.valid_states():
|
for state, e in jump_forward_map.state_to_jump_forward.items():
|
||||||
print(state, f'"{jump_forward_map.jump_forward(state)}"')
|
if e.symbol is not None:
|
||||||
|
jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)
|
||||||
|
print(f"{state} -> {next_state}", jump_forward_str)
|
||||||
|
bytes_ = jump_forward_map.jump_forward_byte(state)
|
||||||
|
print(f"{state} -> {bytes_[-1][1]}", [hex(b) for b, _ in bytes_])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_main()
|
import outlines
|
||||||
|
|
||||||
|
outlines.caching.clear_cache()
|
||||||
|
test_main(r"The google's DNS sever address is " + IP_REGEX)
|
||||||
|
test_main(r"霍格沃茨特快列车|霍比特人比尔博")
|
||||||
|
# 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
|
||||||
|
# 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...
|
||||||
|
|||||||
@@ -3,12 +3,17 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import List
|
from typing import List
|
||||||
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.managers.controller.radix_cache import RadixCache
|
from sglang.srt.managers.controller.radix_cache import RadixCache
|
||||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||||
|
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||||
|
from sglang.srt.constrained import RegexGuide
|
||||||
|
|
||||||
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||||
|
|
||||||
|
|
||||||
class ForwardMode(IntEnum):
|
class ForwardMode(IntEnum):
|
||||||
@@ -64,12 +69,15 @@ class Req:
|
|||||||
def __init__(self, rid, origin_input_text, origin_input_ids):
|
def __init__(self, rid, origin_input_text, origin_input_ids):
|
||||||
self.rid = rid
|
self.rid = rid
|
||||||
self.origin_input_text = origin_input_text
|
self.origin_input_text = origin_input_text
|
||||||
|
self.origin_input_ids_unpadded = origin_input_ids # Before image padding
|
||||||
self.origin_input_ids = origin_input_ids
|
self.origin_input_ids = origin_input_ids
|
||||||
self.origin_input_ids_unpadded = origin_input_ids # before image padding
|
self.output_ids = [] # Each decode stage's output ids
|
||||||
self.prev_output_str = ""
|
self.input_ids = None # input_ids = origin_input_ids + output_ids
|
||||||
self.prev_output_ids = []
|
|
||||||
self.output_ids = []
|
# For incremental decode
|
||||||
self.input_ids = None # input_ids = origin_input_ids + prev_output_ids
|
self.decoded_text = ""
|
||||||
|
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
||||||
|
self.read_offset = None
|
||||||
|
|
||||||
# The number of decoded tokens for token usage report. Note that
|
# The number of decoded tokens for token usage report. Note that
|
||||||
# this does not include the jump forward tokens.
|
# this does not include the jump forward tokens.
|
||||||
@@ -109,20 +117,54 @@ class Req:
|
|||||||
self.last_update_decode_tokens = 0
|
self.last_update_decode_tokens = 0
|
||||||
|
|
||||||
# Constrained decoding
|
# Constrained decoding
|
||||||
self.regex_fsm = None
|
self.regex_fsm: RegexGuide = None
|
||||||
self.regex_fsm_state = 0
|
self.regex_fsm_state: int = 0
|
||||||
self.jump_forward_map = None
|
self.jump_forward_map: JumpForwardMap = None
|
||||||
|
|
||||||
# whether request reached finished condition
|
# whether request reached finished condition
|
||||||
def finished(self) -> bool:
|
def finished(self) -> bool:
|
||||||
return self.finished_reason is not None
|
return self.finished_reason is not None
|
||||||
|
|
||||||
def partial_decode(self, ids):
|
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
|
||||||
first_token = self.tokenizer.convert_ids_to_tokens(ids[0])
|
def init_detokenize_incrementally(self):
|
||||||
first_token = (
|
first_iter = self.surr_offset is None or self.read_offset is None
|
||||||
first_token.decode() if isinstance(first_token, bytes) else first_token
|
|
||||||
|
if first_iter:
|
||||||
|
self.read_offset = len(self.origin_input_ids_unpadded)
|
||||||
|
self.surr_offset = max(
|
||||||
|
self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
|
||||||
|
)
|
||||||
|
|
||||||
|
all_ids = self.origin_input_ids_unpadded + self.output_ids
|
||||||
|
surr_ids = all_ids[self.surr_offset : self.read_offset]
|
||||||
|
read_ids = all_ids[self.surr_offset :]
|
||||||
|
|
||||||
|
return surr_ids, read_ids, len(all_ids)
|
||||||
|
|
||||||
|
def detokenize_incrementally(self, inplace: bool = True):
|
||||||
|
surr_ids, read_ids, num_all_tokens = self.init_detokenize_incrementally()
|
||||||
|
|
||||||
|
surr_text = self.tokenizer.decode(
|
||||||
|
surr_ids,
|
||||||
|
skip_special_tokens=self.sampling_params.skip_special_tokens,
|
||||||
|
spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
|
||||||
)
|
)
|
||||||
return (" " if first_token.startswith("▁") else "") + self.tokenizer.decode(ids)
|
new_text = self.tokenizer.decode(
|
||||||
|
read_ids,
|
||||||
|
skip_special_tokens=self.sampling_params.skip_special_tokens,
|
||||||
|
spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(new_text) > len(surr_text) and not new_text.endswith("<EFBFBD>"):
|
||||||
|
new_text = new_text[len(surr_text) :]
|
||||||
|
if inplace:
|
||||||
|
self.decoded_text += new_text
|
||||||
|
self.surr_offset = self.read_offset
|
||||||
|
self.read_offset = num_all_tokens
|
||||||
|
|
||||||
|
return True, new_text
|
||||||
|
|
||||||
|
return False, ""
|
||||||
|
|
||||||
def max_new_tokens(self):
|
def max_new_tokens(self):
|
||||||
return self.sampling_params.max_new_tokens
|
return self.sampling_params.max_new_tokens
|
||||||
@@ -131,18 +173,17 @@ class Req:
|
|||||||
if self.finished():
|
if self.finished():
|
||||||
return
|
return
|
||||||
|
|
||||||
if (
|
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
||||||
len(self.prev_output_ids) + len(self.output_ids)
|
self.finished_reason = FINISH_LENGTH(len(self.output_ids))
|
||||||
>= self.sampling_params.max_new_tokens
|
|
||||||
):
|
|
||||||
self.finished_reason = FINISH_LENGTH(len(self.prev_output_ids) + len(self.output_ids))
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.output_ids[-1] == self.tokenizer.eos_token_id
|
self.output_ids[-1] == self.tokenizer.eos_token_id
|
||||||
and not self.sampling_params.ignore_eos
|
and not self.sampling_params.ignore_eos
|
||||||
):
|
):
|
||||||
self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.tokenizer.eos_token_id)
|
self.finished_reason = FINISH_MATCHED_TOKEN(
|
||||||
|
matched=self.tokenizer.eos_token_id
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if len(self.sampling_params.stop_strs) > 0:
|
if len(self.sampling_params.stop_strs) > 0:
|
||||||
@@ -151,61 +192,59 @@ class Req:
|
|||||||
)
|
)
|
||||||
|
|
||||||
for stop_str in self.sampling_params.stop_strs:
|
for stop_str in self.sampling_params.stop_strs:
|
||||||
# FIXME: (minor) try incremental match in prev_output_str
|
if stop_str in tail_str or stop_str in self.decoded_text:
|
||||||
if stop_str in tail_str or stop_str in self.prev_output_str:
|
|
||||||
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
||||||
return
|
return
|
||||||
|
|
||||||
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
||||||
# FIXME: This logic does not really solve the problem of determining whether
|
|
||||||
# there should be a leading space.
|
|
||||||
cur_output_str = self.partial_decode(self.output_ids)
|
|
||||||
|
|
||||||
# TODO(lsyin): apply re-tokenize only for decode tokens so that we do not need origin_input_text anymore
|
|
||||||
if self.origin_input_text is None:
|
if self.origin_input_text is None:
|
||||||
# Recovering text can only use unpadded ids
|
# Recovering text can only use unpadded ids
|
||||||
self.origin_input_text = self.tokenizer.decode(
|
self.origin_input_text = self.tokenizer.decode(
|
||||||
self.origin_input_ids_unpadded
|
self.origin_input_ids_unpadded
|
||||||
)
|
)
|
||||||
|
|
||||||
all_text = (
|
all_text = self.origin_input_text + self.decoded_text + jump_forward_str
|
||||||
self.origin_input_text
|
|
||||||
+ self.prev_output_str
|
|
||||||
+ cur_output_str
|
|
||||||
+ jump_forward_str
|
|
||||||
)
|
|
||||||
all_ids = self.tokenizer.encode(all_text)
|
all_ids = self.tokenizer.encode(all_text)
|
||||||
prompt_tokens = len(self.origin_input_ids_unpadded)
|
prompt_tokens = len(self.origin_input_ids_unpadded)
|
||||||
self.origin_input_ids = all_ids[:prompt_tokens]
|
|
||||||
self.origin_input_ids_unpadded = self.origin_input_ids
|
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
|
||||||
# NOTE: the output ids may not strictly correspond to the output text
|
# TODO(lsyin): fix token fusion
|
||||||
old_prev_output_ids = self.prev_output_ids
|
warnings.warn(
|
||||||
self.prev_output_ids = all_ids[prompt_tokens:]
|
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
|
||||||
self.prev_output_str = self.prev_output_str + cur_output_str + jump_forward_str
|
)
|
||||||
self.output_ids = []
|
return False
|
||||||
|
|
||||||
|
old_output_ids = self.output_ids
|
||||||
|
self.output_ids = all_ids[prompt_tokens:]
|
||||||
|
self.decoded_text = self.decoded_text + jump_forward_str
|
||||||
|
self.surr_offset = prompt_tokens
|
||||||
|
self.read_offset = len(all_ids)
|
||||||
|
|
||||||
|
# NOTE: A trick to reduce the surrouding tokens decoding overhead
|
||||||
|
for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
|
||||||
|
surr_text_ = self.tokenizer.decode(
|
||||||
|
all_ids[self.read_offset - i : self.read_offset]
|
||||||
|
)
|
||||||
|
if not surr_text_.endswith("<EFBFBD>"):
|
||||||
|
self.surr_offset = self.read_offset - i
|
||||||
|
break
|
||||||
|
|
||||||
self.regex_fsm_state = next_state
|
self.regex_fsm_state = next_state
|
||||||
|
|
||||||
if self.return_logprob:
|
if self.return_logprob:
|
||||||
# For fast-forward part's logprobs
|
# For fast-forward part's logprobs
|
||||||
k = 0
|
k = 0
|
||||||
for i, old_id in enumerate(old_prev_output_ids):
|
for i, old_id in enumerate(old_output_ids):
|
||||||
if old_id == self.prev_output_ids[i]:
|
if old_id == self.output_ids[i]:
|
||||||
k = k + 1
|
k = k + 1
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
self.decode_token_logprobs = self.decode_token_logprobs[:k]
|
self.decode_token_logprobs = self.decode_token_logprobs[:k]
|
||||||
self.decode_top_logprobs = self.decode_top_logprobs[:k]
|
self.decode_top_logprobs = self.decode_top_logprobs[:k]
|
||||||
self.logprob_start_len = prompt_tokens + k
|
self.logprob_start_len = prompt_tokens + k
|
||||||
self.last_update_decode_tokens = len(self.prev_output_ids) - k
|
self.last_update_decode_tokens = len(self.output_ids) - k
|
||||||
|
|
||||||
# print("=" * 100)
|
return True
|
||||||
# print(f"Catch jump forward:\n{jump_forward_str}")
|
|
||||||
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
|
|
||||||
# print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
|
|
||||||
|
|
||||||
# print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
|
|
||||||
# print("*" * 100)
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
|
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
|
||||||
@@ -381,7 +420,10 @@ class Batch:
|
|||||||
sorted_indices = [i for i in range(len(self.reqs))]
|
sorted_indices = [i for i in range(len(self.reqs))]
|
||||||
# TODO(lsyin): improve the priority of retraction
|
# TODO(lsyin): improve the priority of retraction
|
||||||
sorted_indices.sort(
|
sorted_indices.sort(
|
||||||
key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)),
|
key=lambda i: (
|
||||||
|
len(self.reqs[i].output_ids),
|
||||||
|
-len(self.reqs[i].origin_input_ids),
|
||||||
|
),
|
||||||
reverse=True,
|
reverse=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -403,14 +445,9 @@ class Batch:
|
|||||||
# release the last node
|
# release the last node
|
||||||
self.tree_cache.dec_lock_ref(req.last_node)
|
self.tree_cache.dec_lock_ref(req.last_node)
|
||||||
|
|
||||||
cur_output_str = req.partial_decode(req.output_ids)
|
|
||||||
req.prev_output_str = req.prev_output_str + cur_output_str
|
|
||||||
req.prev_output_ids.extend(req.output_ids)
|
|
||||||
|
|
||||||
req.prefix_indices = None
|
req.prefix_indices = None
|
||||||
req.last_node = None
|
req.last_node = None
|
||||||
req.extend_input_len = 0
|
req.extend_input_len = 0
|
||||||
req.output_ids = []
|
|
||||||
|
|
||||||
# For incremental logprobs
|
# For incremental logprobs
|
||||||
req.last_update_decode_tokens = 0
|
req.last_update_decode_tokens = 0
|
||||||
@@ -428,18 +465,53 @@ class Batch:
|
|||||||
|
|
||||||
for i, req in enumerate(self.reqs):
|
for i, req in enumerate(self.reqs):
|
||||||
if req.jump_forward_map is not None:
|
if req.jump_forward_map is not None:
|
||||||
res = req.jump_forward_map.jump_forward(req.regex_fsm_state)
|
jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
|
||||||
if res is not None:
|
req.regex_fsm_state
|
||||||
jump_forward_str, next_state = res
|
)
|
||||||
if len(jump_forward_str) <= 1:
|
if jump_forward_bytes is not None and len(jump_forward_bytes) > 1:
|
||||||
|
suffix_bytes = []
|
||||||
|
continuation_range = range(0x80, 0xC0)
|
||||||
|
cur_state = req.regex_fsm_state
|
||||||
|
while (
|
||||||
|
len(jump_forward_bytes)
|
||||||
|
and jump_forward_bytes[0][0] in continuation_range
|
||||||
|
):
|
||||||
|
# continuation bytes
|
||||||
|
byte_edge = jump_forward_bytes.pop(0)
|
||||||
|
suffix_bytes.append(byte_edge[0])
|
||||||
|
cur_state = byte_edge[1]
|
||||||
|
|
||||||
|
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
|
||||||
|
suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens)
|
||||||
|
|
||||||
|
# Current ids, for cache and revert
|
||||||
|
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
|
||||||
|
cur_output_ids = req.output_ids
|
||||||
|
|
||||||
|
req.output_ids.extend(suffix_ids)
|
||||||
|
decode_res, new_text = req.detokenize_incrementally(inplace=False)
|
||||||
|
if not decode_res:
|
||||||
|
req.output_ids = cur_output_ids
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if req_pool_indices_cpu is None:
|
jump_forward_str, next_state = (
|
||||||
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
req.jump_forward_map.jump_forward_symbol(cur_state)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make the incrementally decoded text part of jump_forward_str
|
||||||
|
# so that the UTF-8 will not corrupt
|
||||||
|
jump_forward_str = new_text + jump_forward_str
|
||||||
|
if not req.jump_forward_and_retokenize(
|
||||||
|
jump_forward_str, next_state
|
||||||
|
):
|
||||||
|
req.output_ids = cur_output_ids
|
||||||
|
continue
|
||||||
|
|
||||||
# insert the old request into tree_cache
|
# insert the old request into tree_cache
|
||||||
|
if req_pool_indices_cpu is None:
|
||||||
|
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
||||||
self.tree_cache.cache_req(
|
self.tree_cache.cache_req(
|
||||||
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
|
token_ids=cur_all_ids,
|
||||||
last_uncached_pos=len(req.prefix_indices),
|
last_uncached_pos=len(req.prefix_indices),
|
||||||
req_pool_idx=req_pool_indices_cpu[i],
|
req_pool_idx=req_pool_indices_cpu[i],
|
||||||
)
|
)
|
||||||
@@ -447,9 +519,6 @@ class Batch:
|
|||||||
# unlock the last node
|
# unlock the last node
|
||||||
self.tree_cache.dec_lock_ref(req.last_node)
|
self.tree_cache.dec_lock_ref(req.last_node)
|
||||||
|
|
||||||
# jump-forward
|
|
||||||
req.jump_forward_and_retokenize(jump_forward_str, next_state)
|
|
||||||
|
|
||||||
# re-applying image padding
|
# re-applying image padding
|
||||||
if req.pixel_values is not None:
|
if req.pixel_values is not None:
|
||||||
(
|
(
|
||||||
@@ -583,7 +652,7 @@ class Batch:
|
|||||||
if req.regex_fsm is not None:
|
if req.regex_fsm is not None:
|
||||||
allowed_mask.zero_()
|
allowed_mask.zero_()
|
||||||
allowed_mask[
|
allowed_mask[
|
||||||
req.regex_fsm.allowed_token_ids(req.regex_fsm_state)
|
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
|
||||||
] = 1
|
] = 1
|
||||||
logits[i].masked_fill_(~allowed_mask, float("-inf"))
|
logits[i].masked_fill_(~allowed_mask, float("-inf"))
|
||||||
|
|
||||||
@@ -602,7 +671,7 @@ class Batch:
|
|||||||
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
|
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
|
||||||
for i, req in enumerate(self.reqs):
|
for i, req in enumerate(self.reqs):
|
||||||
if req.regex_fsm is not None:
|
if req.regex_fsm is not None:
|
||||||
req.regex_fsm_state = req.regex_fsm.next_state(
|
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
||||||
req.regex_fsm_state, batch_next_token_ids_cpu[i]
|
req.regex_fsm_state, batch_next_token_ids_cpu[i]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,13 @@ from sglang.srt.managers.io_struct import (
|
|||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.controller.infer_batch import BaseFinishReason, Batch, FINISH_ABORT, ForwardMode, Req
|
from sglang.srt.managers.controller.infer_batch import (
|
||||||
|
BaseFinishReason,
|
||||||
|
Batch,
|
||||||
|
FINISH_ABORT,
|
||||||
|
ForwardMode,
|
||||||
|
Req,
|
||||||
|
)
|
||||||
from sglang.srt.managers.controller.model_runner import ModelRunner
|
from sglang.srt.managers.controller.model_runner import ModelRunner
|
||||||
from sglang.srt.managers.controller.radix_cache import RadixCache
|
from sglang.srt.managers.controller.radix_cache import RadixCache
|
||||||
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
|
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
|
||||||
@@ -98,8 +104,11 @@ class ModelTpServer:
|
|||||||
else server_args.max_prefill_tokens
|
else server_args.max_prefill_tokens
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.max_running_requests = (self.max_total_num_tokens // 2
|
self.max_running_requests = (
|
||||||
if server_args.max_running_requests is None else server_args.max_running_requests)
|
self.max_total_num_tokens // 2
|
||||||
|
if server_args.max_running_requests is None
|
||||||
|
else server_args.max_running_requests
|
||||||
|
)
|
||||||
self.int_token_logit_bias = torch.tensor(
|
self.int_token_logit_bias = torch.tensor(
|
||||||
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
||||||
)
|
)
|
||||||
@@ -314,10 +323,7 @@ class ModelTpServer:
|
|||||||
|
|
||||||
# Compute matched prefix length
|
# Compute matched prefix length
|
||||||
for req in self.forward_queue:
|
for req in self.forward_queue:
|
||||||
assert (
|
req.input_ids = req.origin_input_ids + req.output_ids
|
||||||
len(req.output_ids) == 0
|
|
||||||
), "The output ids should be empty when prefilling"
|
|
||||||
req.input_ids = req.origin_input_ids + req.prev_output_ids
|
|
||||||
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
prefix_indices = prefix_indices[: req.logprob_start_len]
|
prefix_indices = prefix_indices[: req.logprob_start_len]
|
||||||
@@ -464,7 +470,7 @@ class ModelTpServer:
|
|||||||
pt = 0
|
pt = 0
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
req.completion_tokens_wo_jump_forward += 1
|
req.completion_tokens_wo_jump_forward += 1
|
||||||
req.output_ids = [next_token_ids[i]]
|
req.output_ids.append(next_token_ids[i])
|
||||||
req.check_finished()
|
req.check_finished()
|
||||||
|
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
@@ -524,7 +530,7 @@ class ModelTpServer:
|
|||||||
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
||||||
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
|
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
||||||
last_uncached_pos=len(req.prefix_indices),
|
last_uncached_pos=len(req.prefix_indices),
|
||||||
req_pool_idx=req_pool_indices_cpu[i],
|
req_pool_idx=req_pool_indices_cpu[i],
|
||||||
del_in_memory_pool=False,
|
del_in_memory_pool=False,
|
||||||
@@ -596,8 +602,9 @@ class ModelTpServer:
|
|||||||
|
|
||||||
def handle_finished_requests(self, batch: Batch):
|
def handle_finished_requests(self, batch: Batch):
|
||||||
output_rids = []
|
output_rids = []
|
||||||
prev_output_strs = []
|
decoded_texts = []
|
||||||
output_tokens = []
|
surr_output_ids = []
|
||||||
|
read_output_ids = []
|
||||||
output_skip_special_tokens = []
|
output_skip_special_tokens = []
|
||||||
output_spaces_between_special_tokens = []
|
output_spaces_between_special_tokens = []
|
||||||
output_meta_info = []
|
output_meta_info = []
|
||||||
@@ -620,8 +627,10 @@ class ModelTpServer:
|
|||||||
)
|
)
|
||||||
):
|
):
|
||||||
output_rids.append(req.rid)
|
output_rids.append(req.rid)
|
||||||
prev_output_strs.append(req.prev_output_str)
|
decoded_texts.append(req.decoded_text)
|
||||||
output_tokens.append(req.output_ids)
|
surr_ids, read_ids, _ = req.init_detokenize_incrementally()
|
||||||
|
surr_output_ids.append(surr_ids)
|
||||||
|
read_output_ids.append(read_ids)
|
||||||
output_skip_special_tokens.append(
|
output_skip_special_tokens.append(
|
||||||
req.sampling_params.skip_special_tokens
|
req.sampling_params.skip_special_tokens
|
||||||
)
|
)
|
||||||
@@ -631,7 +640,7 @@ class ModelTpServer:
|
|||||||
|
|
||||||
meta_info = {
|
meta_info = {
|
||||||
"prompt_tokens": len(req.origin_input_ids),
|
"prompt_tokens": len(req.origin_input_ids),
|
||||||
"completion_tokens": len(req.prev_output_ids) + len(req.output_ids),
|
"completion_tokens": len(req.output_ids),
|
||||||
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
||||||
"finish_reason": str(req.finished_reason),
|
"finish_reason": str(req.finished_reason),
|
||||||
}
|
}
|
||||||
@@ -657,8 +666,9 @@ class ModelTpServer:
|
|||||||
self.out_pyobjs.append(
|
self.out_pyobjs.append(
|
||||||
BatchTokenIDOut(
|
BatchTokenIDOut(
|
||||||
output_rids,
|
output_rids,
|
||||||
prev_output_strs,
|
decoded_texts,
|
||||||
output_tokens,
|
surr_output_ids,
|
||||||
|
read_output_ids,
|
||||||
output_skip_special_tokens,
|
output_skip_special_tokens,
|
||||||
output_spaces_between_special_tokens,
|
output_spaces_between_special_tokens,
|
||||||
output_meta_info,
|
output_meta_info,
|
||||||
@@ -673,7 +683,7 @@ class ModelTpServer:
|
|||||||
for i in finished_indices:
|
for i in finished_indices:
|
||||||
req = batch.reqs[i]
|
req = batch.reqs[i]
|
||||||
self.tree_cache.cache_req(
|
self.tree_cache.cache_req(
|
||||||
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
|
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
||||||
last_uncached_pos=len(req.prefix_indices),
|
last_uncached_pos=len(req.prefix_indices),
|
||||||
req_pool_idx=req_pool_indices_cpu[i],
|
req_pool_idx=req_pool_indices_cpu[i],
|
||||||
)
|
)
|
||||||
@@ -790,4 +800,4 @@ class ModelTpClient:
|
|||||||
|
|
||||||
return _func
|
return _func
|
||||||
|
|
||||||
self.step = async_wrap("step")
|
self.step = async_wrap("step")
|
||||||
|
|||||||
@@ -39,30 +39,24 @@ class DetokenizerManager:
|
|||||||
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
|
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
|
||||||
assert isinstance(recv_obj, BatchTokenIDOut)
|
assert isinstance(recv_obj, BatchTokenIDOut)
|
||||||
|
|
||||||
output_tokens = recv_obj.output_tokens
|
|
||||||
|
|
||||||
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
|
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
|
||||||
output_strs = self.tokenizer.batch_decode(
|
surr_texts = self.tokenizer.batch_decode(
|
||||||
output_tokens,
|
recv_obj.surr_output_ids,
|
||||||
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
||||||
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
|
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
||||||
0
|
)
|
||||||
],
|
read_texts = self.tokenizer.batch_decode(
|
||||||
|
recv_obj.read_output_ids,
|
||||||
|
skip_special_tokens=recv_obj.skip_special_tokens[0],
|
||||||
|
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Trim stop str
|
# Trim stop str
|
||||||
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
||||||
for i in range(len(output_strs)):
|
output_strs = []
|
||||||
if len(output_tokens[i]) > 0:
|
for i in range(len(recv_obj.rids)):
|
||||||
first_token = self.tokenizer.convert_ids_to_tokens(
|
new_text = read_texts[i][len(surr_texts[i]) :]
|
||||||
int(output_tokens[i][0])
|
output_strs.append(recv_obj.decoded_texts[i] + new_text)
|
||||||
)
|
|
||||||
if not isinstance(first_token, str):
|
|
||||||
first_token = first_token.decode("utf-8", errors="ignore")
|
|
||||||
if first_token.startswith("▁"):
|
|
||||||
output_strs[i] = " " + output_strs[i]
|
|
||||||
|
|
||||||
output_strs[i] = recv_obj.prev_output_strs[i] + output_strs[i]
|
|
||||||
|
|
||||||
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
|
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
|
||||||
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
|
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
|
||||||
|
|||||||
@@ -111,13 +111,15 @@ class TokenizedGenerateReqInput:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class BatchTokenIDOut:
|
class BatchTokenIDOut:
|
||||||
rids: List[str]
|
rids: List[str]
|
||||||
prev_output_strs: List[str]
|
decoded_texts: List[str]
|
||||||
output_tokens: List[List[int]]
|
surr_output_ids: List[List[int]]
|
||||||
|
read_output_ids: List[List[int]]
|
||||||
skip_special_tokens: List[bool]
|
skip_special_tokens: List[bool]
|
||||||
spaces_between_special_tokens: List[bool]
|
spaces_between_special_tokens: List[bool]
|
||||||
meta_info: List[Dict]
|
meta_info: List[Dict]
|
||||||
finished_reason: List[BaseFinishReason]
|
finished_reason: List[BaseFinishReason]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchStrOut:
|
class BatchStrOut:
|
||||||
rids: List[str]
|
rids: List[str]
|
||||||
|
|||||||
Reference in New Issue
Block a user