Correctly abort the failed grammar requests & Improve the handling of abort (#6803)

This commit is contained in:
Lianmin Zheng
2025-06-01 19:00:07 -07:00
committed by GitHub
parent 6a47b73024
commit 20fd53b8f6
16 changed files with 199 additions and 142 deletions

View File

@@ -60,7 +60,7 @@ class BaseGrammarObject:
raise NotImplementedError()
def copy(self) -> "BaseGrammarObject":
raise NotImplementedError()
return self
@property
def finished(self):
@@ -99,9 +99,12 @@ class BaseGrammarObject:
raise NotImplementedError()
INVALID_GRAMMAR_OBJ = BaseGrammarObject()
@dataclass
class CacheEntry:
value: Optional[BaseGrammarObject]
value: BaseGrammarObject
event: Event

View File

@@ -28,6 +28,7 @@ from llguidance.torch import (
)
from sglang.srt.constrained.base_grammar_backend import (
INVALID_GRAMMAR_OBJ,
BaseGrammarBackend,
BaseGrammarObject,
)
@@ -126,8 +127,8 @@ class GuidanceBackend(BaseGrammarBackend):
serialized_grammar=serialized_grammar,
)
except Exception as e:
logger.warning(f"Skip invalid grammar: {serialized_grammar}, {e=}")
return None
logger.error(f"Hit invalid grammar: {serialized_grammar=}, {e=}")
return INVALID_GRAMMAR_OBJ
def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]:
try:
@@ -138,8 +139,8 @@ class GuidanceBackend(BaseGrammarBackend):
},
)
except Exception as e:
logger.warning(f"Skip invalid grammar: {key_string=}, {e=}")
return None
logger.error(f"Hit invalid json_schema: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
return self._from_serialized(serialized_grammar)
def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]:
@@ -151,8 +152,8 @@ class GuidanceBackend(BaseGrammarBackend):
serialized_grammar = grammar_from("ebnf", key_string)
return self._from_serialized(serialized_grammar)
except ValueError as e:
logger.warning(f"Skip invalid ebnf: regex={key_string}, {e=}")
return None
logger.error(f"Hit invalid ebnf: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]:
try:
@@ -169,5 +170,5 @@ class GuidanceBackend(BaseGrammarBackend):
g = StructTag.to_grammar(tags)
return self._from_serialized(g)
except Exception as e:
logging.warning(f"Skip invalid structural_tag: {key_string}, {e=}")
return None
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ

View File

@@ -24,6 +24,7 @@ from outlines.models.transformers import TransformerTokenizer
from pydantic import BaseModel
from sglang.srt.constrained.base_grammar_backend import (
INVALID_GRAMMAR_OBJ,
BaseGrammarBackend,
BaseGrammarObject,
)
@@ -151,8 +152,8 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
# outlines <= 0.0.46
guide = RegexGuide(regex, self.outlines_tokenizer)
except interegular.patterns.InvalidSyntax as e:
logger.warning(f"skip invalid regex schema: {regex=}, {e=}")
return None
logger.error(f"Hit invalid regex schema: {regex=}, {e=}")
return INVALID_GRAMMAR_OBJ
jump_forward_map = None
return OutlinesGrammar(guide, jump_forward_map)
@@ -170,8 +171,8 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
whitespace_pattern=self.whitespace_pattern,
)
except (NotImplementedError, json.decoder.JSONDecodeError, ValueError) as e:
logger.warning(f"Skip invalid json_schema: {key_string=}, {e=}")
return None
logger.error(f"Hit invalid json_schema: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
return self._compile_regex(regex)
def dispatch_regex(self, key_string: str):

View File

@@ -28,6 +28,7 @@ from xgrammar import (
)
from sglang.srt.constrained.base_grammar_backend import (
INVALID_GRAMMAR_OBJ,
BaseGrammarBackend,
BaseGrammarObject,
)
@@ -152,10 +153,11 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
):
super().__init__()
tokenizer_info = TokenizerInfo.from_huggingface(
tokenizer, vocab_size=vocab_size
)
override_stop_tokens = None
if True:
tokenizer_info = TokenizerInfo.from_huggingface(
tokenizer, vocab_size=vocab_size
)
override_stop_tokens = None
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
self.vocab_size = vocab_size
@@ -178,25 +180,26 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
ctx = self.grammar_compiler.compile_builtin_json_grammar()
else:
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
except RuntimeError as e:
logging.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
return None
except (RuntimeError, json.decoder.JSONDecodeError) as e:
logging.error(f"Hit invalid json_schema: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
return self._from_context(ctx, key_string)
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
ctx = self.grammar_compiler.compile_grammar(key_string)
except RuntimeError as e:
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
return None
logging.error(f"Hit invalid ebnf: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
return self._from_context(ctx, key_string)
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
ctx = self.grammar_compiler.compile_regex(key_string)
except RuntimeError as e:
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
return None
logging.error(f"Hit invalid regex: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
return self._from_context(ctx, key_string)
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
@@ -213,13 +216,10 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
ctx = self.grammar_compiler.compile_structural_tag(
tags, structural_tag["triggers"]
)
except RuntimeError as e:
logging.warning(
f"Skip invalid structural_tag: structural_tag={key_string}, {e=}"
)
return None
except (RuntimeError, json.decoder.JSONDecodeError) as e:
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ
return self._from_context(ctx, key_string)
def reset(self):
if self.grammar_compiler:
self.grammar_compiler.clear_cache()
self.grammar_compiler.clear_cache()