diff --git a/python/sglang/srt/constrained/fsm_cache.py b/python/sglang/srt/constrained/fsm_cache.py index 4ac4cef48..ab025f26e 100644 --- a/python/sglang/srt/constrained/fsm_cache.py +++ b/python/sglang/srt/constrained/fsm_cache.py @@ -14,13 +14,17 @@ limitations under the License. """ """Cache for the compressed finite state machine.""" +import logging +from interegular import InvalidSyntax, parse_pattern from outlines.fsm.json_schema import build_regex_from_schema from transformers import AutoTokenizer from sglang.srt.constrained import RegexGuide, TransformerTokenizer from sglang.srt.constrained.base_tool_cache import BaseToolCache +logger = logging.getLogger(__name__) + class FSMCache(BaseToolCache): def __init__( @@ -76,5 +80,9 @@ class FSMCache(BaseToolCache): regex = key_string else: raise ValueError(f"Invalid key_type: {key_type}") - + try: + parse_pattern(regex) + except InvalidSyntax as e: + logger.warning(f"skip invalid regex guide: {regex=}, {e=}") + return None, regex return RegexGuide(regex, self.outlines_tokenizer), regex diff --git a/python/sglang/srt/constrained/jump_forward.py b/python/sglang/srt/constrained/jump_forward.py index b00c48d47..1ebc8b217 100644 --- a/python/sglang/srt/constrained/jump_forward.py +++ b/python/sglang/srt/constrained/jump_forward.py @@ -19,10 +19,12 @@ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/ """ import dataclasses +import logging from collections import defaultdict import interegular import outlines.caching +from interegular import InvalidSyntax from sglang.srt.constrained import ( FSMInfo, @@ -34,6 +36,8 @@ from sglang.srt.constrained.base_tool_cache import BaseToolCache IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" +logger = logging.getLogger(__name__) + @dataclasses.dataclass class JumpEdge: @@ -47,7 +51,12 @@ class JumpForwardMap: def __init__(self, regex_string): @disk_cache() def _init_state_to_jump_forward(regex_string): - regex_pattern = interegular.parse_pattern(regex_string) + try: + regex_pattern = interegular.parse_pattern(regex_string) + except InvalidSyntax as e: + logger.warning(f"skip invalid regex: {regex_string}, {e=}") + self.state_to_jump_forward = None + return byte_fsm = make_byte_level_fsm( regex_pattern.to_fsm().reduce(), keep_utf8=True @@ -165,7 +174,11 @@ class JumpForwardCache(BaseToolCache): super().__init__() def init_value(self, regex): - return JumpForwardMap(regex) + forward_map = JumpForwardMap(regex) + if forward_map.state_to_jump_forward: + return forward_map + else: + return None def test_main(regex_string):