Decode Incrementally (#517)
This commit is contained in:
@@ -3,8 +3,8 @@ from typing import Dict, Optional, Union
|
||||
|
||||
from outlines.caching import cache as disk_cache
|
||||
from outlines.caching import disable_cache
|
||||
from outlines.fsm.fsm import RegexFSM
|
||||
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm
|
||||
from outlines.fsm.guide import RegexGuide
|
||||
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm, make_byte_level_fsm
|
||||
from outlines.models.transformers import TransformerTokenizer
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -28,11 +28,12 @@ except ImportError:
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RegexFSM",
|
||||
"RegexGuide",
|
||||
"FSMInfo",
|
||||
"make_deterministic_fsm",
|
||||
"build_regex_from_object",
|
||||
"TransformerTokenizer",
|
||||
"disk_cache",
|
||||
"disable_cache",
|
||||
"make_byte_level_fsm",
|
||||
]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""Cache for the compressed finite state machine."""
|
||||
from sglang.srt.constrained import RegexFSM, TransformerTokenizer
|
||||
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
||||
from sglang.srt.constrained.base_cache import BaseCache
|
||||
|
||||
|
||||
@@ -26,4 +26,4 @@ class FSMCache(BaseCache):
|
||||
)
|
||||
|
||||
def init_value(self, regex):
|
||||
return RegexFSM(regex, self.outlines_tokenizer)
|
||||
return RegexGuide(regex, self.outlines_tokenizer)
|
||||
|
||||
@@ -2,20 +2,41 @@
|
||||
Faster constrained decoding.
|
||||
Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
|
||||
"""
|
||||
import interegular
|
||||
|
||||
from sglang.srt.constrained import FSMInfo, disk_cache, make_deterministic_fsm
|
||||
import interegular
|
||||
import dataclasses
|
||||
from collections import defaultdict
|
||||
|
||||
import outlines.caching
|
||||
from sglang.srt.constrained import (
|
||||
FSMInfo,
|
||||
disk_cache,
|
||||
make_deterministic_fsm,
|
||||
make_byte_level_fsm,
|
||||
)
|
||||
from sglang.srt.constrained.base_cache import BaseCache
|
||||
|
||||
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class JumpEdge:
|
||||
symbol: str = None
|
||||
symbol_next_state: int = None
|
||||
byte: int = None
|
||||
byte_next_state: int = None
|
||||
|
||||
|
||||
class JumpForwardMap:
|
||||
def __init__(self, regex_string):
|
||||
@disk_cache()
|
||||
def _init_state_to_jump_forward(regex_string):
|
||||
regex_pattern = interegular.parse_pattern(regex_string)
|
||||
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
|
||||
|
||||
byte_fsm = make_byte_level_fsm(
|
||||
regex_pattern.to_fsm().reduce(), keep_utf8=True
|
||||
)
|
||||
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
|
||||
|
||||
fsm_info: FSMInfo = regex_fsm.fsm_info
|
||||
|
||||
@@ -25,40 +46,91 @@ class JumpForwardMap:
|
||||
id_to_symbol.setdefault(id_, []).append(symbol)
|
||||
|
||||
transitions = fsm_info.transitions
|
||||
dirty_states = set()
|
||||
outgoings_ct = defaultdict(int)
|
||||
state_to_jump_forward = {}
|
||||
|
||||
for (state, id_), next_state in transitions.items():
|
||||
if state in dirty_states:
|
||||
continue
|
||||
if state in state_to_jump_forward:
|
||||
dirty_states.add(state)
|
||||
del state_to_jump_forward[state]
|
||||
continue
|
||||
if len(id_to_symbol[id_]) > 1:
|
||||
dirty_states.add(state)
|
||||
if id_ == fsm_info.alphabet_anything_value:
|
||||
continue
|
||||
symbols = id_to_symbol[id_]
|
||||
for c in symbols:
|
||||
if len(c) > 1:
|
||||
# Skip byte level transitions
|
||||
continue
|
||||
|
||||
state_to_jump_forward[state] = (id_to_symbol[id_][0], next_state)
|
||||
outgoings_ct[state] += 1
|
||||
if outgoings_ct[state] > 1:
|
||||
if state in state_to_jump_forward:
|
||||
del state_to_jump_forward[state]
|
||||
break
|
||||
|
||||
state_to_jump_forward[state] = JumpEdge(
|
||||
symbol=c,
|
||||
symbol_next_state=next_state,
|
||||
)
|
||||
|
||||
# Process the byte level jump forward
|
||||
outgoings_ct = defaultdict(int)
|
||||
for (state, id_), next_state in transitions.items():
|
||||
if id_ == fsm_info.alphabet_anything_value:
|
||||
continue
|
||||
symbols = id_to_symbol[id_]
|
||||
for c in symbols:
|
||||
byte_ = None
|
||||
if len(c) == 1 and ord(c) < 0x80:
|
||||
# ASCII character
|
||||
byte_ = ord(c)
|
||||
elif len(c) == 2:
|
||||
byte_ = int(symbols[0], 16)
|
||||
|
||||
if byte_ is not None:
|
||||
outgoings_ct[state] += 1
|
||||
if outgoings_ct[state] > 1:
|
||||
if state in state_to_jump_forward:
|
||||
del state_to_jump_forward[state]
|
||||
break
|
||||
e = state_to_jump_forward.get(state, JumpEdge())
|
||||
e.byte = byte_
|
||||
e.byte_next_state = next_state
|
||||
state_to_jump_forward[state] = e
|
||||
|
||||
return state_to_jump_forward
|
||||
|
||||
self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
|
||||
|
||||
def valid_states(self):
|
||||
return self.state_to_jump_forward.keys()
|
||||
def jump_forward_symbol(self, state):
|
||||
jump_forward_str = ""
|
||||
next_state = state
|
||||
while state in self.state_to_jump_forward:
|
||||
e = self.state_to_jump_forward[state]
|
||||
if e.symbol is None:
|
||||
break
|
||||
jump_forward_str += e.symbol
|
||||
next_state = e.symbol_next_state
|
||||
state = next_state
|
||||
|
||||
def jump_forward(self, state):
|
||||
return jump_forward_str, next_state
|
||||
|
||||
def jump_forward_byte(self, state):
|
||||
if state not in self.state_to_jump_forward:
|
||||
return None
|
||||
|
||||
jump_forward_str = ""
|
||||
jump_forward_bytes = []
|
||||
next_state = None
|
||||
while state in self.state_to_jump_forward:
|
||||
symbol, next_state = self.state_to_jump_forward[state]
|
||||
jump_forward_str += symbol
|
||||
e = self.state_to_jump_forward[state]
|
||||
assert e.byte is not None and e.byte_next_state is not None
|
||||
jump_forward_bytes.append((e.byte, e.byte_next_state))
|
||||
next_state = e.byte_next_state
|
||||
state = next_state
|
||||
return jump_forward_str, next_state
|
||||
|
||||
return jump_forward_bytes
|
||||
|
||||
def is_jump_forward_symbol_state(self, state):
|
||||
return (
|
||||
state in self.state_to_jump_forward
|
||||
and self.state_to_jump_forward[state].symbol is not None
|
||||
)
|
||||
|
||||
|
||||
class JumpForwardCache(BaseCache):
|
||||
@@ -69,12 +141,21 @@ class JumpForwardCache(BaseCache):
|
||||
return JumpForwardMap(regex)
|
||||
|
||||
|
||||
def test_main():
|
||||
regex_string = r"The google's DNS sever address is " + IP_REGEX
|
||||
def test_main(regex_string):
|
||||
jump_forward_map = JumpForwardMap(regex_string)
|
||||
for state in jump_forward_map.valid_states():
|
||||
print(state, f'"{jump_forward_map.jump_forward(state)}"')
|
||||
for state, e in jump_forward_map.state_to_jump_forward.items():
|
||||
if e.symbol is not None:
|
||||
jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)
|
||||
print(f"{state} -> {next_state}", jump_forward_str)
|
||||
bytes_ = jump_forward_map.jump_forward_byte(state)
|
||||
print(f"{state} -> {bytes_[-1][1]}", [hex(b) for b, _ in bytes_])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_main()
|
||||
import outlines
|
||||
|
||||
outlines.caching.clear_cache()
|
||||
test_main(r"The google's DNS sever address is " + IP_REGEX)
|
||||
test_main(r"霍格沃茨特快列车|霍比特人比尔博")
|
||||
# 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
|
||||
# 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...
|
||||
|
||||
Reference in New Issue
Block a user