Fix grammar backend for tensor parallelism (#2020)
This commit is contained in:
@@ -27,8 +27,6 @@ from interegular import InvalidSyntax
|
||||
from outlines.caching import cache as disk_cache
|
||||
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
|
||||
|
||||
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
||||
|
||||
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -42,92 +40,90 @@ class JumpEdge:
|
||||
byte_next_state: int = None
|
||||
|
||||
|
||||
@disk_cache()
|
||||
def init_state_to_jump_forward(regex_string):
|
||||
try:
|
||||
regex_pattern = interegular.parse_pattern(regex_string)
|
||||
except InvalidSyntax as e:
|
||||
logger.warning(f"skip invalid regex: {regex_string}, {e=}")
|
||||
return
|
||||
|
||||
byte_fsm = make_byte_level_fsm(regex_pattern.to_fsm().reduce(), keep_utf8=True)
|
||||
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
|
||||
|
||||
fsm_info: FSMInfo = regex_fsm.fsm_info
|
||||
|
||||
symbol_to_id = fsm_info.alphabet_symbol_mapping
|
||||
id_to_symbol = {}
|
||||
for symbol, id_ in symbol_to_id.items():
|
||||
id_to_symbol.setdefault(id_, []).append(symbol)
|
||||
|
||||
transitions = fsm_info.transitions
|
||||
|
||||
outgoings_ct = defaultdict(int)
|
||||
# NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally
|
||||
for s in fsm_info.finals:
|
||||
outgoings_ct[s] = 1
|
||||
|
||||
state_to_jump_forward = {}
|
||||
for (state, id_), next_state in transitions.items():
|
||||
if id_ == fsm_info.alphabet_anything_value:
|
||||
# Arbitrarily symbol cannot be recognized as jump forward
|
||||
continue
|
||||
|
||||
symbols = id_to_symbol[id_]
|
||||
for c in symbols:
|
||||
if len(c) > 1:
|
||||
# Skip byte level transitions like c = "5E"
|
||||
continue
|
||||
|
||||
outgoings_ct[state] += 1
|
||||
if outgoings_ct[state] > 1:
|
||||
if state in state_to_jump_forward:
|
||||
del state_to_jump_forward[state]
|
||||
break
|
||||
|
||||
state_to_jump_forward[state] = JumpEdge(
|
||||
symbol=c,
|
||||
symbol_next_state=next_state,
|
||||
)
|
||||
|
||||
# Process the byte level jump forward
|
||||
outgoings_ct = defaultdict(int)
|
||||
for s in fsm_info.finals:
|
||||
outgoings_ct[s] = 1
|
||||
|
||||
for (state, id_), next_state in transitions.items():
|
||||
if id_ == fsm_info.alphabet_anything_value:
|
||||
continue
|
||||
symbols = id_to_symbol[id_]
|
||||
for c in symbols:
|
||||
byte_ = None
|
||||
if len(c) == 1 and ord(c) < 0x80:
|
||||
# ASCII character
|
||||
byte_ = ord(c)
|
||||
elif len(c) > 1:
|
||||
# FIXME: This logic is due to the leading \x00
|
||||
# https://github.com/outlines-dev/outlines/pull/930
|
||||
byte_ = int(symbols[0][1:], 16)
|
||||
|
||||
if byte_ is not None:
|
||||
outgoings_ct[state] += 1
|
||||
if outgoings_ct[state] > 1:
|
||||
if state in state_to_jump_forward:
|
||||
del state_to_jump_forward[state]
|
||||
break
|
||||
e = state_to_jump_forward.get(state, JumpEdge())
|
||||
e.byte = byte_
|
||||
e.byte_next_state = next_state
|
||||
state_to_jump_forward[state] = e
|
||||
|
||||
return state_to_jump_forward
|
||||
|
||||
|
||||
class OutlinesJumpForwardMap:
|
||||
def __init__(self, regex_string):
|
||||
@disk_cache()
|
||||
def _init_state_to_jump_forward(regex_string):
|
||||
try:
|
||||
regex_pattern = interegular.parse_pattern(regex_string)
|
||||
except InvalidSyntax as e:
|
||||
logger.warning(f"skip invalid regex: {regex_string}, {e=}")
|
||||
self.state_to_jump_forward = None
|
||||
return
|
||||
|
||||
byte_fsm = make_byte_level_fsm(
|
||||
regex_pattern.to_fsm().reduce(), keep_utf8=True
|
||||
)
|
||||
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
|
||||
|
||||
fsm_info: FSMInfo = regex_fsm.fsm_info
|
||||
|
||||
symbol_to_id = fsm_info.alphabet_symbol_mapping
|
||||
id_to_symbol = {}
|
||||
for symbol, id_ in symbol_to_id.items():
|
||||
id_to_symbol.setdefault(id_, []).append(symbol)
|
||||
|
||||
transitions = fsm_info.transitions
|
||||
|
||||
outgoings_ct = defaultdict(int)
|
||||
# NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally
|
||||
for s in fsm_info.finals:
|
||||
outgoings_ct[s] = 1
|
||||
|
||||
state_to_jump_forward = {}
|
||||
for (state, id_), next_state in transitions.items():
|
||||
if id_ == fsm_info.alphabet_anything_value:
|
||||
# Arbitrarily symbol cannot be recognized as jump forward
|
||||
continue
|
||||
|
||||
symbols = id_to_symbol[id_]
|
||||
for c in symbols:
|
||||
if len(c) > 1:
|
||||
# Skip byte level transitions like c = "5E"
|
||||
continue
|
||||
|
||||
outgoings_ct[state] += 1
|
||||
if outgoings_ct[state] > 1:
|
||||
if state in state_to_jump_forward:
|
||||
del state_to_jump_forward[state]
|
||||
break
|
||||
|
||||
state_to_jump_forward[state] = JumpEdge(
|
||||
symbol=c,
|
||||
symbol_next_state=next_state,
|
||||
)
|
||||
|
||||
# Process the byte level jump forward
|
||||
outgoings_ct = defaultdict(int)
|
||||
for s in fsm_info.finals:
|
||||
outgoings_ct[s] = 1
|
||||
|
||||
for (state, id_), next_state in transitions.items():
|
||||
if id_ == fsm_info.alphabet_anything_value:
|
||||
continue
|
||||
symbols = id_to_symbol[id_]
|
||||
for c in symbols:
|
||||
byte_ = None
|
||||
if len(c) == 1 and ord(c) < 0x80:
|
||||
# ASCII character
|
||||
byte_ = ord(c)
|
||||
elif len(c) > 1:
|
||||
# FIXME: This logic is due to the leading \x00
|
||||
# https://github.com/outlines-dev/outlines/pull/930
|
||||
byte_ = int(symbols[0][1:], 16)
|
||||
|
||||
if byte_ is not None:
|
||||
outgoings_ct[state] += 1
|
||||
if outgoings_ct[state] > 1:
|
||||
if state in state_to_jump_forward:
|
||||
del state_to_jump_forward[state]
|
||||
break
|
||||
e = state_to_jump_forward.get(state, JumpEdge())
|
||||
e.byte = byte_
|
||||
e.byte_next_state = next_state
|
||||
state_to_jump_forward[state] = e
|
||||
|
||||
return state_to_jump_forward
|
||||
|
||||
self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
|
||||
self.state_to_jump_forward = init_state_to_jump_forward(regex_string)
|
||||
|
||||
def jump_forward_symbol(self, state):
|
||||
jump_forward_str = ""
|
||||
@@ -164,18 +160,6 @@ class OutlinesJumpForwardMap:
|
||||
)
|
||||
|
||||
|
||||
class OutlinesJumpForwardCache(BaseToolCache):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def init_value(self, regex):
|
||||
forward_map = OutlinesJumpForwardMap(regex)
|
||||
if forward_map.state_to_jump_forward:
|
||||
return forward_map
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def test_main(regex_string):
|
||||
jump_forward_map = OutlinesJumpForwardMap(regex_string)
|
||||
for state, e in jump_forward_map.state_to_jump_forward.items():
|
||||
|
||||
Reference in New Issue
Block a user