[FIX] Catch syntax error of Regex Guide to avoid crash (#1521)
This commit is contained in:
@@ -14,13 +14,17 @@ limitations under the License.
|
||||
"""
|
||||
|
||||
"""Cache for the compressed finite state machine."""
|
||||
import logging
|
||||
|
||||
from interegular import InvalidSyntax, parse_pattern
|
||||
from outlines.fsm.json_schema import build_regex_from_schema
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
||||
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FSMCache(BaseToolCache):
|
||||
def __init__(
|
||||
@@ -76,5 +80,9 @@ class FSMCache(BaseToolCache):
|
||||
regex = key_string
|
||||
else:
|
||||
raise ValueError(f"Invalid key_type: {key_type}")
|
||||
|
||||
try:
|
||||
parse_pattern(regex)
|
||||
except InvalidSyntax as e:
|
||||
logger.warning(f"skip invalid regex guide: {regex=}, {e=}")
|
||||
return None, regex
|
||||
return RegexGuide(regex, self.outlines_tokenizer), regex
|
||||
|
||||
@@ -19,10 +19,12 @@ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
import interegular
|
||||
import outlines.caching
|
||||
from interegular import InvalidSyntax
|
||||
|
||||
from sglang.srt.constrained import (
|
||||
FSMInfo,
|
||||
@@ -34,6 +36,8 @@ from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
||||
|
||||
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class JumpEdge:
|
||||
@@ -47,7 +51,12 @@ class JumpForwardMap:
|
||||
def __init__(self, regex_string):
|
||||
@disk_cache()
|
||||
def _init_state_to_jump_forward(regex_string):
|
||||
regex_pattern = interegular.parse_pattern(regex_string)
|
||||
try:
|
||||
regex_pattern = interegular.parse_pattern(regex_string)
|
||||
except InvalidSyntax as e:
|
||||
logger.warning(f"skip invalid regex: {regex_string}, {e=}")
|
||||
self.state_to_jump_forward = None
|
||||
return
|
||||
|
||||
byte_fsm = make_byte_level_fsm(
|
||||
regex_pattern.to_fsm().reduce(), keep_utf8=True
|
||||
@@ -165,7 +174,11 @@ class JumpForwardCache(BaseToolCache):
|
||||
super().__init__()
|
||||
|
||||
def init_value(self, regex):
|
||||
return JumpForwardMap(regex)
|
||||
forward_map = JumpForwardMap(regex)
|
||||
if forward_map.state_to_jump_forward:
|
||||
return forward_map
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def test_main(regex_string):
|
||||
|
||||
Reference in New Issue
Block a user