Improve structured outputs: fix race condition, server crash, metrics and style (#6188)

This commit is contained in:
Lianmin Zheng
2025-05-11 08:36:16 -07:00
committed by GitHub
parent 94d42b6794
commit 01bdbf7f80
13 changed files with 568 additions and 258 deletions

View File

@@ -14,10 +14,9 @@
"""The baseclass of a backend for grammar-guided constrained decoding."""
import logging
from abc import ABC, abstractmethod
from concurrent.futures import Future, ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from threading import Event, Lock
from threading import Event
from typing import Dict, List, Optional, Tuple
import torch
@@ -27,11 +26,36 @@ from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__)
class BaseGrammarObject(ABC):
class BaseGrammarObject:
def __init__(self):
self._finished = False
def accept_token(self, token: int) -> None:
"""
Accept a token in the grammar.
"""
raise NotImplementedError()
def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
) -> torch.Tensor:
raise NotImplementedError()
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
raise NotImplementedError()
@staticmethod
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
raise NotImplementedError()
@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
raise NotImplementedError()
def copy(self) -> "BaseGrammarObject":
raise NotImplementedError()
@property
def finished(self):
return self._finished
@@ -40,7 +64,6 @@ class BaseGrammarObject(ABC):
def finished(self, finished):
self._finished = finished
@abstractmethod
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
"""
Try to jump forward in the grammar.
@@ -49,9 +72,8 @@ class BaseGrammarObject(ABC):
A jump forward helper which may be used in `jump_forward_str_state`.
None if the jump forward is not possible.
"""
raise NotImplementedError
raise NotImplementedError()
@abstractmethod
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
"""
Jump forward for the grammar.
@@ -60,47 +82,15 @@ class BaseGrammarObject(ABC):
A tuple of the jump forward string and the next state of the grammar
(which can be used in `jump_and_retokenize` if needed).
"""
raise NotImplementedError
raise NotImplementedError()
@abstractmethod
def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
) -> None:
"""
Jump forward occurs, and update the grammar state if needed.
"""
raise NotImplementedError
@abstractmethod
def accept_token(self, token: int) -> None:
"""
Accept a token in the grammar.
"""
raise NotImplementedError
@abstractmethod
def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
raise NotImplementedError
@staticmethod
@abstractmethod
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
raise NotImplementedError
@staticmethod
@abstractmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
raise NotImplementedError
@abstractmethod
def copy(self) -> "BaseGrammarObject":
raise NotImplementedError
raise NotImplementedError()
@dataclass
@@ -113,10 +103,9 @@ class BaseGrammarBackend:
def __init__(self):
self.executor = ThreadPoolExecutor()
self.cache: Dict[Tuple[str, str], CacheEntry] = {}
self.cache_lock = Lock()
def _not_supported(self, key_type: str, key_string: str) -> None:
logger.warning(f"Skip unsupported {key_type}: {key_type}={key_string}")
logger.warning(f"Skip unsupported {key_type=}, {key_string=}")
def dispatch_fallback(
self, key_type: str, key_string: str
@@ -148,40 +137,25 @@ class BaseGrammarBackend:
return self.dispatch_ebnf(key_string)
elif key_type == "structural_tag":
return self.dispatch_structural_tag(key_string)
elif key_type == "structural_pattern":
return self.dispatch_structural_pattern(key_string)
else:
return self.dispatch_fallback(key_type, key_string)
def _init_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
with self.cache_lock:
if key in self.cache:
cache_hit = True
entry = self.cache[key]
else:
cache_hit = False
entry = CacheEntry(None, Event())
self.cache[key] = entry
def get_cached_or_future_value(
self, key: Tuple[str, str]
) -> Optional[BaseGrammarObject]:
value = self.cache.get(key)
if value:
return value.copy(), True
value = self.executor.submit(self._init_value_dispatch, key)
return value, False
if cache_hit:
entry.event.wait()
else:
entry.value = self._init_value_dispatch(key)
entry.event.set()
return entry.value.copy() if entry.value else None
def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
with self.cache_lock:
entry = self.cache.get(key)
if not entry or not entry.event.is_set():
return None
val = self.cache[key].value
return val.copy() if val else None
def get_future_value(self, key: Tuple[str, str]) -> Future:
return self.executor.submit(self._init_value, key)
def set_cache(self, key: Tuple[str, str], value: BaseGrammarObject):
self.cache[key] = value
def reset(self):
with self.cache_lock:
self.cache.clear()
self.cache.clear()
def create_grammar_backend(
@@ -211,9 +185,12 @@ def create_grammar_backend(
raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"):
from .reasoner_grammar_backend import ReasonerGrammarBackend
from sglang.srt.constrained.reasoner_grammar_backend import (
ReasonerGrammarBackend,
)
grammar_backend = ReasonerGrammarBackend(
grammar_backend, tokenizer.think_end_id
)
return grammar_backend

View File

@@ -50,21 +50,6 @@ class GuidanceGrammar(BaseGrammarObject):
self.finished = False
self.bitmask = None
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
ff_tokens = self.ll_matcher.compute_ff_tokens()
if ff_tokens:
return ff_tokens, ""
else:
return None
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
return "", -1
def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
):
pass
def accept_token(self, token: int):
if not self.ll_matcher.consume_token(token):
logger.warning(f"matcher error: {self.ll_matcher.get_error()}")
@@ -104,6 +89,21 @@ class GuidanceGrammar(BaseGrammarObject):
serialized_grammar=self.serialized_grammar,
)
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
ff_tokens = self.ll_matcher.compute_ff_tokens()
if ff_tokens:
return ff_tokens, ""
else:
return None
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
return "", -1
def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
):
pass
class GuidanceBackend(BaseGrammarBackend):
@@ -130,12 +130,16 @@ class GuidanceBackend(BaseGrammarBackend):
return None
def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]:
serialized_grammar = LLMatcher.grammar_from_json_schema(
key_string,
defaults={
"whitespace_pattern": self.whitespace_pattern,
},
)
try:
serialized_grammar = LLMatcher.grammar_from_json_schema(
key_string,
defaults={
"whitespace_pattern": self.whitespace_pattern,
},
)
except Exception as e:
logger.warning(f"Skip invalid grammar: {key_string=}, {e=}")
return None
return self._from_serialized(serialized_grammar)
def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]:

View File

@@ -53,6 +53,30 @@ class OutlinesGrammar(BaseGrammarObject):
def accept_token(self, token: int):
self.state = self.guide.get_next_state(self.state, token)
def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
) -> torch.Tensor:
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
@staticmethod
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
return vocab_mask
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
tokens = torch.tensor(
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
).to(vocab_mask.device, non_blocking=True)
vocab_mask = vocab_mask[idx]
vocab_mask.fill_(1)
vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
logits.masked_fill_(vocab_mask, float("-inf"))
def copy(self):
return OutlinesGrammar(self.guide, self.jump_forward_map)
def try_jump_forward(self, tokenizer) -> Optional[Tuple]:
if not self.jump_forward_map:
return None
@@ -86,30 +110,6 @@ class OutlinesGrammar(BaseGrammarObject):
):
self.state = next_state
def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
) -> torch.Tensor:
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
@staticmethod
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
return vocab_mask
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
tokens = torch.tensor(
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
).to(vocab_mask.device, non_blocking=True)
vocab_mask = vocab_mask[idx]
vocab_mask.fill_(1)
vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
logits.masked_fill_(vocab_mask, float("-inf"))
def copy(self):
return OutlinesGrammar(self.guide, self.jump_forward_map)
class OutlinesGrammarBackend(BaseGrammarBackend):
def __init__(
@@ -169,8 +169,9 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
key_string,
whitespace_pattern=self.whitespace_pattern,
)
except (NotImplementedError, json.decoder.JSONDecodeError) as e:
logger.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
except (NotImplementedError, json.decoder.JSONDecodeError, ValueError) as e:
logger.warning(f"Skip invalid json_schema: {key_string=}, {e=}")
return None
return self._compile_regex(regex)
def dispatch_regex(self, key_string: str):

View File

@@ -13,7 +13,6 @@
# ==============================================================================
"""The baseclass of a backend for reasoner grammar-guided constrained decoding."""
from concurrent.futures import Future
from typing import List, Optional, Tuple
import torch
@@ -28,13 +27,12 @@ class ReasonerGrammarObject(BaseGrammarObject):
self.think_end_id = think_end_id
self.is_in_reasoning = True
@property
def finished(self):
return self.grammar.finished
def accept_token(self, token: int):
if token == self.think_end_id:
self.is_in_reasoning = False
@finished.setter
def finished(self, finished):
self.grammar.finished = finished
if not self.is_in_reasoning and token != self.think_end_id:
self.grammar.accept_token(token)
def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
@@ -52,12 +50,16 @@ class ReasonerGrammarObject(BaseGrammarObject):
def apply_vocab_mask(self):
return self.grammar.apply_vocab_mask
def accept_token(self, token: int):
if token == self.think_end_id:
self.is_in_reasoning = False
def copy(self) -> BaseGrammarObject:
return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id)
if not self.is_in_reasoning and token != self.think_end_id:
self.grammar.accept_token(token)
@property
def finished(self):
return self.grammar.finished
@finished.setter
def finished(self, finished):
self.grammar.finished = finished
def try_jump_forward(self, tokenizer):
return self.grammar.try_jump_forward(tokenizer)
@@ -72,30 +74,17 @@ class ReasonerGrammarObject(BaseGrammarObject):
old_output_ids, new_output_ids, next_state
)
def copy(self) -> BaseGrammarObject:
return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id)
class ReasonerGrammarBackend(BaseGrammarBackend):
def __init__(self, grammar_backend: BaseGrammarBackend, think_end_id):
super().__init__()
self.grammar_backend = grammar_backend
self.think_end_id = think_end_id
def get_cached_value(self, key: Tuple[str, str]) -> Optional[ReasonerGrammarObject]:
grammar = self.grammar_backend.get_cached_value(key)
return ReasonerGrammarObject(grammar, self.think_end_id) if grammar else None
def get_future_value(self, key: Tuple[str, str]) -> Future:
grammar = Future()
def callback(f: Future):
if result := f.result():
grammar.set_result(ReasonerGrammarObject(result, self.think_end_id))
else:
grammar.set_result(None)
self.grammar_backend.get_future_value(key).add_done_callback(callback)
return grammar
def reset(self):
self.grammar_backend.reset()
def _init_value_dispatch(
self, key: Tuple[str, str]
) -> Optional[ReasonerGrammarObject]:
ret = self.grammar_backend._init_value_dispatch(key)
if ret is None:
return None
return ReasonerGrammarObject(ret, self.think_end_id)

View File

@@ -18,7 +18,6 @@ import logging
from typing import List, Optional, Tuple, Union
import torch
import xgrammar
from xgrammar import (
CompiledGrammar,
GrammarCompiler,
@@ -35,7 +34,6 @@ from sglang.srt.constrained.base_grammar_backend import (
from sglang.srt.constrained.triton_ops.bitmask_ops import (
apply_token_bitmask_inplace_triton,
)
from sglang.srt.utils import get_bool_env_var
logger = logging.getLogger(__name__)
@@ -51,49 +49,35 @@ class XGrammarGrammar(BaseGrammarObject):
vocab_size: int,
ctx: CompiledGrammar,
override_stop_tokens: Optional[Union[List[int], int]],
key_string: Optional[str] = None, # TODO (sk): for debugging, remove later
) -> None:
super().__init__()
self.matcher = matcher
self.vocab_size = vocab_size
self.ctx = ctx
self.override_stop_tokens = override_stop_tokens
self.finished = False
from xgrammar.kernels.apply_token_bitmask_inplace_cpu import (
apply_token_bitmask_inplace_cpu,
)
self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_cpu
self.accepted_tokens = []
self.key_string = key_string
def accept_token(self, token: int):
assert self.matcher.accept_token(token)
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
s = self.matcher.find_jump_forward_string()
if s:
return [], s
return None
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
_, data = helper
return data, -1
def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
):
k = 0
for i, old_id in enumerate(old_output_ids):
if old_id == new_output_ids[i]:
k = i + 1
if not self.is_terminated():
accepted = self.matcher.accept_token(token)
if not accepted:
# log for debugging
raise ValueError(
f"Tokens not accepted: {token}\n"
f"Accepted tokens: {self.accepted_tokens}\n"
f"Key string: {self.key_string}"
)
else:
break
self.accepted_tokens.append(token)
# rollback to the last token that is the same
if k < len(old_output_ids):
self.matcher.rollback(len(old_output_ids) - k)
def rollback(self, k: int):
self.matcher.rollback(k)
self.accepted_tokens = self.accepted_tokens[:-k]
for i in range(k, len(new_output_ids)):
assert self.matcher.accept_token(new_output_ids[i])
def is_terminated(self):
return self.matcher.is_terminated()
def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device
@@ -122,9 +106,43 @@ class XGrammarGrammar(BaseGrammarObject):
override_stop_tokens=self.override_stop_tokens,
)
return XGrammarGrammar(
matcher, self.vocab_size, self.ctx, self.override_stop_tokens
matcher,
self.vocab_size,
self.ctx,
self.override_stop_tokens,
self.key_string,
)
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
s = self.matcher.find_jump_forward_string()
if s:
return [], s
return None
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
_, data = helper
return data, -1
def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
):
k = 0
for i, old_id in enumerate(old_output_ids):
if old_id == new_output_ids[i]:
k = i + 1
else:
break
# rollback to the last token that is the same
if k < len(old_output_ids):
self.matcher.rollback(len(old_output_ids) - k)
for i in range(k, len(new_output_ids)):
assert self.matcher.accept_token(new_output_ids[i])
def __repr__(self):
return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=})"
class XGrammarGrammarBackend(BaseGrammarBackend):
def __init__(
@@ -143,9 +161,15 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
self.vocab_size = vocab_size
self.override_stop_tokens = override_stop_tokens
def _from_context(self, ctx: CompiledGrammar) -> XGrammarGrammar:
matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens)
def _from_context(self, ctx: CompiledGrammar, key_string: str) -> XGrammarGrammar:
matcher = GrammarMatcher(
ctx,
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
override_stop_tokens=self.override_stop_tokens,
)
return XGrammarGrammar(
matcher, self.vocab_size, ctx, self.override_stop_tokens, key_string
)
def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
@@ -157,7 +181,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
except RuntimeError as e:
logging.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
return None
return self._from_context(ctx)
return self._from_context(ctx, key_string)
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
@@ -165,7 +189,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
except RuntimeError as e:
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
return None
return self._from_context(ctx)
return self._from_context(ctx, key_string)
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
@@ -173,7 +197,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
except RuntimeError as e:
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
return None
return self._from_context(ctx)
return self._from_context(ctx, key_string)
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
try:
@@ -190,9 +214,11 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
tags, structural_tag["triggers"]
)
except RuntimeError as e:
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
logging.warning(
f"Skip invalid structural_tag: structural_tag={key_string}, {e=}"
)
return None
return self._from_context(ctx)
return self._from_context(ctx, key_string)
def reset(self):
if self.grammar_compiler: