Decode Incrementally (#517)

This commit is contained in:
Liangsheng Yin
2024-06-12 14:39:12 +08:00
committed by GitHub
parent 111991fe23
commit 9c902b1954
8 changed files with 345 additions and 135 deletions

View File

@@ -3,8 +3,8 @@ from typing import Dict, Optional, Union
from outlines.caching import cache as disk_cache
from outlines.caching import disable_cache
from outlines.fsm.fsm import RegexFSM
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm
from outlines.fsm.guide import RegexGuide
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm, make_byte_level_fsm
from outlines.models.transformers import TransformerTokenizer
from pydantic import BaseModel
@@ -28,11 +28,12 @@ except ImportError:
__all__ = [
"RegexFSM",
"RegexGuide",
"FSMInfo",
"make_deterministic_fsm",
"build_regex_from_object",
"TransformerTokenizer",
"disk_cache",
"disable_cache",
"make_byte_level_fsm",
]

View File

@@ -1,5 +1,5 @@
"""Cache for the compressed finite state machine."""
from sglang.srt.constrained import RegexFSM, TransformerTokenizer
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
from sglang.srt.constrained.base_cache import BaseCache
@@ -26,4 +26,4 @@ class FSMCache(BaseCache):
)
def init_value(self, regex):
return RegexFSM(regex, self.outlines_tokenizer)
return RegexGuide(regex, self.outlines_tokenizer)

View File

@@ -2,20 +2,41 @@
Faster constrained decoding.
Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
"""
import interegular
from sglang.srt.constrained import FSMInfo, disk_cache, make_deterministic_fsm
import interegular
import dataclasses
from collections import defaultdict
import outlines.caching
from sglang.srt.constrained import (
FSMInfo,
disk_cache,
make_deterministic_fsm,
make_byte_level_fsm,
)
from sglang.srt.constrained.base_cache import BaseCache
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
@dataclasses.dataclass
class JumpEdge:
symbol: str = None
symbol_next_state: int = None
byte: int = None
byte_next_state: int = None
class JumpForwardMap:
def __init__(self, regex_string):
@disk_cache()
def _init_state_to_jump_forward(regex_string):
regex_pattern = interegular.parse_pattern(regex_string)
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
byte_fsm = make_byte_level_fsm(
regex_pattern.to_fsm().reduce(), keep_utf8=True
)
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
fsm_info: FSMInfo = regex_fsm.fsm_info
@@ -25,40 +46,91 @@ class JumpForwardMap:
id_to_symbol.setdefault(id_, []).append(symbol)
transitions = fsm_info.transitions
dirty_states = set()
outgoings_ct = defaultdict(int)
state_to_jump_forward = {}
for (state, id_), next_state in transitions.items():
if state in dirty_states:
continue
if state in state_to_jump_forward:
dirty_states.add(state)
del state_to_jump_forward[state]
continue
if len(id_to_symbol[id_]) > 1:
dirty_states.add(state)
if id_ == fsm_info.alphabet_anything_value:
continue
symbols = id_to_symbol[id_]
for c in symbols:
if len(c) > 1:
# Skip byte level transitions
continue
state_to_jump_forward[state] = (id_to_symbol[id_][0], next_state)
outgoings_ct[state] += 1
if outgoings_ct[state] > 1:
if state in state_to_jump_forward:
del state_to_jump_forward[state]
break
state_to_jump_forward[state] = JumpEdge(
symbol=c,
symbol_next_state=next_state,
)
# Process the byte level jump forward
outgoings_ct = defaultdict(int)
for (state, id_), next_state in transitions.items():
if id_ == fsm_info.alphabet_anything_value:
continue
symbols = id_to_symbol[id_]
for c in symbols:
byte_ = None
if len(c) == 1 and ord(c) < 0x80:
# ASCII character
byte_ = ord(c)
elif len(c) == 2:
byte_ = int(symbols[0], 16)
if byte_ is not None:
outgoings_ct[state] += 1
if outgoings_ct[state] > 1:
if state in state_to_jump_forward:
del state_to_jump_forward[state]
break
e = state_to_jump_forward.get(state, JumpEdge())
e.byte = byte_
e.byte_next_state = next_state
state_to_jump_forward[state] = e
return state_to_jump_forward
self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
def valid_states(self):
return self.state_to_jump_forward.keys()
def jump_forward_symbol(self, state):
jump_forward_str = ""
next_state = state
while state in self.state_to_jump_forward:
e = self.state_to_jump_forward[state]
if e.symbol is None:
break
jump_forward_str += e.symbol
next_state = e.symbol_next_state
state = next_state
def jump_forward(self, state):
return jump_forward_str, next_state
def jump_forward_byte(self, state):
if state not in self.state_to_jump_forward:
return None
jump_forward_str = ""
jump_forward_bytes = []
next_state = None
while state in self.state_to_jump_forward:
symbol, next_state = self.state_to_jump_forward[state]
jump_forward_str += symbol
e = self.state_to_jump_forward[state]
assert e.byte is not None and e.byte_next_state is not None
jump_forward_bytes.append((e.byte, e.byte_next_state))
next_state = e.byte_next_state
state = next_state
return jump_forward_str, next_state
return jump_forward_bytes
def is_jump_forward_symbol_state(self, state):
return (
state in self.state_to_jump_forward
and self.state_to_jump_forward[state].symbol is not None
)
class JumpForwardCache(BaseCache):
@@ -69,12 +141,21 @@ class JumpForwardCache(BaseCache):
return JumpForwardMap(regex)
def test_main():
regex_string = r"The google's DNS sever address is " + IP_REGEX
def test_main(regex_string):
jump_forward_map = JumpForwardMap(regex_string)
for state in jump_forward_map.valid_states():
print(state, f'"{jump_forward_map.jump_forward(state)}"')
for state, e in jump_forward_map.state_to_jump_forward.items():
if e.symbol is not None:
jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)
print(f"{state} -> {next_state}", jump_forward_str)
bytes_ = jump_forward_map.jump_forward_byte(state)
print(f"{state} -> {bytes_[-1][1]}", [hex(b) for b, _ in bytes_])
if __name__ == "__main__":
test_main()
import outlines
outlines.caching.clear_cache()
test_main(r"The google's DNS sever address is " + IP_REGEX)
test_main(r"霍格沃茨特快列车|霍比特人比尔博")
# 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
# 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...