[Fix & Style] Refactor the grammar backend to reduce human errors and improve readability (#4030)
This commit is contained in:
@@ -13,31 +13,130 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""The baseclass of a backend for grammar-guided constrained decoding."""
|
"""The baseclass of a backend for grammar-guided constrained decoding."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from threading import Event, Lock
|
from threading import Event, Lock
|
||||||
from typing import Any, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseGrammarObject(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
||||||
|
"""
|
||||||
|
Try to jump forward in the grammar.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A jump forward helper which may be used in `jump_forward_str_state`.
|
||||||
|
None if the jump forward is not possible.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
||||||
|
"""
|
||||||
|
Jump forward for the grammar.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of the jump forward string and the next state of the grammar
|
||||||
|
(which can be used in `jump_and_retokenize` if needed).
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def jump_and_retokenize(
|
||||||
|
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Jump forward occurs, and update the grammar state if needed.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def allocate_vocab_mask(
|
||||||
|
self, vocab_size: int, batch_size: int, device
|
||||||
|
) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def copy(self) -> "BaseGrammarObject":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CacheEntry:
|
class CacheEntry:
|
||||||
value: Any
|
value: Optional[BaseGrammarObject]
|
||||||
event: Event
|
event: Event
|
||||||
|
|
||||||
|
|
||||||
class BaseGrammarObject:
|
class BaseGrammarBackend(ABC):
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BaseGrammarBackend:
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.executor = ThreadPoolExecutor()
|
self.executor = ThreadPoolExecutor()
|
||||||
self.cache = {}
|
self.cache: Dict[Tuple[str, str], CacheEntry] = {}
|
||||||
self.cache_lock = Lock()
|
self.cache_lock = Lock()
|
||||||
|
|
||||||
def init_value(self, key: Tuple[str, str]) -> BaseGrammarObject:
|
def _not_supported(self, key_type: str, key_string: str) -> None:
|
||||||
|
logger.warning(f"Skip unsupported {key_type}: {key_type}={key_string}")
|
||||||
|
|
||||||
|
def dispatch_fallback(
|
||||||
|
self, key_type: str, key_string: str
|
||||||
|
) -> Optional[BaseGrammarObject]:
|
||||||
|
"""
|
||||||
|
This function should not be reached in any case.
|
||||||
|
"""
|
||||||
|
raise ValueError(f"Invalid key_type: {key_type}={key_string}")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def dispatch_json(self, key_string: str) -> Optional[BaseGrammarObject]:
|
||||||
|
return self._not_supported("json", key_string)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def dispatch_regex(self, key_string: str) -> Optional[BaseGrammarObject]:
|
||||||
|
return self._not_supported("regex", key_string)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def dispatch_ebnf(self, key_string: str) -> Optional[BaseGrammarObject]:
|
||||||
|
return self._not_supported("ebnf", key_string)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def dispatch_structural_tag(self, key_string: str) -> Optional[BaseGrammarObject]:
|
||||||
|
return self._not_supported("structural_tag", key_string)
|
||||||
|
|
||||||
|
def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
||||||
|
key_type, key_string = key
|
||||||
|
if key_type == "json":
|
||||||
|
return self.dispatch_json(key_string)
|
||||||
|
elif key_type == "regex":
|
||||||
|
return self.dispatch_regex(key_string)
|
||||||
|
elif key_type == "ebnf":
|
||||||
|
return self.dispatch_ebnf(key_string)
|
||||||
|
elif key_type == "structural_tag":
|
||||||
|
return self.dispatch_structural_tag(key_string)
|
||||||
|
else:
|
||||||
|
return self.dispatch_fallback(key_type, key_string)
|
||||||
|
|
||||||
|
def _init_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
||||||
with self.cache_lock:
|
with self.cache_lock:
|
||||||
if key in self.cache:
|
if key in self.cache:
|
||||||
cache_hit = True
|
cache_hit = True
|
||||||
@@ -50,13 +149,10 @@ class BaseGrammarBackend:
|
|||||||
if cache_hit:
|
if cache_hit:
|
||||||
entry.event.wait()
|
entry.event.wait()
|
||||||
else:
|
else:
|
||||||
entry.value = self.init_value_impl(key)
|
entry.value = self._init_value_dispatch(key)
|
||||||
entry.event.set()
|
entry.event.set()
|
||||||
return entry.value.copy() if entry.value else None
|
return entry.value.copy() if entry.value else None
|
||||||
|
|
||||||
def init_value_impl(self, key: Tuple[str, str]) -> BaseGrammarObject:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
||||||
with self.cache_lock:
|
with self.cache_lock:
|
||||||
entry = self.cache.get(key)
|
entry = self.cache.get(key)
|
||||||
@@ -66,7 +162,7 @@ class BaseGrammarBackend:
|
|||||||
return val.copy() if val else None
|
return val.copy() if val else None
|
||||||
|
|
||||||
def get_future_value(self, key: Tuple[str, str]) -> Future:
|
def get_future_value(self, key: Tuple[str, str]) -> Future:
|
||||||
return self.executor.submit(self.init_value, key)
|
return self.executor.submit(self._init_value, key)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
with self.cache_lock:
|
with self.cache_lock:
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ class GuidanceGrammar(BaseGrammarObject):
|
|||||||
self.finished = False
|
self.finished = False
|
||||||
self.bitmask = None
|
self.bitmask = None
|
||||||
|
|
||||||
def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
|
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
||||||
if len(self.pending_ff_tokens) > 0:
|
if len(self.pending_ff_tokens) > 0:
|
||||||
s = self.llguidance_tokenizer.decode_str(self.pending_ff_tokens)
|
s = self.llguidance_tokenizer.decode_str(self.pending_ff_tokens)
|
||||||
ff_tokens = self.pending_ff_tokens
|
ff_tokens = self.pending_ff_tokens
|
||||||
@@ -125,22 +125,27 @@ class GuidanceBackend(BaseGrammarBackend):
|
|||||||
)
|
)
|
||||||
self.llguidance_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None)
|
self.llguidance_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None)
|
||||||
|
|
||||||
def init_value_impl(self, key: Tuple[str, str]) -> GuidanceGrammar:
|
def _from_serialized(self, serialized_grammar) -> GuidanceGrammar:
|
||||||
mode, value = key
|
|
||||||
if mode == "json":
|
|
||||||
json_schema = value
|
|
||||||
compiler = llguidance.JsonCompiler(
|
|
||||||
whitespace_flexible=self.whitespace_flexible
|
|
||||||
)
|
|
||||||
serialized_grammar = compiler.compile(json_schema)
|
|
||||||
elif mode == "regex":
|
|
||||||
compiler = llguidance.RegexCompiler()
|
|
||||||
serialized_grammar = compiler.compile(regex=value)
|
|
||||||
elif mode == "ebnf":
|
|
||||||
compiler = llguidance.LarkCompiler()
|
|
||||||
serialized_grammar = compiler.compile(any_to_lark(value))
|
|
||||||
|
|
||||||
return GuidanceGrammar(
|
return GuidanceGrammar(
|
||||||
llguidance_tokenizer=self.llguidance_tokenizer,
|
llguidance_tokenizer=self.llguidance_tokenizer,
|
||||||
serialized_grammar=serialized_grammar,
|
serialized_grammar=serialized_grammar,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def dispatch_json(self, key_string: str) -> GuidanceGrammar:
|
||||||
|
json_schema = key_string
|
||||||
|
compiler = llguidance.JsonCompiler(whitespace_flexible=self.whitespace_flexible)
|
||||||
|
serialized_grammar = compiler.compile(json_schema)
|
||||||
|
return self._from_serialized(serialized_grammar)
|
||||||
|
|
||||||
|
def dispatch_regex(self, key_string: str) -> GuidanceGrammar:
|
||||||
|
compiler = llguidance.RegexCompiler()
|
||||||
|
serialized_grammar = compiler.compile(regex=key_string)
|
||||||
|
return self._from_serialized(serialized_grammar)
|
||||||
|
|
||||||
|
def dispatch_ebnf(self, key_string: str) -> GuidanceGrammar:
|
||||||
|
compiler = llguidance.LarkCompiler()
|
||||||
|
serialized_grammar = compiler.compile(any_to_lark(key_string))
|
||||||
|
return self._from_serialized(serialized_grammar)
|
||||||
|
|
||||||
|
def dispatch_structural_tag(self, key_string: str):
|
||||||
|
return super().dispatch_structural_tag(key_string)
|
||||||
|
|||||||
@@ -141,24 +141,7 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
|
|||||||
)
|
)
|
||||||
self.whitespace_pattern = whitespace_pattern
|
self.whitespace_pattern = whitespace_pattern
|
||||||
|
|
||||||
def init_value_impl(self, key: Tuple[str, str]) -> OutlinesGrammar:
|
def _compile_regex(self, regex: str) -> Optional[OutlinesGrammar]:
|
||||||
key_type, key_string = key
|
|
||||||
if key_type == "json":
|
|
||||||
try:
|
|
||||||
regex = build_regex_from_object(
|
|
||||||
key_string,
|
|
||||||
whitespace_pattern=self.whitespace_pattern,
|
|
||||||
)
|
|
||||||
except (NotImplementedError, json.decoder.JSONDecodeError) as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
elif key_type == "regex":
|
|
||||||
regex = key_string
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid key_type: {key_type}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if hasattr(RegexGuide, "from_regex"):
|
if hasattr(RegexGuide, "from_regex"):
|
||||||
# outlines >= 0.1.1
|
# outlines >= 0.1.1
|
||||||
@@ -173,6 +156,25 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
|
|||||||
jump_forward_map = None
|
jump_forward_map = None
|
||||||
return OutlinesGrammar(guide, jump_forward_map)
|
return OutlinesGrammar(guide, jump_forward_map)
|
||||||
|
|
||||||
|
def dispatch_ebnf(self, key_string: str):
|
||||||
|
return super().dispatch_ebnf(key_string)
|
||||||
|
|
||||||
|
def dispatch_structural_tag(self, key_string: str):
|
||||||
|
return super().dispatch_structural_tag(key_string)
|
||||||
|
|
||||||
|
def dispatch_json(self, key_string: str):
|
||||||
|
try:
|
||||||
|
regex = build_regex_from_object(
|
||||||
|
key_string,
|
||||||
|
whitespace_pattern=self.whitespace_pattern,
|
||||||
|
)
|
||||||
|
except (NotImplementedError, json.decoder.JSONDecodeError) as e:
|
||||||
|
logger.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
|
||||||
|
return self._compile_regex(regex)
|
||||||
|
|
||||||
|
def dispatch_regex(self, key_string: str):
|
||||||
|
return self._compile_regex(key_string)
|
||||||
|
|
||||||
|
|
||||||
def build_regex_from_object(
|
def build_regex_from_object(
|
||||||
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
|
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ class XGrammarGrammar(BaseGrammarObject):
|
|||||||
def accept_token(self, token: int):
|
def accept_token(self, token: int):
|
||||||
assert self.matcher.accept_token(token)
|
assert self.matcher.accept_token(token)
|
||||||
|
|
||||||
def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
|
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
||||||
s = self.matcher.find_jump_forward_string()
|
s = self.matcher.find_jump_forward_string()
|
||||||
if s:
|
if s:
|
||||||
return [], s
|
return [], s
|
||||||
@@ -128,55 +128,56 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.override_stop_tokens = override_stop_tokens
|
self.override_stop_tokens = override_stop_tokens
|
||||||
|
|
||||||
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
def _from_context(self, ctx: CompiledGrammar) -> XGrammarGrammar:
|
||||||
|
|
||||||
key_type, key_string = key
|
|
||||||
if key_type == "json":
|
|
||||||
try:
|
|
||||||
if key_string == "$$ANY$$":
|
|
||||||
ctx = self.grammar_compiler.compile_builtin_json_grammar()
|
|
||||||
else:
|
|
||||||
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
|
|
||||||
except RuntimeError as e:
|
|
||||||
logging.warning(
|
|
||||||
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
elif key_type == "ebnf":
|
|
||||||
try:
|
|
||||||
ctx = self.grammar_compiler.compile_grammar(key_string)
|
|
||||||
except RuntimeError as e:
|
|
||||||
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
|
|
||||||
return None
|
|
||||||
elif key_type == "regex":
|
|
||||||
try:
|
|
||||||
ctx = self.grammar_compiler.compile_regex(key_string)
|
|
||||||
except RuntimeError as e:
|
|
||||||
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
|
|
||||||
return None
|
|
||||||
elif key_type == "structural_tag":
|
|
||||||
try:
|
|
||||||
structural_tag = json.loads(key_string)
|
|
||||||
tags = [
|
|
||||||
StructuralTagItem(
|
|
||||||
begin=structure["begin"],
|
|
||||||
schema=json.dumps(structure["schema"]),
|
|
||||||
end=structure["end"],
|
|
||||||
)
|
|
||||||
for structure in structural_tag["structures"]
|
|
||||||
]
|
|
||||||
ctx = self.grammar_compiler.compile_structural_tag(
|
|
||||||
tags, structural_tag["triggers"]
|
|
||||||
)
|
|
||||||
except RuntimeError as e:
|
|
||||||
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid key_type: {key_type}")
|
|
||||||
|
|
||||||
matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
|
matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
|
||||||
return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens)
|
return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens)
|
||||||
|
|
||||||
|
def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||||
|
try:
|
||||||
|
if key_string == "$$ANY$$":
|
||||||
|
ctx = self.grammar_compiler.compile_builtin_json_grammar()
|
||||||
|
else:
|
||||||
|
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logging.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
|
||||||
|
return None
|
||||||
|
return self._from_context(ctx)
|
||||||
|
|
||||||
|
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||||
|
try:
|
||||||
|
ctx = self.grammar_compiler.compile_grammar(key_string)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
|
||||||
|
return None
|
||||||
|
return self._from_context(ctx)
|
||||||
|
|
||||||
|
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||||
|
try:
|
||||||
|
ctx = self.grammar_compiler.compile_regex(key_string)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
|
||||||
|
return None
|
||||||
|
return self._from_context(ctx)
|
||||||
|
|
||||||
|
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||||
|
try:
|
||||||
|
structural_tag = json.loads(key_string)
|
||||||
|
tags = [
|
||||||
|
StructuralTagItem(
|
||||||
|
begin=structure["begin"],
|
||||||
|
schema=json.dumps(structure["schema"]),
|
||||||
|
end=structure["end"],
|
||||||
|
)
|
||||||
|
for structure in structural_tag["structures"]
|
||||||
|
]
|
||||||
|
ctx = self.grammar_compiler.compile_structural_tag(
|
||||||
|
tags, structural_tag["triggers"]
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
|
||||||
|
return None
|
||||||
|
return self._from_context(ctx)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
if self.grammar_compiler:
|
if self.grammar_compiler:
|
||||||
self.grammar_compiler.clear_cache()
|
self.grammar_compiler.clear_cache()
|
||||||
|
|||||||
Reference in New Issue
Block a user