Improve structured outputs: fix race condition, server crash, metrics and style (#6188)
This commit is contained in:
@@ -14,10 +14,9 @@
|
||||
"""The baseclass of a backend for grammar-guided constrained decoding."""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from threading import Event, Lock
|
||||
from threading import Event
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -27,11 +26,36 @@ from sglang.srt.server_args import ServerArgs
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseGrammarObject(ABC):
|
||||
class BaseGrammarObject:
|
||||
|
||||
def __init__(self):
|
||||
self._finished = False
|
||||
|
||||
def accept_token(self, token: int) -> None:
|
||||
"""
|
||||
Accept a token in the grammar.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def allocate_vocab_mask(
|
||||
self, vocab_size: int, batch_size: int, device
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def copy(self) -> "BaseGrammarObject":
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def finished(self):
|
||||
return self._finished
|
||||
@@ -40,7 +64,6 @@ class BaseGrammarObject(ABC):
|
||||
def finished(self, finished):
|
||||
self._finished = finished
|
||||
|
||||
@abstractmethod
|
||||
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
||||
"""
|
||||
Try to jump forward in the grammar.
|
||||
@@ -49,9 +72,8 @@ class BaseGrammarObject(ABC):
|
||||
A jump forward helper which may be used in `jump_forward_str_state`.
|
||||
None if the jump forward is not possible.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
||||
"""
|
||||
Jump forward for the grammar.
|
||||
@@ -60,47 +82,15 @@ class BaseGrammarObject(ABC):
|
||||
A tuple of the jump forward string and the next state of the grammar
|
||||
(which can be used in `jump_and_retokenize` if needed).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def jump_and_retokenize(
|
||||
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
||||
) -> None:
|
||||
"""
|
||||
Jump forward occurs, and update the grammar state if needed.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def accept_token(self, token: int) -> None:
|
||||
"""
|
||||
Accept a token in the grammar.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def allocate_vocab_mask(
|
||||
self, vocab_size: int, batch_size: int, device
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def copy(self) -> "BaseGrammarObject":
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -113,10 +103,9 @@ class BaseGrammarBackend:
|
||||
def __init__(self):
|
||||
self.executor = ThreadPoolExecutor()
|
||||
self.cache: Dict[Tuple[str, str], CacheEntry] = {}
|
||||
self.cache_lock = Lock()
|
||||
|
||||
def _not_supported(self, key_type: str, key_string: str) -> None:
|
||||
logger.warning(f"Skip unsupported {key_type}: {key_type}={key_string}")
|
||||
logger.warning(f"Skip unsupported {key_type=}, {key_string=}")
|
||||
|
||||
def dispatch_fallback(
|
||||
self, key_type: str, key_string: str
|
||||
@@ -148,40 +137,25 @@ class BaseGrammarBackend:
|
||||
return self.dispatch_ebnf(key_string)
|
||||
elif key_type == "structural_tag":
|
||||
return self.dispatch_structural_tag(key_string)
|
||||
elif key_type == "structural_pattern":
|
||||
return self.dispatch_structural_pattern(key_string)
|
||||
else:
|
||||
return self.dispatch_fallback(key_type, key_string)
|
||||
|
||||
def _init_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
||||
with self.cache_lock:
|
||||
if key in self.cache:
|
||||
cache_hit = True
|
||||
entry = self.cache[key]
|
||||
else:
|
||||
cache_hit = False
|
||||
entry = CacheEntry(None, Event())
|
||||
self.cache[key] = entry
|
||||
def get_cached_or_future_value(
|
||||
self, key: Tuple[str, str]
|
||||
) -> Optional[BaseGrammarObject]:
|
||||
value = self.cache.get(key)
|
||||
if value:
|
||||
return value.copy(), True
|
||||
value = self.executor.submit(self._init_value_dispatch, key)
|
||||
return value, False
|
||||
|
||||
if cache_hit:
|
||||
entry.event.wait()
|
||||
else:
|
||||
entry.value = self._init_value_dispatch(key)
|
||||
entry.event.set()
|
||||
return entry.value.copy() if entry.value else None
|
||||
|
||||
def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
||||
with self.cache_lock:
|
||||
entry = self.cache.get(key)
|
||||
if not entry or not entry.event.is_set():
|
||||
return None
|
||||
val = self.cache[key].value
|
||||
return val.copy() if val else None
|
||||
|
||||
def get_future_value(self, key: Tuple[str, str]) -> Future:
|
||||
return self.executor.submit(self._init_value, key)
|
||||
def set_cache(self, key: Tuple[str, str], value: BaseGrammarObject):
|
||||
self.cache[key] = value
|
||||
|
||||
def reset(self):
|
||||
with self.cache_lock:
|
||||
self.cache.clear()
|
||||
self.cache.clear()
|
||||
|
||||
|
||||
def create_grammar_backend(
|
||||
@@ -211,9 +185,12 @@ def create_grammar_backend(
|
||||
raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
|
||||
|
||||
if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"):
|
||||
from .reasoner_grammar_backend import ReasonerGrammarBackend
|
||||
from sglang.srt.constrained.reasoner_grammar_backend import (
|
||||
ReasonerGrammarBackend,
|
||||
)
|
||||
|
||||
grammar_backend = ReasonerGrammarBackend(
|
||||
grammar_backend, tokenizer.think_end_id
|
||||
)
|
||||
|
||||
return grammar_backend
|
||||
|
||||
@@ -50,21 +50,6 @@ class GuidanceGrammar(BaseGrammarObject):
|
||||
self.finished = False
|
||||
self.bitmask = None
|
||||
|
||||
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
||||
ff_tokens = self.ll_matcher.compute_ff_tokens()
|
||||
if ff_tokens:
|
||||
return ff_tokens, ""
|
||||
else:
|
||||
return None
|
||||
|
||||
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
||||
return "", -1
|
||||
|
||||
def jump_and_retokenize(
|
||||
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
||||
):
|
||||
pass
|
||||
|
||||
def accept_token(self, token: int):
|
||||
if not self.ll_matcher.consume_token(token):
|
||||
logger.warning(f"matcher error: {self.ll_matcher.get_error()}")
|
||||
@@ -104,6 +89,21 @@ class GuidanceGrammar(BaseGrammarObject):
|
||||
serialized_grammar=self.serialized_grammar,
|
||||
)
|
||||
|
||||
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
||||
ff_tokens = self.ll_matcher.compute_ff_tokens()
|
||||
if ff_tokens:
|
||||
return ff_tokens, ""
|
||||
else:
|
||||
return None
|
||||
|
||||
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
||||
return "", -1
|
||||
|
||||
def jump_and_retokenize(
|
||||
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class GuidanceBackend(BaseGrammarBackend):
|
||||
|
||||
@@ -130,12 +130,16 @@ class GuidanceBackend(BaseGrammarBackend):
|
||||
return None
|
||||
|
||||
def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]:
|
||||
serialized_grammar = LLMatcher.grammar_from_json_schema(
|
||||
key_string,
|
||||
defaults={
|
||||
"whitespace_pattern": self.whitespace_pattern,
|
||||
},
|
||||
)
|
||||
try:
|
||||
serialized_grammar = LLMatcher.grammar_from_json_schema(
|
||||
key_string,
|
||||
defaults={
|
||||
"whitespace_pattern": self.whitespace_pattern,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Skip invalid grammar: {key_string=}, {e=}")
|
||||
return None
|
||||
return self._from_serialized(serialized_grammar)
|
||||
|
||||
def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]:
|
||||
|
||||
@@ -53,6 +53,30 @@ class OutlinesGrammar(BaseGrammarObject):
|
||||
def accept_token(self, token: int):
|
||||
self.state = self.guide.get_next_state(self.state, token)
|
||||
|
||||
def allocate_vocab_mask(
|
||||
self, vocab_size: int, batch_size: int, device
|
||||
) -> torch.Tensor:
|
||||
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
|
||||
|
||||
@staticmethod
|
||||
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
||||
return vocab_mask
|
||||
|
||||
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
||||
tokens = torch.tensor(
|
||||
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
|
||||
).to(vocab_mask.device, non_blocking=True)
|
||||
vocab_mask = vocab_mask[idx]
|
||||
vocab_mask.fill_(1)
|
||||
vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
|
||||
|
||||
@staticmethod
|
||||
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
|
||||
logits.masked_fill_(vocab_mask, float("-inf"))
|
||||
|
||||
def copy(self):
|
||||
return OutlinesGrammar(self.guide, self.jump_forward_map)
|
||||
|
||||
def try_jump_forward(self, tokenizer) -> Optional[Tuple]:
|
||||
if not self.jump_forward_map:
|
||||
return None
|
||||
@@ -86,30 +110,6 @@ class OutlinesGrammar(BaseGrammarObject):
|
||||
):
|
||||
self.state = next_state
|
||||
|
||||
def allocate_vocab_mask(
|
||||
self, vocab_size: int, batch_size: int, device
|
||||
) -> torch.Tensor:
|
||||
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
|
||||
|
||||
@staticmethod
|
||||
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
||||
return vocab_mask
|
||||
|
||||
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
||||
tokens = torch.tensor(
|
||||
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
|
||||
).to(vocab_mask.device, non_blocking=True)
|
||||
vocab_mask = vocab_mask[idx]
|
||||
vocab_mask.fill_(1)
|
||||
vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
|
||||
|
||||
@staticmethod
|
||||
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
|
||||
logits.masked_fill_(vocab_mask, float("-inf"))
|
||||
|
||||
def copy(self):
|
||||
return OutlinesGrammar(self.guide, self.jump_forward_map)
|
||||
|
||||
|
||||
class OutlinesGrammarBackend(BaseGrammarBackend):
|
||||
def __init__(
|
||||
@@ -169,8 +169,9 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
|
||||
key_string,
|
||||
whitespace_pattern=self.whitespace_pattern,
|
||||
)
|
||||
except (NotImplementedError, json.decoder.JSONDecodeError) as e:
|
||||
logger.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
|
||||
except (NotImplementedError, json.decoder.JSONDecodeError, ValueError) as e:
|
||||
logger.warning(f"Skip invalid json_schema: {key_string=}, {e=}")
|
||||
return None
|
||||
return self._compile_regex(regex)
|
||||
|
||||
def dispatch_regex(self, key_string: str):
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# ==============================================================================
|
||||
"""The baseclass of a backend for reasoner grammar-guided constrained decoding."""
|
||||
|
||||
from concurrent.futures import Future
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -28,13 +27,12 @@ class ReasonerGrammarObject(BaseGrammarObject):
|
||||
self.think_end_id = think_end_id
|
||||
self.is_in_reasoning = True
|
||||
|
||||
@property
|
||||
def finished(self):
|
||||
return self.grammar.finished
|
||||
def accept_token(self, token: int):
|
||||
if token == self.think_end_id:
|
||||
self.is_in_reasoning = False
|
||||
|
||||
@finished.setter
|
||||
def finished(self, finished):
|
||||
self.grammar.finished = finished
|
||||
if not self.is_in_reasoning and token != self.think_end_id:
|
||||
self.grammar.accept_token(token)
|
||||
|
||||
def allocate_vocab_mask(
|
||||
self, vocab_size: int, batch_size: int, device
|
||||
@@ -52,12 +50,16 @@ class ReasonerGrammarObject(BaseGrammarObject):
|
||||
def apply_vocab_mask(self):
|
||||
return self.grammar.apply_vocab_mask
|
||||
|
||||
def accept_token(self, token: int):
|
||||
if token == self.think_end_id:
|
||||
self.is_in_reasoning = False
|
||||
def copy(self) -> BaseGrammarObject:
|
||||
return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id)
|
||||
|
||||
if not self.is_in_reasoning and token != self.think_end_id:
|
||||
self.grammar.accept_token(token)
|
||||
@property
|
||||
def finished(self):
|
||||
return self.grammar.finished
|
||||
|
||||
@finished.setter
|
||||
def finished(self, finished):
|
||||
self.grammar.finished = finished
|
||||
|
||||
def try_jump_forward(self, tokenizer):
|
||||
return self.grammar.try_jump_forward(tokenizer)
|
||||
@@ -72,30 +74,17 @@ class ReasonerGrammarObject(BaseGrammarObject):
|
||||
old_output_ids, new_output_ids, next_state
|
||||
)
|
||||
|
||||
def copy(self) -> BaseGrammarObject:
|
||||
return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id)
|
||||
|
||||
|
||||
class ReasonerGrammarBackend(BaseGrammarBackend):
|
||||
def __init__(self, grammar_backend: BaseGrammarBackend, think_end_id):
|
||||
super().__init__()
|
||||
self.grammar_backend = grammar_backend
|
||||
self.think_end_id = think_end_id
|
||||
|
||||
def get_cached_value(self, key: Tuple[str, str]) -> Optional[ReasonerGrammarObject]:
|
||||
grammar = self.grammar_backend.get_cached_value(key)
|
||||
return ReasonerGrammarObject(grammar, self.think_end_id) if grammar else None
|
||||
|
||||
def get_future_value(self, key: Tuple[str, str]) -> Future:
|
||||
grammar = Future()
|
||||
|
||||
def callback(f: Future):
|
||||
if result := f.result():
|
||||
grammar.set_result(ReasonerGrammarObject(result, self.think_end_id))
|
||||
else:
|
||||
grammar.set_result(None)
|
||||
|
||||
self.grammar_backend.get_future_value(key).add_done_callback(callback)
|
||||
return grammar
|
||||
|
||||
def reset(self):
|
||||
self.grammar_backend.reset()
|
||||
def _init_value_dispatch(
|
||||
self, key: Tuple[str, str]
|
||||
) -> Optional[ReasonerGrammarObject]:
|
||||
ret = self.grammar_backend._init_value_dispatch(key)
|
||||
if ret is None:
|
||||
return None
|
||||
return ReasonerGrammarObject(ret, self.think_end_id)
|
||||
|
||||
@@ -18,7 +18,6 @@ import logging
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import xgrammar
|
||||
from xgrammar import (
|
||||
CompiledGrammar,
|
||||
GrammarCompiler,
|
||||
@@ -35,7 +34,6 @@ from sglang.srt.constrained.base_grammar_backend import (
|
||||
from sglang.srt.constrained.triton_ops.bitmask_ops import (
|
||||
apply_token_bitmask_inplace_triton,
|
||||
)
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,49 +49,35 @@ class XGrammarGrammar(BaseGrammarObject):
|
||||
vocab_size: int,
|
||||
ctx: CompiledGrammar,
|
||||
override_stop_tokens: Optional[Union[List[int], int]],
|
||||
key_string: Optional[str] = None, # TODO (sk): for debugging, remove later
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.matcher = matcher
|
||||
self.vocab_size = vocab_size
|
||||
self.ctx = ctx
|
||||
self.override_stop_tokens = override_stop_tokens
|
||||
self.finished = False
|
||||
|
||||
from xgrammar.kernels.apply_token_bitmask_inplace_cpu import (
|
||||
apply_token_bitmask_inplace_cpu,
|
||||
)
|
||||
|
||||
self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_cpu
|
||||
self.accepted_tokens = []
|
||||
self.key_string = key_string
|
||||
|
||||
def accept_token(self, token: int):
|
||||
assert self.matcher.accept_token(token)
|
||||
|
||||
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
||||
s = self.matcher.find_jump_forward_string()
|
||||
if s:
|
||||
return [], s
|
||||
return None
|
||||
|
||||
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
||||
_, data = helper
|
||||
return data, -1
|
||||
|
||||
def jump_and_retokenize(
|
||||
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
||||
):
|
||||
k = 0
|
||||
for i, old_id in enumerate(old_output_ids):
|
||||
if old_id == new_output_ids[i]:
|
||||
k = i + 1
|
||||
if not self.is_terminated():
|
||||
accepted = self.matcher.accept_token(token)
|
||||
if not accepted:
|
||||
# log for debugging
|
||||
raise ValueError(
|
||||
f"Tokens not accepted: {token}\n"
|
||||
f"Accepted tokens: {self.accepted_tokens}\n"
|
||||
f"Key string: {self.key_string}"
|
||||
)
|
||||
else:
|
||||
break
|
||||
self.accepted_tokens.append(token)
|
||||
|
||||
# rollback to the last token that is the same
|
||||
if k < len(old_output_ids):
|
||||
self.matcher.rollback(len(old_output_ids) - k)
|
||||
def rollback(self, k: int):
|
||||
self.matcher.rollback(k)
|
||||
self.accepted_tokens = self.accepted_tokens[:-k]
|
||||
|
||||
for i in range(k, len(new_output_ids)):
|
||||
assert self.matcher.accept_token(new_output_ids[i])
|
||||
def is_terminated(self):
|
||||
return self.matcher.is_terminated()
|
||||
|
||||
def allocate_vocab_mask(
|
||||
self, vocab_size: int, batch_size: int, device
|
||||
@@ -122,9 +106,43 @@ class XGrammarGrammar(BaseGrammarObject):
|
||||
override_stop_tokens=self.override_stop_tokens,
|
||||
)
|
||||
return XGrammarGrammar(
|
||||
matcher, self.vocab_size, self.ctx, self.override_stop_tokens
|
||||
matcher,
|
||||
self.vocab_size,
|
||||
self.ctx,
|
||||
self.override_stop_tokens,
|
||||
self.key_string,
|
||||
)
|
||||
|
||||
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
||||
s = self.matcher.find_jump_forward_string()
|
||||
if s:
|
||||
return [], s
|
||||
return None
|
||||
|
||||
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
||||
_, data = helper
|
||||
return data, -1
|
||||
|
||||
def jump_and_retokenize(
|
||||
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
||||
):
|
||||
k = 0
|
||||
for i, old_id in enumerate(old_output_ids):
|
||||
if old_id == new_output_ids[i]:
|
||||
k = i + 1
|
||||
else:
|
||||
break
|
||||
|
||||
# rollback to the last token that is the same
|
||||
if k < len(old_output_ids):
|
||||
self.matcher.rollback(len(old_output_ids) - k)
|
||||
|
||||
for i in range(k, len(new_output_ids)):
|
||||
assert self.matcher.accept_token(new_output_ids[i])
|
||||
|
||||
def __repr__(self):
|
||||
return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=})"
|
||||
|
||||
|
||||
class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
def __init__(
|
||||
@@ -143,9 +161,15 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
self.vocab_size = vocab_size
|
||||
self.override_stop_tokens = override_stop_tokens
|
||||
|
||||
def _from_context(self, ctx: CompiledGrammar) -> XGrammarGrammar:
|
||||
matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
|
||||
return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens)
|
||||
def _from_context(self, ctx: CompiledGrammar, key_string: str) -> XGrammarGrammar:
|
||||
matcher = GrammarMatcher(
|
||||
ctx,
|
||||
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
||||
override_stop_tokens=self.override_stop_tokens,
|
||||
)
|
||||
return XGrammarGrammar(
|
||||
matcher, self.vocab_size, ctx, self.override_stop_tokens, key_string
|
||||
)
|
||||
|
||||
def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||
try:
|
||||
@@ -157,7 +181,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
except RuntimeError as e:
|
||||
logging.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
|
||||
return None
|
||||
return self._from_context(ctx)
|
||||
return self._from_context(ctx, key_string)
|
||||
|
||||
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||
try:
|
||||
@@ -165,7 +189,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
except RuntimeError as e:
|
||||
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
|
||||
return None
|
||||
return self._from_context(ctx)
|
||||
return self._from_context(ctx, key_string)
|
||||
|
||||
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||
try:
|
||||
@@ -173,7 +197,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
except RuntimeError as e:
|
||||
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
|
||||
return None
|
||||
return self._from_context(ctx)
|
||||
return self._from_context(ctx, key_string)
|
||||
|
||||
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||
try:
|
||||
@@ -190,9 +214,11 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
tags, structural_tag["triggers"]
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
|
||||
logging.warning(
|
||||
f"Skip invalid structural_tag: structural_tag={key_string}, {e=}"
|
||||
)
|
||||
return None
|
||||
return self._from_context(ctx)
|
||||
return self._from_context(ctx, key_string)
|
||||
|
||||
def reset(self):
|
||||
if self.grammar_compiler:
|
||||
|
||||
Reference in New Issue
Block a user