[Fix & Style] Refactor the grammar backend to reduce human errors and improve readability (#4030)
This commit is contained in:
@@ -13,31 +13,130 @@
|
||||
# ==============================================================================
|
||||
"""The baseclass of a backend for grammar-guided constrained decoding."""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from threading import Event, Lock
|
||||
from typing import Any, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseGrammarObject(ABC):
|
||||
@abstractmethod
|
||||
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
||||
"""
|
||||
Try to jump forward in the grammar.
|
||||
|
||||
Returns:
|
||||
A jump forward helper which may be used in `jump_forward_str_state`.
|
||||
None if the jump forward is not possible.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
||||
"""
|
||||
Jump forward for the grammar.
|
||||
|
||||
Returns:
|
||||
A tuple of the jump forward string and the next state of the grammar
|
||||
(which can be used in `jump_and_retokenize` if needed).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def jump_and_retokenize(
|
||||
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
||||
) -> None:
|
||||
"""
|
||||
Jump forward occurs, and update the grammar state if needed.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def allocate_vocab_mask(
|
||||
self, vocab_size: int, batch_size: int, device
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def copy(self) -> "BaseGrammarObject":
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
value: Any
|
||||
value: Optional[BaseGrammarObject]
|
||||
event: Event
|
||||
|
||||
|
||||
class BaseGrammarObject:
|
||||
pass
|
||||
|
||||
|
||||
class BaseGrammarBackend:
|
||||
class BaseGrammarBackend(ABC):
|
||||
def __init__(self):
|
||||
self.executor = ThreadPoolExecutor()
|
||||
self.cache = {}
|
||||
self.cache: Dict[Tuple[str, str], CacheEntry] = {}
|
||||
self.cache_lock = Lock()
|
||||
|
||||
def init_value(self, key: Tuple[str, str]) -> BaseGrammarObject:
|
||||
def _not_supported(self, key_type: str, key_string: str) -> None:
|
||||
logger.warning(f"Skip unsupported {key_type}: {key_type}={key_string}")
|
||||
|
||||
def dispatch_fallback(
|
||||
self, key_type: str, key_string: str
|
||||
) -> Optional[BaseGrammarObject]:
|
||||
"""
|
||||
This function should not be reached in any case.
|
||||
"""
|
||||
raise ValueError(f"Invalid key_type: {key_type}={key_string}")
|
||||
|
||||
@abstractmethod
|
||||
def dispatch_json(self, key_string: str) -> Optional[BaseGrammarObject]:
|
||||
return self._not_supported("json", key_string)
|
||||
|
||||
@abstractmethod
|
||||
def dispatch_regex(self, key_string: str) -> Optional[BaseGrammarObject]:
|
||||
return self._not_supported("regex", key_string)
|
||||
|
||||
@abstractmethod
|
||||
def dispatch_ebnf(self, key_string: str) -> Optional[BaseGrammarObject]:
|
||||
return self._not_supported("ebnf", key_string)
|
||||
|
||||
@abstractmethod
|
||||
def dispatch_structural_tag(self, key_string: str) -> Optional[BaseGrammarObject]:
|
||||
return self._not_supported("structural_tag", key_string)
|
||||
|
||||
def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
||||
key_type, key_string = key
|
||||
if key_type == "json":
|
||||
return self.dispatch_json(key_string)
|
||||
elif key_type == "regex":
|
||||
return self.dispatch_regex(key_string)
|
||||
elif key_type == "ebnf":
|
||||
return self.dispatch_ebnf(key_string)
|
||||
elif key_type == "structural_tag":
|
||||
return self.dispatch_structural_tag(key_string)
|
||||
else:
|
||||
return self.dispatch_fallback(key_type, key_string)
|
||||
|
||||
def _init_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
||||
with self.cache_lock:
|
||||
if key in self.cache:
|
||||
cache_hit = True
|
||||
@@ -50,13 +149,10 @@ class BaseGrammarBackend:
|
||||
if cache_hit:
|
||||
entry.event.wait()
|
||||
else:
|
||||
entry.value = self.init_value_impl(key)
|
||||
entry.value = self._init_value_dispatch(key)
|
||||
entry.event.set()
|
||||
return entry.value.copy() if entry.value else None
|
||||
|
||||
def init_value_impl(self, key: Tuple[str, str]) -> BaseGrammarObject:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
||||
with self.cache_lock:
|
||||
entry = self.cache.get(key)
|
||||
@@ -66,7 +162,7 @@ class BaseGrammarBackend:
|
||||
return val.copy() if val else None
|
||||
|
||||
def get_future_value(self, key: Tuple[str, str]) -> Future:
|
||||
return self.executor.submit(self.init_value, key)
|
||||
return self.executor.submit(self._init_value, key)
|
||||
|
||||
def reset(self):
|
||||
with self.cache_lock:
|
||||
|
||||
@@ -48,7 +48,7 @@ class GuidanceGrammar(BaseGrammarObject):
|
||||
self.finished = False
|
||||
self.bitmask = None
|
||||
|
||||
def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
|
||||
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
||||
if len(self.pending_ff_tokens) > 0:
|
||||
s = self.llguidance_tokenizer.decode_str(self.pending_ff_tokens)
|
||||
ff_tokens = self.pending_ff_tokens
|
||||
@@ -125,22 +125,27 @@ class GuidanceBackend(BaseGrammarBackend):
|
||||
)
|
||||
self.llguidance_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None)
|
||||
|
||||
def init_value_impl(self, key: Tuple[str, str]) -> GuidanceGrammar:
|
||||
mode, value = key
|
||||
if mode == "json":
|
||||
json_schema = value
|
||||
compiler = llguidance.JsonCompiler(
|
||||
whitespace_flexible=self.whitespace_flexible
|
||||
)
|
||||
serialized_grammar = compiler.compile(json_schema)
|
||||
elif mode == "regex":
|
||||
compiler = llguidance.RegexCompiler()
|
||||
serialized_grammar = compiler.compile(regex=value)
|
||||
elif mode == "ebnf":
|
||||
compiler = llguidance.LarkCompiler()
|
||||
serialized_grammar = compiler.compile(any_to_lark(value))
|
||||
|
||||
def _from_serialized(self, serialized_grammar) -> GuidanceGrammar:
|
||||
return GuidanceGrammar(
|
||||
llguidance_tokenizer=self.llguidance_tokenizer,
|
||||
serialized_grammar=serialized_grammar,
|
||||
)
|
||||
|
||||
def dispatch_json(self, key_string: str) -> GuidanceGrammar:
|
||||
json_schema = key_string
|
||||
compiler = llguidance.JsonCompiler(whitespace_flexible=self.whitespace_flexible)
|
||||
serialized_grammar = compiler.compile(json_schema)
|
||||
return self._from_serialized(serialized_grammar)
|
||||
|
||||
def dispatch_regex(self, key_string: str) -> GuidanceGrammar:
|
||||
compiler = llguidance.RegexCompiler()
|
||||
serialized_grammar = compiler.compile(regex=key_string)
|
||||
return self._from_serialized(serialized_grammar)
|
||||
|
||||
def dispatch_ebnf(self, key_string: str) -> GuidanceGrammar:
|
||||
compiler = llguidance.LarkCompiler()
|
||||
serialized_grammar = compiler.compile(any_to_lark(key_string))
|
||||
return self._from_serialized(serialized_grammar)
|
||||
|
||||
def dispatch_structural_tag(self, key_string: str):
|
||||
return super().dispatch_structural_tag(key_string)
|
||||
|
||||
@@ -141,24 +141,7 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
|
||||
)
|
||||
self.whitespace_pattern = whitespace_pattern
|
||||
|
||||
def init_value_impl(self, key: Tuple[str, str]) -> OutlinesGrammar:
|
||||
key_type, key_string = key
|
||||
if key_type == "json":
|
||||
try:
|
||||
regex = build_regex_from_object(
|
||||
key_string,
|
||||
whitespace_pattern=self.whitespace_pattern,
|
||||
)
|
||||
except (NotImplementedError, json.decoder.JSONDecodeError) as e:
|
||||
logger.warning(
|
||||
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
||||
)
|
||||
return None
|
||||
elif key_type == "regex":
|
||||
regex = key_string
|
||||
else:
|
||||
raise ValueError(f"Invalid key_type: {key_type}")
|
||||
|
||||
def _compile_regex(self, regex: str) -> Optional[OutlinesGrammar]:
|
||||
try:
|
||||
if hasattr(RegexGuide, "from_regex"):
|
||||
# outlines >= 0.1.1
|
||||
@@ -173,6 +156,25 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
|
||||
jump_forward_map = None
|
||||
return OutlinesGrammar(guide, jump_forward_map)
|
||||
|
||||
def dispatch_ebnf(self, key_string: str):
|
||||
return super().dispatch_ebnf(key_string)
|
||||
|
||||
def dispatch_structural_tag(self, key_string: str):
|
||||
return super().dispatch_structural_tag(key_string)
|
||||
|
||||
def dispatch_json(self, key_string: str):
|
||||
try:
|
||||
regex = build_regex_from_object(
|
||||
key_string,
|
||||
whitespace_pattern=self.whitespace_pattern,
|
||||
)
|
||||
except (NotImplementedError, json.decoder.JSONDecodeError) as e:
|
||||
logger.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
|
||||
return self._compile_regex(regex)
|
||||
|
||||
def dispatch_regex(self, key_string: str):
|
||||
return self._compile_regex(key_string)
|
||||
|
||||
|
||||
def build_regex_from_object(
|
||||
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
|
||||
|
||||
@@ -57,7 +57,7 @@ class XGrammarGrammar(BaseGrammarObject):
|
||||
def accept_token(self, token: int):
|
||||
assert self.matcher.accept_token(token)
|
||||
|
||||
def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
|
||||
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
||||
s = self.matcher.find_jump_forward_string()
|
||||
if s:
|
||||
return [], s
|
||||
@@ -128,55 +128,56 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
self.vocab_size = vocab_size
|
||||
self.override_stop_tokens = override_stop_tokens
|
||||
|
||||
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
||||
|
||||
key_type, key_string = key
|
||||
if key_type == "json":
|
||||
try:
|
||||
if key_string == "$$ANY$$":
|
||||
ctx = self.grammar_compiler.compile_builtin_json_grammar()
|
||||
else:
|
||||
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
|
||||
except RuntimeError as e:
|
||||
logging.warning(
|
||||
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
||||
)
|
||||
return None
|
||||
elif key_type == "ebnf":
|
||||
try:
|
||||
ctx = self.grammar_compiler.compile_grammar(key_string)
|
||||
except RuntimeError as e:
|
||||
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
|
||||
return None
|
||||
elif key_type == "regex":
|
||||
try:
|
||||
ctx = self.grammar_compiler.compile_regex(key_string)
|
||||
except RuntimeError as e:
|
||||
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
|
||||
return None
|
||||
elif key_type == "structural_tag":
|
||||
try:
|
||||
structural_tag = json.loads(key_string)
|
||||
tags = [
|
||||
StructuralTagItem(
|
||||
begin=structure["begin"],
|
||||
schema=json.dumps(structure["schema"]),
|
||||
end=structure["end"],
|
||||
)
|
||||
for structure in structural_tag["structures"]
|
||||
]
|
||||
ctx = self.grammar_compiler.compile_structural_tag(
|
||||
tags, structural_tag["triggers"]
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
|
||||
return None
|
||||
else:
|
||||
raise ValueError(f"Invalid key_type: {key_type}")
|
||||
|
||||
def _from_context(self, ctx: CompiledGrammar) -> XGrammarGrammar:
|
||||
matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
|
||||
return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens)
|
||||
|
||||
def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||
try:
|
||||
if key_string == "$$ANY$$":
|
||||
ctx = self.grammar_compiler.compile_builtin_json_grammar()
|
||||
else:
|
||||
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
|
||||
except RuntimeError as e:
|
||||
logging.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
|
||||
return None
|
||||
return self._from_context(ctx)
|
||||
|
||||
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||
try:
|
||||
ctx = self.grammar_compiler.compile_grammar(key_string)
|
||||
except RuntimeError as e:
|
||||
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
|
||||
return None
|
||||
return self._from_context(ctx)
|
||||
|
||||
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||
try:
|
||||
ctx = self.grammar_compiler.compile_regex(key_string)
|
||||
except RuntimeError as e:
|
||||
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
|
||||
return None
|
||||
return self._from_context(ctx)
|
||||
|
||||
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||
try:
|
||||
structural_tag = json.loads(key_string)
|
||||
tags = [
|
||||
StructuralTagItem(
|
||||
begin=structure["begin"],
|
||||
schema=json.dumps(structure["schema"]),
|
||||
end=structure["end"],
|
||||
)
|
||||
for structure in structural_tag["structures"]
|
||||
]
|
||||
ctx = self.grammar_compiler.compile_structural_tag(
|
||||
tags, structural_tag["triggers"]
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
|
||||
return None
|
||||
return self._from_context(ctx)
|
||||
|
||||
def reset(self):
|
||||
if self.grammar_compiler:
|
||||
self.grammar_compiler.clear_cache()
|
||||
|
||||
Reference in New Issue
Block a user