diff --git a/docs/backend/structured_outputs_for_reasoning_models.ipynb b/docs/backend/structured_outputs_for_reasoning_models.ipynb index d17dbd967..03824743f 100644 --- a/docs/backend/structured_outputs_for_reasoning_models.ipynb +++ b/docs/backend/structured_outputs_for_reasoning_models.ipynb @@ -94,8 +94,8 @@ " model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n", " messages=[\n", " {\n", - " \"role\": \"user\",\n", - " \"content\": \"Please generate the information of the capital of France in the JSON format.\",\n", + " \"role\": \"assistant\",\n", + " \"content\": \"Give me the information and population of the capital of France in the JSON format.\",\n", " },\n", " ],\n", " temperature=0,\n", @@ -145,8 +145,8 @@ " model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n", " messages=[\n", " {\n", - " \"role\": \"user\",\n", - " \"content\": \"Give me the information of the capital of France in the JSON format.\",\n", + " \"role\": \"assistant\",\n", + " \"content\": \"Give me the information and population of the capital of France in the JSON format.\",\n", " },\n", " ],\n", " temperature=0,\n", @@ -188,8 +188,8 @@ " messages=[\n", " {\"role\": \"system\", \"content\": \"You are a helpful geography bot.\"},\n", " {\n", - " \"role\": \"user\",\n", - " \"content\": \"Give me the information of the capital of France.\",\n", + " \"role\": \"assistant\",\n", + " \"content\": \"Give me the information and population of the capital of France in the JSON format.\",\n", " },\n", " ],\n", " temperature=0,\n", @@ -218,7 +218,7 @@ "response = client.chat.completions.create(\n", " model=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n", " messages=[\n", - " {\"role\": \"user\", \"content\": \"What is the capital of France?\"},\n", + " {\"role\": \"assistant\", \"content\": \"What is the capital of France?\"},\n", " ],\n", " temperature=0,\n", " max_tokens=2048,\n", @@ -323,7 +323,7 @@ "You are a helpful assistant.\"\"\",\n", " },\n", " {\n", - " \"role\": \"user\",\n", + " \"role\": \"assistant\",\n", " \"content\": \"You are in New York. Please get the current date and time, and the weather.\",\n", " },\n", " ]\n", @@ -400,9 +400,9 @@ "\n", "messages = [\n", " {\n", - " \"role\": \"user\",\n", - " \"content\": \"Here is the information of the capital of France in the JSON format.\\n\",\n", - " }\n", + " \"role\": \"assistant\",\n", + " \"content\": \"Give me the information and population of the capital of France in the JSON format.\",\n", + " },\n", "]\n", "text = tokenizer.apply_chat_template(\n", " messages, tokenize=False, add_generation_prompt=True\n", @@ -452,7 +452,9 @@ ")\n", "\n", "# JSON\n", - "text = tokenizer.apply_chat_template(text, tokenize=False, add_generation_prompt=True)\n", + "text = tokenizer.apply_chat_template(\n", + " messages, tokenize=False, add_generation_prompt=True\n", + ")\n", "response = requests.post(\n", " f\"http://localhost:{port}/generate\",\n", " json={\n", diff --git a/python/sglang/srt/constrained/base_grammar_backend.py b/python/sglang/srt/constrained/base_grammar_backend.py index 58ee5ccb2..0316a8dfc 100644 --- a/python/sglang/srt/constrained/base_grammar_backend.py +++ b/python/sglang/srt/constrained/base_grammar_backend.py @@ -14,10 +14,9 @@ """The baseclass of a backend for grammar-guided constrained decoding.""" import logging -from abc import ABC, abstractmethod -from concurrent.futures import Future, ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from threading import Event, Lock +from threading import Event from typing import Dict, List, Optional, Tuple import torch @@ -27,11 +26,36 @@ from sglang.srt.server_args import ServerArgs logger = logging.getLogger(__name__) -class BaseGrammarObject(ABC): +class BaseGrammarObject: def __init__(self): self._finished = False + def accept_token(self, token: int) -> None: + """ + Accept a token in the grammar. + """ + raise NotImplementedError() + + def allocate_vocab_mask( + self, vocab_size: int, batch_size: int, device + ) -> torch.Tensor: + raise NotImplementedError() + + def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: + raise NotImplementedError() + + @staticmethod + def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor: + raise NotImplementedError() + + @staticmethod + def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: + raise NotImplementedError() + + def copy(self) -> "BaseGrammarObject": + raise NotImplementedError() + @property def finished(self): return self._finished @@ -40,7 +64,6 @@ class BaseGrammarObject(ABC): def finished(self, finished): self._finished = finished - @abstractmethod def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]: """ Try to jump forward in the grammar. @@ -49,9 +72,8 @@ class BaseGrammarObject(ABC): A jump forward helper which may be used in `jump_forward_str_state`. None if the jump forward is not possible. """ - raise NotImplementedError + raise NotImplementedError() - @abstractmethod def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]: """ Jump forward for the grammar. @@ -60,47 +82,15 @@ class BaseGrammarObject(ABC): A tuple of the jump forward string and the next state of the grammar (which can be used in `jump_and_retokenize` if needed). """ - raise NotImplementedError + raise NotImplementedError() - @abstractmethod def jump_and_retokenize( self, old_output_ids: List[int], new_output_ids: List[int], next_state: int ) -> None: """ Jump forward occurs, and update the grammar state if needed. """ - raise NotImplementedError - - @abstractmethod - def accept_token(self, token: int) -> None: - """ - Accept a token in the grammar. - """ - raise NotImplementedError - - @abstractmethod - def allocate_vocab_mask( - self, vocab_size: int, batch_size: int, device - ) -> torch.Tensor: - raise NotImplementedError - - @abstractmethod - def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: - raise NotImplementedError - - @staticmethod - @abstractmethod - def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor: - raise NotImplementedError - - @staticmethod - @abstractmethod - def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: - raise NotImplementedError - - @abstractmethod - def copy(self) -> "BaseGrammarObject": - raise NotImplementedError + raise NotImplementedError() @dataclass @@ -113,10 +103,9 @@ class BaseGrammarBackend: def __init__(self): self.executor = ThreadPoolExecutor() self.cache: Dict[Tuple[str, str], CacheEntry] = {} - self.cache_lock = Lock() def _not_supported(self, key_type: str, key_string: str) -> None: - logger.warning(f"Skip unsupported {key_type}: {key_type}={key_string}") + logger.warning(f"Skip unsupported {key_type=}, {key_string=}") def dispatch_fallback( self, key_type: str, key_string: str @@ -148,40 +137,25 @@ class BaseGrammarBackend: return self.dispatch_ebnf(key_string) elif key_type == "structural_tag": return self.dispatch_structural_tag(key_string) + elif key_type == "structural_pattern": + return self.dispatch_structural_pattern(key_string) else: return self.dispatch_fallback(key_type, key_string) - def _init_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]: - with self.cache_lock: - if key in self.cache: - cache_hit = True - entry = self.cache[key] - else: - cache_hit = False - entry = CacheEntry(None, Event()) - self.cache[key] = entry + def get_cached_or_future_value( + self, key: Tuple[str, str] + ) -> Optional[BaseGrammarObject]: + value = self.cache.get(key) + if value: + return value.copy(), True + value = self.executor.submit(self._init_value_dispatch, key) + return value, False - if cache_hit: - entry.event.wait() - else: - entry.value = self._init_value_dispatch(key) - entry.event.set() - return entry.value.copy() if entry.value else None - - def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]: - with self.cache_lock: - entry = self.cache.get(key) - if not entry or not entry.event.is_set(): - return None - val = self.cache[key].value - return val.copy() if val else None - - def get_future_value(self, key: Tuple[str, str]) -> Future: - return self.executor.submit(self._init_value, key) + def set_cache(self, key: Tuple[str, str], value: BaseGrammarObject): + self.cache[key] = value def reset(self): - with self.cache_lock: - self.cache.clear() + self.cache.clear() def create_grammar_backend( @@ -211,9 +185,12 @@ def create_grammar_backend( raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}") if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"): - from .reasoner_grammar_backend import ReasonerGrammarBackend + from sglang.srt.constrained.reasoner_grammar_backend import ( + ReasonerGrammarBackend, + ) grammar_backend = ReasonerGrammarBackend( grammar_backend, tokenizer.think_end_id ) + return grammar_backend diff --git a/python/sglang/srt/constrained/llguidance_backend.py b/python/sglang/srt/constrained/llguidance_backend.py index 49c3740fb..cdd0aecd6 100644 --- a/python/sglang/srt/constrained/llguidance_backend.py +++ b/python/sglang/srt/constrained/llguidance_backend.py @@ -50,21 +50,6 @@ class GuidanceGrammar(BaseGrammarObject): self.finished = False self.bitmask = None - def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]: - ff_tokens = self.ll_matcher.compute_ff_tokens() - if ff_tokens: - return ff_tokens, "" - else: - return None - - def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]: - return "", -1 - - def jump_and_retokenize( - self, old_output_ids: List[int], new_output_ids: List[int], next_state: int - ): - pass - def accept_token(self, token: int): if not self.ll_matcher.consume_token(token): logger.warning(f"matcher error: {self.ll_matcher.get_error()}") @@ -104,6 +89,21 @@ class GuidanceGrammar(BaseGrammarObject): serialized_grammar=self.serialized_grammar, ) + def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]: + ff_tokens = self.ll_matcher.compute_ff_tokens() + if ff_tokens: + return ff_tokens, "" + else: + return None + + def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]: + return "", -1 + + def jump_and_retokenize( + self, old_output_ids: List[int], new_output_ids: List[int], next_state: int + ): + pass + class GuidanceBackend(BaseGrammarBackend): @@ -130,12 +130,16 @@ class GuidanceBackend(BaseGrammarBackend): return None def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]: - serialized_grammar = LLMatcher.grammar_from_json_schema( - key_string, - defaults={ - "whitespace_pattern": self.whitespace_pattern, - }, - ) + try: + serialized_grammar = LLMatcher.grammar_from_json_schema( + key_string, + defaults={ + "whitespace_pattern": self.whitespace_pattern, + }, + ) + except Exception as e: + logger.warning(f"Skip invalid grammar: {key_string=}, {e=}") + return None return self._from_serialized(serialized_grammar) def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]: diff --git a/python/sglang/srt/constrained/outlines_backend.py b/python/sglang/srt/constrained/outlines_backend.py index 41128108a..2f4f97149 100644 --- a/python/sglang/srt/constrained/outlines_backend.py +++ b/python/sglang/srt/constrained/outlines_backend.py @@ -53,6 +53,30 @@ class OutlinesGrammar(BaseGrammarObject): def accept_token(self, token: int): self.state = self.guide.get_next_state(self.state, token) + def allocate_vocab_mask( + self, vocab_size: int, batch_size: int, device + ) -> torch.Tensor: + return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device) + + @staticmethod + def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor: + return vocab_mask + + def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: + tokens = torch.tensor( + self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64 + ).to(vocab_mask.device, non_blocking=True) + vocab_mask = vocab_mask[idx] + vocab_mask.fill_(1) + vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool)) + + @staticmethod + def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor): + logits.masked_fill_(vocab_mask, float("-inf")) + + def copy(self): + return OutlinesGrammar(self.guide, self.jump_forward_map) + def try_jump_forward(self, tokenizer) -> Optional[Tuple]: if not self.jump_forward_map: return None @@ -86,30 +110,6 @@ class OutlinesGrammar(BaseGrammarObject): ): self.state = next_state - def allocate_vocab_mask( - self, vocab_size: int, batch_size: int, device - ) -> torch.Tensor: - return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device) - - @staticmethod - def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor: - return vocab_mask - - def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: - tokens = torch.tensor( - self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64 - ).to(vocab_mask.device, non_blocking=True) - vocab_mask = vocab_mask[idx] - vocab_mask.fill_(1) - vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool)) - - @staticmethod - def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor): - logits.masked_fill_(vocab_mask, float("-inf")) - - def copy(self): - return OutlinesGrammar(self.guide, self.jump_forward_map) - class OutlinesGrammarBackend(BaseGrammarBackend): def __init__( @@ -169,8 +169,9 @@ class OutlinesGrammarBackend(BaseGrammarBackend): key_string, whitespace_pattern=self.whitespace_pattern, ) - except (NotImplementedError, json.decoder.JSONDecodeError) as e: - logger.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}") + except (NotImplementedError, json.decoder.JSONDecodeError, ValueError) as e: + logger.warning(f"Skip invalid json_schema: {key_string=}, {e=}") + return None return self._compile_regex(regex) def dispatch_regex(self, key_string: str): diff --git a/python/sglang/srt/constrained/reasoner_grammar_backend.py b/python/sglang/srt/constrained/reasoner_grammar_backend.py index 3f6f59e5b..ca5f118ef 100644 --- a/python/sglang/srt/constrained/reasoner_grammar_backend.py +++ b/python/sglang/srt/constrained/reasoner_grammar_backend.py @@ -13,7 +13,6 @@ # ============================================================================== """The baseclass of a backend for reasoner grammar-guided constrained decoding.""" -from concurrent.futures import Future from typing import List, Optional, Tuple import torch @@ -28,13 +27,12 @@ class ReasonerGrammarObject(BaseGrammarObject): self.think_end_id = think_end_id self.is_in_reasoning = True - @property - def finished(self): - return self.grammar.finished + def accept_token(self, token: int): + if token == self.think_end_id: + self.is_in_reasoning = False - @finished.setter - def finished(self, finished): - self.grammar.finished = finished + if not self.is_in_reasoning and token != self.think_end_id: + self.grammar.accept_token(token) def allocate_vocab_mask( self, vocab_size: int, batch_size: int, device @@ -52,12 +50,16 @@ class ReasonerGrammarObject(BaseGrammarObject): def apply_vocab_mask(self): return self.grammar.apply_vocab_mask - def accept_token(self, token: int): - if token == self.think_end_id: - self.is_in_reasoning = False + def copy(self) -> BaseGrammarObject: + return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id) - if not self.is_in_reasoning and token != self.think_end_id: - self.grammar.accept_token(token) + @property + def finished(self): + return self.grammar.finished + + @finished.setter + def finished(self, finished): + self.grammar.finished = finished def try_jump_forward(self, tokenizer): return self.grammar.try_jump_forward(tokenizer) @@ -72,30 +74,17 @@ class ReasonerGrammarObject(BaseGrammarObject): old_output_ids, new_output_ids, next_state ) - def copy(self) -> BaseGrammarObject: - return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id) - class ReasonerGrammarBackend(BaseGrammarBackend): def __init__(self, grammar_backend: BaseGrammarBackend, think_end_id): + super().__init__() self.grammar_backend = grammar_backend self.think_end_id = think_end_id - def get_cached_value(self, key: Tuple[str, str]) -> Optional[ReasonerGrammarObject]: - grammar = self.grammar_backend.get_cached_value(key) - return ReasonerGrammarObject(grammar, self.think_end_id) if grammar else None - - def get_future_value(self, key: Tuple[str, str]) -> Future: - grammar = Future() - - def callback(f: Future): - if result := f.result(): - grammar.set_result(ReasonerGrammarObject(result, self.think_end_id)) - else: - grammar.set_result(None) - - self.grammar_backend.get_future_value(key).add_done_callback(callback) - return grammar - - def reset(self): - self.grammar_backend.reset() + def _init_value_dispatch( + self, key: Tuple[str, str] + ) -> Optional[ReasonerGrammarObject]: + ret = self.grammar_backend._init_value_dispatch(key) + if ret is None: + return None + return ReasonerGrammarObject(ret, self.think_end_id) diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py index 8e715b3d8..7c02978e4 100644 --- a/python/sglang/srt/constrained/xgrammar_backend.py +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -18,7 +18,6 @@ import logging from typing import List, Optional, Tuple, Union import torch -import xgrammar from xgrammar import ( CompiledGrammar, GrammarCompiler, @@ -35,7 +34,6 @@ from sglang.srt.constrained.base_grammar_backend import ( from sglang.srt.constrained.triton_ops.bitmask_ops import ( apply_token_bitmask_inplace_triton, ) -from sglang.srt.utils import get_bool_env_var logger = logging.getLogger(__name__) @@ -51,49 +49,35 @@ class XGrammarGrammar(BaseGrammarObject): vocab_size: int, ctx: CompiledGrammar, override_stop_tokens: Optional[Union[List[int], int]], + key_string: Optional[str] = None, # TODO (sk): for debugging, remove later ) -> None: - super().__init__() self.matcher = matcher self.vocab_size = vocab_size self.ctx = ctx self.override_stop_tokens = override_stop_tokens self.finished = False - - from xgrammar.kernels.apply_token_bitmask_inplace_cpu import ( - apply_token_bitmask_inplace_cpu, - ) - - self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_cpu + self.accepted_tokens = [] + self.key_string = key_string def accept_token(self, token: int): - assert self.matcher.accept_token(token) - - def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]: - s = self.matcher.find_jump_forward_string() - if s: - return [], s - return None - - def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]: - _, data = helper - return data, -1 - - def jump_and_retokenize( - self, old_output_ids: List[int], new_output_ids: List[int], next_state: int - ): - k = 0 - for i, old_id in enumerate(old_output_ids): - if old_id == new_output_ids[i]: - k = i + 1 + if not self.is_terminated(): + accepted = self.matcher.accept_token(token) + if not accepted: + # log for debugging + raise ValueError( + f"Tokens not accepted: {token}\n" + f"Accepted tokens: {self.accepted_tokens}\n" + f"Key string: {self.key_string}" + ) else: - break + self.accepted_tokens.append(token) - # rollback to the last token that is the same - if k < len(old_output_ids): - self.matcher.rollback(len(old_output_ids) - k) + def rollback(self, k: int): + self.matcher.rollback(k) + self.accepted_tokens = self.accepted_tokens[:-k] - for i in range(k, len(new_output_ids)): - assert self.matcher.accept_token(new_output_ids[i]) + def is_terminated(self): + return self.matcher.is_terminated() def allocate_vocab_mask( self, vocab_size: int, batch_size: int, device @@ -122,9 +106,43 @@ class XGrammarGrammar(BaseGrammarObject): override_stop_tokens=self.override_stop_tokens, ) return XGrammarGrammar( - matcher, self.vocab_size, self.ctx, self.override_stop_tokens + matcher, + self.vocab_size, + self.ctx, + self.override_stop_tokens, + self.key_string, ) + def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]: + s = self.matcher.find_jump_forward_string() + if s: + return [], s + return None + + def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]: + _, data = helper + return data, -1 + + def jump_and_retokenize( + self, old_output_ids: List[int], new_output_ids: List[int], next_state: int + ): + k = 0 + for i, old_id in enumerate(old_output_ids): + if old_id == new_output_ids[i]: + k = i + 1 + else: + break + + # rollback to the last token that is the same + if k < len(old_output_ids): + self.matcher.rollback(len(old_output_ids) - k) + + for i in range(k, len(new_output_ids)): + assert self.matcher.accept_token(new_output_ids[i]) + + def __repr__(self): + return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=})" + class XGrammarGrammarBackend(BaseGrammarBackend): def __init__( @@ -143,9 +161,15 @@ class XGrammarGrammarBackend(BaseGrammarBackend): self.vocab_size = vocab_size self.override_stop_tokens = override_stop_tokens - def _from_context(self, ctx: CompiledGrammar) -> XGrammarGrammar: - matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS) - return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens) + def _from_context(self, ctx: CompiledGrammar, key_string: str) -> XGrammarGrammar: + matcher = GrammarMatcher( + ctx, + max_rollback_tokens=MAX_ROLLBACK_TOKENS, + override_stop_tokens=self.override_stop_tokens, + ) + return XGrammarGrammar( + matcher, self.vocab_size, ctx, self.override_stop_tokens, key_string + ) def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]: try: @@ -157,7 +181,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend): except RuntimeError as e: logging.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}") return None - return self._from_context(ctx) + return self._from_context(ctx, key_string) def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]: try: @@ -165,7 +189,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend): except RuntimeError as e: logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}") return None - return self._from_context(ctx) + return self._from_context(ctx, key_string) def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]: try: @@ -173,7 +197,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend): except RuntimeError as e: logging.warning(f"Skip invalid regex: regex={key_string}, {e=}") return None - return self._from_context(ctx) + return self._from_context(ctx, key_string) def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]: try: @@ -190,9 +214,11 @@ class XGrammarGrammarBackend(BaseGrammarBackend): tags, structural_tag["triggers"] ) except RuntimeError as e: - logging.warning(f"Skip invalid regex: regex={key_string}, {e=}") + logging.warning( + f"Skip invalid structural_tag: structural_tag={key_string}, {e=}" + ) return None - return self._from_context(ctx) + return self._from_context(ctx, key_string) def reset(self): if self.grammar_compiler: diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 61177b3a7..8ed50b1c9 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -239,10 +239,6 @@ def top_p_normalize_probs_torch( def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]): - assert len(top_logprobs_nums) == logprobs.shape[0], ( - len(top_logprobs_nums), - logprobs.shape[0], - ) max_k = max(top_logprobs_nums) ret = logprobs.topk(max_k, dim=1) values = ret.values.tolist() diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 38420076a..a797a7f3a 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -533,6 +533,7 @@ class Req: # Constrained decoding self.grammar: Optional[BaseGrammarObject] = None + self.grammar_wait_ct = 0 # The number of cached tokens that were already cached in the KV cache self.cached_tokens = 0 diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 28c68a41f..4498f2cfc 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -149,6 +149,7 @@ logger = logging.getLogger(__name__) # Test retract decode for debugging purposes TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT") RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME") +GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300)) @dataclass @@ -1024,9 +1025,11 @@ class Scheduler( elif req.sampling_params.structural_tag: key = ("structural_tag", req.sampling_params.structural_tag) - req.grammar = self.grammar_backend.get_cached_value(key) - if not req.grammar: - req.grammar = self.grammar_backend.get_future_value(key) + value, cache_hit = self.grammar_backend.get_cached_or_future_value(key) + req.grammar = value + + if not cache_hit: + req.grammar_key = key add_to_grammar_queue = True if add_to_grammar_queue: @@ -1208,6 +1211,7 @@ class Scheduler( self.stats.cache_hit_rate = 0.0 self.stats.gen_throughput = self.last_gen_throughput self.stats.num_queue_reqs = len(self.waiting_queue) + self.stats.num_grammar_queue_reqs = len(self.grammar_queue) self.stats.spec_accept_length = spec_accept_length self.metrics_collector.log_stats(self.stats) @@ -1255,6 +1259,7 @@ class Scheduler( self.stats.token_usage = num_used / self.max_total_num_tokens self.stats.gen_throughput = 0 self.stats.num_queue_reqs = len(self.waiting_queue) + self.stats.num_grammar_queue_reqs = len(self.grammar_queue) self.metrics_collector.log_stats(self.stats) def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: @@ -1715,11 +1720,17 @@ class Scheduler( """Move requests whose grammar objects are ready from grammar_queue to waiting_queue.""" num_ready_reqs = 0 + num_abort_reqs = 0 for req in self.grammar_queue: try: - req.grammar = req.grammar.result(timeout=0.05) + req.grammar = req.grammar.result(timeout=0.03) + if req.grammar: + self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy()) num_ready_reqs += 1 except futures._base.TimeoutError: + req.grammar_wait_ct += 1 + if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03: + num_abort_reqs = 1 break if self.server_args.enable_dp_attention: @@ -1731,14 +1742,28 @@ class Scheduler( if tp_size > 1: # Sync across TP ranks to make sure they have the same number of ready requests - tensor = torch.tensor(num_ready_reqs, dtype=torch.int32) + tensor = torch.tensor([num_ready_reqs, num_abort_reqs], dtype=torch.int32) torch.distributed.all_reduce( tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group ) - num_ready_reqs_max = tensor.item() + num_ready_reqs_max, num_abort_reqs_max = tensor.tolist() + for i in range(num_ready_reqs, num_ready_reqs_max): - self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result() - num_ready_reqs = num_ready_reqs_max + req = self.grammar_queue[i] + req.grammar = req.grammar.result() + if req.grammar: + self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy()) + + for i in range(num_ready_reqs, num_ready_reqs + num_abort_reqs_max): + req = self.grammar_queue[i] + req.grammar.cancel() + req.grammar = None + error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}" + logger.error(error_msg) + req.finished_reason = FINISH_ABORT( + error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError" + ) + num_ready_reqs = num_ready_reqs_max + num_abort_reqs_max self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs]) self.grammar_queue = self.grammar_queue[num_ready_reqs:] diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 89190d8a4..306359dda 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -1230,11 +1230,18 @@ class TokenizerManager: state.last_completion_tokens = completion_tokens if state.finished: + has_grammar = ( + state.obj.sampling_params.get("json_schema", None) + or state.obj.sampling_params.get("regex", None) + or state.obj.sampling_params.get("ebnf", None) + or state.obj.sampling_params.get("structural_tag", None) + ) self.metrics_collector.observe_one_finished_request( recv_obj.prompt_tokens[i], completion_tokens, recv_obj.cached_tokens[i], state.finished_time - state.created_time, + has_grammar, ) def dump_requests(self, state: ReqState, out_dict: dict): diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/metrics/collector.py index b881406e6..aa407e0ec 100644 --- a/python/sglang/srt/metrics/collector.py +++ b/python/sglang/srt/metrics/collector.py @@ -15,7 +15,119 @@ import time from dataclasses import dataclass -from typing import Dict, Union +from enum import Enum +from typing import Dict, List, Optional, Union + +from sglang.srt.utils import get_bool_env_var + +SGLANG_TEST_REQUEST_TIME_STATS = get_bool_env_var("SGLANG_TEST_REQUEST_TIME_STATS") + + +@dataclass +class TimeStats: + """ + Store the timestamps for each stage of a request. + + Unified: wait_queue -> forward -> completion + Prefill: bootstrap_queue -> wait_queue -> forward -> transfer_queue -> completion + Decode: prealloc_queue -> transfer_queue -> wait_queue -> forward -> completion + """ + + lb_entry_time: float = 0.0 + wait_queue_entry_time: float = 0.0 + forward_entry_time: float = 0.0 + completion_time: float = 0.0 + prefill_bootstrap_queue_entry_time: float = 0.0 + prefill_transfer_queue_entry_time: float = 0.0 + decode_prealloc_queue_entry_time: float = 0.0 + decode_transfer_queue_entry_time: float = 0.0 + + class RequestType(Enum): + UNIFIED = "unified" + PREFILL = "prefill" + DECODE = "decode" + INVALID = "invalid" + + def __str__(self) -> str: + # if unified + _type = self.get_type() + + if _type == self.RequestType.UNIFIED: + queue_duration = self.forward_entry_time - self.wait_queue_entry_time + forward_duration = self.completion_time - self.forward_entry_time + + if SGLANG_TEST_REQUEST_TIME_STATS: + assert ( + queue_duration >= 0 and forward_duration >= 0 + ), f"queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0" + + return f"queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.wait_queue_entry_time}" + elif _type == self.RequestType.PREFILL: + bootstrap_duration = ( + self.wait_queue_entry_time - self.prefill_bootstrap_queue_entry_time + ) + + queue_duration = self.forward_entry_time - self.wait_queue_entry_time + + forward_duration = self.completion_time - self.forward_entry_time + + if SGLANG_TEST_REQUEST_TIME_STATS: + assert ( + bootstrap_duration >= 0 + and queue_duration >= 0 + and forward_duration >= 0 + ), f"bootstrap_duration={bootstrap_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0" + return f"bootstrap_duration={self.format_duration(bootstrap_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.prefill_bootstrap_queue_entry_time}" + # if decode + elif _type == self.RequestType.DECODE: + prealloc_duration = ( + self.decode_transfer_queue_entry_time + - self.decode_prealloc_queue_entry_time + ) + + transfer_duration = ( + self.wait_queue_entry_time - self.decode_transfer_queue_entry_time + ) + queue_duration = self.forward_entry_time - self.wait_queue_entry_time + forward_duration = self.completion_time - self.forward_entry_time + + if SGLANG_TEST_REQUEST_TIME_STATS: + assert ( + prealloc_duration >= 0 + and transfer_duration >= 0 + and queue_duration >= 0 + and forward_duration >= 0 + ), f"prealloc_duration={prealloc_duration} < 0 or transfer_duration={transfer_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0" + + return f"prealloc_duration={self.format_duration(prealloc_duration)}, transfer_duration={self.format_duration(transfer_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.decode_prealloc_queue_entry_time}" + else: + return "Invalid Time Stats" + + def format_duration(self, duration: float) -> str: + return f"{duration * 1e3:.2f}ms" + + def get_type(self) -> RequestType: + """Determine the type of request based on timestamp values.""" + if ( + self.prefill_bootstrap_queue_entry_time == 0.0 + and self.prefill_transfer_queue_entry_time == 0.0 + and self.decode_prealloc_queue_entry_time == 0.0 + and self.decode_transfer_queue_entry_time == 0.0 + ): + return self.RequestType.UNIFIED + elif ( + self.prefill_bootstrap_queue_entry_time > 0.0 + and self.prefill_transfer_queue_entry_time > 0.0 + ): + return self.RequestType.PREFILL + elif ( + self.decode_prealloc_queue_entry_time > 0.0 + and self.decode_transfer_queue_entry_time > 0.0 + and self.wait_queue_entry_time > 0.0 + ): + return self.RequestType.DECODE + else: + return self.RequestType.INVALID @dataclass @@ -26,15 +138,20 @@ class SchedulerStats: gen_throughput: float = 0.0 num_queue_reqs: int = 0 cache_hit_rate: float = 0.0 + num_grammar_queue_reqs: int = 0 spec_accept_length: float = 0.0 avg_request_queue_latency: float = 0.0 + num_prefill_prealloc_queue_reqs: int = 0 + num_prefill_infight_queue_reqs: int = 0 + num_decode_prealloc_queue_reqs: int = 0 + num_decode_transfer_queue_reqs: int = 0 class SchedulerMetricsCollector: def __init__(self, labels: Dict[str, str]) -> None: # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR` - from prometheus_client import Gauge, Histogram + from prometheus_client import Counter, Gauge self.labels = labels self.last_log_time = time.time() @@ -74,6 +191,13 @@ class SchedulerMetricsCollector: multiprocess_mode="mostrecent", ) + self.num_grammar_queue_reqs = Gauge( + name="sglang:num_grammar_queue_reqs", + documentation="The number of requests in the grammar waiting queue.", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) + self.cache_hit_rate = Gauge( name="sglang:cache_hit_rate", documentation="The prefix cache hit rate.", @@ -95,28 +219,98 @@ class SchedulerMetricsCollector: multiprocess_mode="mostrecent", ) + # Disaggregation queue metrics + self.num_prefill_prealloc_queue_reqs = Gauge( + name="sglang:num_prefill_prealloc_queue_reqs", + documentation="The number of requests in the prefill prealloc queue.", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) + + self.num_prefill_infight_queue_reqs = Gauge( + name="sglang:num_prefill_infight_queue_reqs", + documentation="The number of requests in the prefill infight queue.", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) + + self.num_decode_prealloc_queue_reqs = Gauge( + name="sglang:num_decode_prealloc_queue_reqs", + documentation="The number of requests in the decode prealloc queue.", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) + + self.num_decode_transfer_queue_reqs = Gauge( + name="sglang:num_decode_transfer_queue_reqs", + documentation="The number of requests in the decode transfer queue.", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) + + self.num_bootstrap_failed_reqs = Counter( + name="sglang:num_bootstrap_failed_reqs", + documentation="The number of bootstrap failed requests.", + labelnames=labels.keys(), + ) + + self.num_transfer_failed_reqs = Counter( + name="sglang:num_transfer_failed_reqs", + documentation="The number of transfer failed requests.", + labelnames=labels.keys(), + ) + def _log_gauge(self, gauge, data: Union[int, float]) -> None: # Convenience function for logging to gauge. gauge.labels(**self.labels).set(data) + def increment_bootstrap_failed_reqs(self) -> None: + self.num_bootstrap_failed_reqs.labels(**self.labels).inc(1) + + def increment_transfer_failed_reqs(self) -> None: + self.num_transfer_failed_reqs.labels(**self.labels).inc(1) + def log_stats(self, stats: SchedulerStats) -> None: self._log_gauge(self.num_running_reqs, stats.num_running_reqs) self._log_gauge(self.num_used_tokens, stats.num_used_tokens) self._log_gauge(self.token_usage, stats.token_usage) self._log_gauge(self.gen_throughput, stats.gen_throughput) self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs) + self._log_gauge(self.num_grammar_queue_reqs, stats.num_grammar_queue_reqs) self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate) self._log_gauge(self.spec_accept_length, stats.spec_accept_length) - self._log_gauge(self.avg_request_queue_latency, stats.avg_request_queue_latency) + + # Disaggregation metrics + self._log_gauge( + self.num_prefill_prealloc_queue_reqs, stats.num_prefill_prealloc_queue_reqs + ) + self._log_gauge( + self.num_prefill_infight_queue_reqs, stats.num_prefill_infight_queue_reqs + ) + self._log_gauge( + self.num_decode_prealloc_queue_reqs, stats.num_decode_prealloc_queue_reqs + ) + self._log_gauge( + self.num_decode_transfer_queue_reqs, stats.num_decode_transfer_queue_reqs + ) + self.last_log_time = time.time() class TokenizerMetricsCollector: - def __init__(self, labels: Dict[str, str]) -> None: + def __init__( + self, + labels: Dict[str, str], + bucket_time_to_first_token: Optional[List[float]] = None, + bucket_inter_token_latency: Optional[List[float]] = None, + bucket_e2e_request_latency: Optional[List[float]] = None, + collect_tokens_histogram: bool = False, + ) -> None: # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR` from prometheus_client import Counter, Histogram self.labels = labels + self.collect_tokens_histogram = collect_tokens_histogram self.prompt_tokens_total = Counter( name="sglang:prompt_tokens_total", @@ -130,6 +324,66 @@ class TokenizerMetricsCollector: labelnames=labels.keys(), ) + if collect_tokens_histogram: + bucket_prompt_tokens = [ + 100, + 300, + 500, + 700, + 1000, + 1500, + 2000, + 3000, + 4000, + 5000, + 6000, + 7000, + 8000, + 9000, + 10000, + 12000, + 15000, + 20000, + 22000, + 25000, + 30000, + 35000, + 40000, + ] + self.prompt_tokens_histogram = Histogram( + name="sglang:prompt_tokens_histogram", + documentation="Histogram of prompt token length.", + labelnames=labels.keys(), + buckets=bucket_prompt_tokens, + ) + bucket_generation_tokens = [ + 100, + 300, + 500, + 1000, + 1200, + 1500, + 1700, + 2000, + 2500, + 3000, + 3500, + 4000, + 4500, + 5000, + 6000, + 7000, + 8000, + 9000, + 10000, + ] + self.generation_tokens_histogram = Histogram( + name="sglang:generation_tokens_histogram", + documentation="Histogram of generation token length.", + labelnames=labels.keys(), + buckets=bucket_generation_tokens, + ) + self.cached_tokens_total = Counter( name="sglang:cached_tokens_total", documentation="Number of cached prompt tokens.", @@ -142,11 +396,14 @@ class TokenizerMetricsCollector: labelnames=labels.keys(), ) - self.histogram_time_to_first_token = Histogram( - name="sglang:time_to_first_token_seconds", - documentation="Histogram of time to first token in seconds.", + self.num_so_requests_total = Counter( + name="sglang:num_so_requests_total", + documentation="Number of structured output requests processed.", labelnames=labels.keys(), - buckets=[ + ) + + if bucket_time_to_first_token is None: + bucket_time_to_first_token = [ 0.1, 0.2, 0.4, @@ -165,14 +422,33 @@ class TokenizerMetricsCollector: 100, 200, 400, - ], - ) + ] - self.histogram_inter_token_latency_seconds = Histogram( - name="sglang:inter_token_latency_seconds", - documentation="Histogram of inter-token latency in seconds.", - labelnames=labels.keys(), - buckets=[ + if bucket_e2e_request_latency is None: + bucket_e2e_request_latency = [ + 0.1, + 0.2, + 0.4, + 0.6, + 0.8, + 1, + 2, + 4, + 6, + 8, + 10, + 20, + 40, + 60, + 80, + 100, + 200, + 400, + 800, + ] + + if bucket_inter_token_latency is None: + bucket_inter_token_latency = [ 0.002, 0.004, 0.006, @@ -196,34 +472,27 @@ class TokenizerMetricsCollector: 4.000, 6.000, 8.000, - ], + ] + + self.histogram_time_to_first_token = Histogram( + name="sglang:time_to_first_token_seconds", + documentation="Histogram of time to first token in seconds.", + labelnames=labels.keys(), + buckets=bucket_time_to_first_token, + ) + + self.histogram_inter_token_latency_seconds = Histogram( + name="sglang:inter_token_latency_seconds", + documentation="Histogram of inter-token latency in seconds.", + labelnames=labels.keys(), + buckets=bucket_inter_token_latency, ) self.histogram_e2e_request_latency = Histogram( name="sglang:e2e_request_latency_seconds", documentation="Histogram of End-to-end request latency in seconds", labelnames=labels.keys(), - buckets=[ - 0.1, - 0.2, - 0.4, - 0.6, - 0.8, - 1, - 2, - 4, - 6, - 8, - 10, - 20, - 40, - 60, - 80, - 100, - 200, - 400, - 800, - ], + buckets=bucket_e2e_request_latency, ) def _log_histogram(self, histogram, data: Union[int, float]) -> None: @@ -235,13 +504,19 @@ class TokenizerMetricsCollector: generation_tokens: int, cached_tokens: int, e2e_latency: float, + has_grammar: bool, ): self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens) self.generation_tokens_total.labels(**self.labels).inc(generation_tokens) if cached_tokens > 0: self.cached_tokens_total.labels(**self.labels).inc(cached_tokens) self.num_requests_total.labels(**self.labels).inc(1) + if has_grammar: + self.num_so_requests_total.labels(**self.labels).inc(1) self._log_histogram(self.histogram_e2e_request_latency, e2e_latency) + if self.collect_tokens_histogram: + self._log_histogram(self.prompt_tokens_histogram, prompt_tokens) + self._log_histogram(self.generation_tokens_histogram, generation_tokens) def observe_time_to_first_token(self, value: float): self.histogram_time_to_first_token.labels(**self.labels).observe(value) diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index 7b42319bf..49b02a1ed 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -82,7 +82,7 @@ class TestJSONConstrainedOutlinesBackend(CustomTestCase): print(json.dumps(ret)) print("=" * 100) - if not json_schema: + if not json_schema or json_schema == "INVALID": return # Make sure the json output is valid @@ -97,6 +97,9 @@ class TestJSONConstrainedOutlinesBackend(CustomTestCase): def test_json_generate(self): self.run_decode(json_schema=self.json_schema) + def test_json_invalid(self): + self.run_decode(json_schema="INVALID") + def test_json_openai(self): client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1") @@ -104,7 +107,10 @@ class TestJSONConstrainedOutlinesBackend(CustomTestCase): model=self.model, messages=[ {"role": "system", "content": "You are a helpful AI assistant"}, - {"role": "user", "content": "Introduce the capital of France."}, + { + "role": "user", + "content": "Introduce the capital of France. Return in a JSON format.", + }, ], temperature=0, max_tokens=128, diff --git a/test/srt/test_metrics.py b/test/srt/test_metrics.py index 1ab487025..ce85e0d8a 100644 --- a/test/srt/test_metrics.py +++ b/test/srt/test_metrics.py @@ -56,6 +56,7 @@ class TestEnableMetrics(CustomTestCase): "sglang:token_usage", "sglang:gen_throughput", "sglang:num_queue_reqs", + "sglang:num_grammar_queue_reqs", "sglang:cache_hit_rate", "sglang:spec_accept_length", "sglang:prompt_tokens_total",