[FIX] Catch syntax error of Regex Guide to avoid crash (#1521)
This commit is contained in:
@@ -14,13 +14,17 @@ limitations under the License.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
"""Cache for the compressed finite state machine."""
|
"""Cache for the compressed finite state machine."""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from interegular import InvalidSyntax, parse_pattern
|
||||||
from outlines.fsm.json_schema import build_regex_from_schema
|
from outlines.fsm.json_schema import build_regex_from_schema
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
||||||
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class FSMCache(BaseToolCache):
|
class FSMCache(BaseToolCache):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -76,5 +80,9 @@ class FSMCache(BaseToolCache):
|
|||||||
regex = key_string
|
regex = key_string
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid key_type: {key_type}")
|
raise ValueError(f"Invalid key_type: {key_type}")
|
||||||
|
try:
|
||||||
|
parse_pattern(regex)
|
||||||
|
except InvalidSyntax as e:
|
||||||
|
logger.warning(f"skip invalid regex guide: {regex=}, {e=}")
|
||||||
|
return None, regex
|
||||||
return RegexGuide(regex, self.outlines_tokenizer), regex
|
return RegexGuide(regex, self.outlines_tokenizer), regex
|
||||||
|
|||||||
@@ -19,10 +19,12 @@ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
import interegular
|
import interegular
|
||||||
import outlines.caching
|
import outlines.caching
|
||||||
|
from interegular import InvalidSyntax
|
||||||
|
|
||||||
from sglang.srt.constrained import (
|
from sglang.srt.constrained import (
|
||||||
FSMInfo,
|
FSMInfo,
|
||||||
@@ -34,6 +36,8 @@ from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
|||||||
|
|
||||||
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class JumpEdge:
|
class JumpEdge:
|
||||||
@@ -47,7 +51,12 @@ class JumpForwardMap:
|
|||||||
def __init__(self, regex_string):
|
def __init__(self, regex_string):
|
||||||
@disk_cache()
|
@disk_cache()
|
||||||
def _init_state_to_jump_forward(regex_string):
|
def _init_state_to_jump_forward(regex_string):
|
||||||
regex_pattern = interegular.parse_pattern(regex_string)
|
try:
|
||||||
|
regex_pattern = interegular.parse_pattern(regex_string)
|
||||||
|
except InvalidSyntax as e:
|
||||||
|
logger.warning(f"skip invalid regex: {regex_string}, {e=}")
|
||||||
|
self.state_to_jump_forward = None
|
||||||
|
return
|
||||||
|
|
||||||
byte_fsm = make_byte_level_fsm(
|
byte_fsm = make_byte_level_fsm(
|
||||||
regex_pattern.to_fsm().reduce(), keep_utf8=True
|
regex_pattern.to_fsm().reduce(), keep_utf8=True
|
||||||
@@ -165,7 +174,11 @@ class JumpForwardCache(BaseToolCache):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def init_value(self, regex):
|
def init_value(self, regex):
|
||||||
return JumpForwardMap(regex)
|
forward_map = JumpForwardMap(regex)
|
||||||
|
if forward_map.state_to_jump_forward:
|
||||||
|
return forward_map
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def test_main(regex_string):
|
def test_main(regex_string):
|
||||||
|
|||||||
Reference in New Issue
Block a user