jump-forward rename (#144)
This commit is contained in:
@@ -6,10 +6,10 @@ from sglang.srt.constrained.regex import FSMInfo, make_deterministic_fsm
|
||||
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||
|
||||
|
||||
class FastForwardMap:
|
||||
class JumpForwardMap:
|
||||
def __init__(self, regex_string):
|
||||
@disk_cache()
|
||||
def _init_state_to_fast_forward(regex_string):
|
||||
def _init_state_to_jump_forward(regex_string):
|
||||
regex_pattern = interegular.parse_pattern(regex_string)
|
||||
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
|
||||
|
||||
@@ -22,54 +22,54 @@ class FastForwardMap:
|
||||
|
||||
transitions = fsm_info.transitions
|
||||
dirty_states = set()
|
||||
state_to_fast_forward = {}
|
||||
state_to_jump_forward = {}
|
||||
|
||||
for (state, id_), next_state in transitions.items():
|
||||
if state in dirty_states:
|
||||
continue
|
||||
if state in state_to_fast_forward:
|
||||
if state in state_to_jump_forward:
|
||||
dirty_states.add(state)
|
||||
del state_to_fast_forward[state]
|
||||
del state_to_jump_forward[state]
|
||||
continue
|
||||
if len(id_to_symbol[id_]) > 1:
|
||||
dirty_states.add(state)
|
||||
continue
|
||||
|
||||
state_to_fast_forward[state] = (id_to_symbol[id_][0], next_state)
|
||||
state_to_jump_forward[state] = (id_to_symbol[id_][0], next_state)
|
||||
|
||||
return state_to_fast_forward
|
||||
return state_to_jump_forward
|
||||
|
||||
self.state_to_fast_forward = _init_state_to_fast_forward(regex_string)
|
||||
self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
|
||||
|
||||
def valid_states(self):
|
||||
return self.state_to_fast_forward.keys()
|
||||
return self.state_to_jump_forward.keys()
|
||||
|
||||
def fast_forward(self, state):
|
||||
if state not in self.state_to_fast_forward:
|
||||
def jump_forward(self, state):
|
||||
if state not in self.state_to_jump_forward:
|
||||
return None
|
||||
|
||||
fast_forward_str = ""
|
||||
jump_forward_str = ""
|
||||
next_state = None
|
||||
while state in self.state_to_fast_forward:
|
||||
symbol, next_state = self.state_to_fast_forward[state]
|
||||
fast_forward_str += symbol
|
||||
while state in self.state_to_jump_forward:
|
||||
symbol, next_state = self.state_to_jump_forward[state]
|
||||
jump_forward_str += symbol
|
||||
state = next_state
|
||||
return fast_forward_str, next_state
|
||||
return jump_forward_str, next_state
|
||||
|
||||
|
||||
class FastForwardCache(BaseCache):
|
||||
class JumpForwardCache(BaseCache):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def init_value(self, regex):
|
||||
return FastForwardMap(regex)
|
||||
return JumpForwardMap(regex)
|
||||
|
||||
|
||||
def test_main():
|
||||
regex_string = r"The google's DNS sever address is " + IP_REGEX
|
||||
fast_forward_map = FastForwardMap(regex_string)
|
||||
for state in fast_forward_map.valid_states():
|
||||
print(state, f'"{fast_forward_map.fast_forward(state)}"')
|
||||
jump_forward_map = JumpForwardMap(regex_string)
|
||||
for state in jump_forward_map.valid_states():
|
||||
print(state, f'"{jump_forward_map.jump_forward(state)}"')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Reference in New Issue
Block a user