jump-forward rename (#144)

This commit is contained in:
Liangsheng Yin
2024-02-05 16:50:37 +08:00
committed by GitHub
parent 82fa69b3cc
commit 26f0bedc8f
12 changed files with 70 additions and 70 deletions

View File

@@ -6,10 +6,10 @@ from sglang.srt.constrained.regex import FSMInfo, make_deterministic_fsm
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
class FastForwardMap:
class JumpForwardMap:
def __init__(self, regex_string):
@disk_cache()
def _init_state_to_fast_forward(regex_string):
def _init_state_to_jump_forward(regex_string):
regex_pattern = interegular.parse_pattern(regex_string)
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
@@ -22,54 +22,54 @@ class FastForwardMap:
transitions = fsm_info.transitions
dirty_states = set()
state_to_fast_forward = {}
state_to_jump_forward = {}
for (state, id_), next_state in transitions.items():
if state in dirty_states:
continue
if state in state_to_fast_forward:
if state in state_to_jump_forward:
dirty_states.add(state)
del state_to_fast_forward[state]
del state_to_jump_forward[state]
continue
if len(id_to_symbol[id_]) > 1:
dirty_states.add(state)
continue
state_to_fast_forward[state] = (id_to_symbol[id_][0], next_state)
state_to_jump_forward[state] = (id_to_symbol[id_][0], next_state)
return state_to_fast_forward
return state_to_jump_forward
self.state_to_fast_forward = _init_state_to_fast_forward(regex_string)
self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
def valid_states(self):
return self.state_to_fast_forward.keys()
return self.state_to_jump_forward.keys()
def fast_forward(self, state):
if state not in self.state_to_fast_forward:
def jump_forward(self, state):
if state not in self.state_to_jump_forward:
return None
fast_forward_str = ""
jump_forward_str = ""
next_state = None
while state in self.state_to_fast_forward:
symbol, next_state = self.state_to_fast_forward[state]
fast_forward_str += symbol
while state in self.state_to_jump_forward:
symbol, next_state = self.state_to_jump_forward[state]
jump_forward_str += symbol
state = next_state
return fast_forward_str, next_state
return jump_forward_str, next_state
class FastForwardCache(BaseCache):
class JumpForwardCache(BaseCache):
def __init__(self):
super().__init__()
def init_value(self, regex):
return FastForwardMap(regex)
return JumpForwardMap(regex)
def test_main():
regex_string = r"The google's DNS sever address is " + IP_REGEX
fast_forward_map = FastForwardMap(regex_string)
for state in fast_forward_map.valid_states():
print(state, f'"{fast_forward_map.fast_forward(state)}"')
jump_forward_map = JumpForwardMap(regex_string)
for state in jump_forward_map.valid_states():
print(state, f'"{jump_forward_map.jump_forward(state)}"')
if __name__ == "__main__":