Improve structured outputs: fix race condition, server crash, metrics and style (#6188)
This commit is contained in:
@@ -14,10 +14,9 @@
|
||||
"""The baseclass of a backend for grammar-guided constrained decoding."""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from threading import Event, Lock
|
||||
from threading import Event
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -27,11 +26,36 @@ from sglang.srt.server_args import ServerArgs
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseGrammarObject(ABC):
|
||||
class BaseGrammarObject:
|
||||
|
||||
def __init__(self):
|
||||
self._finished = False
|
||||
|
||||
def accept_token(self, token: int) -> None:
|
||||
"""
|
||||
Accept a token in the grammar.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def allocate_vocab_mask(
|
||||
self, vocab_size: int, batch_size: int, device
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def copy(self) -> "BaseGrammarObject":
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def finished(self):
|
||||
return self._finished
|
||||
@@ -40,7 +64,6 @@ class BaseGrammarObject(ABC):
|
||||
def finished(self, finished):
|
||||
self._finished = finished
|
||||
|
||||
@abstractmethod
|
||||
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
||||
"""
|
||||
Try to jump forward in the grammar.
|
||||
@@ -49,9 +72,8 @@ class BaseGrammarObject(ABC):
|
||||
A jump forward helper which may be used in `jump_forward_str_state`.
|
||||
None if the jump forward is not possible.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
||||
"""
|
||||
Jump forward for the grammar.
|
||||
@@ -60,47 +82,15 @@ class BaseGrammarObject(ABC):
|
||||
A tuple of the jump forward string and the next state of the grammar
|
||||
(which can be used in `jump_and_retokenize` if needed).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def jump_and_retokenize(
|
||||
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
||||
) -> None:
|
||||
"""
|
||||
Jump forward occurs, and update the grammar state if needed.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def accept_token(self, token: int) -> None:
|
||||
"""
|
||||
Accept a token in the grammar.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def allocate_vocab_mask(
|
||||
self, vocab_size: int, batch_size: int, device
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def copy(self) -> "BaseGrammarObject":
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -113,10 +103,9 @@ class BaseGrammarBackend:
|
||||
def __init__(self):
|
||||
self.executor = ThreadPoolExecutor()
|
||||
self.cache: Dict[Tuple[str, str], CacheEntry] = {}
|
||||
self.cache_lock = Lock()
|
||||
|
||||
def _not_supported(self, key_type: str, key_string: str) -> None:
|
||||
logger.warning(f"Skip unsupported {key_type}: {key_type}={key_string}")
|
||||
logger.warning(f"Skip unsupported {key_type=}, {key_string=}")
|
||||
|
||||
def dispatch_fallback(
|
||||
self, key_type: str, key_string: str
|
||||
@@ -148,40 +137,25 @@ class BaseGrammarBackend:
|
||||
return self.dispatch_ebnf(key_string)
|
||||
elif key_type == "structural_tag":
|
||||
return self.dispatch_structural_tag(key_string)
|
||||
elif key_type == "structural_pattern":
|
||||
return self.dispatch_structural_pattern(key_string)
|
||||
else:
|
||||
return self.dispatch_fallback(key_type, key_string)
|
||||
|
||||
def _init_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
||||
with self.cache_lock:
|
||||
if key in self.cache:
|
||||
cache_hit = True
|
||||
entry = self.cache[key]
|
||||
else:
|
||||
cache_hit = False
|
||||
entry = CacheEntry(None, Event())
|
||||
self.cache[key] = entry
|
||||
def get_cached_or_future_value(
|
||||
self, key: Tuple[str, str]
|
||||
) -> Optional[BaseGrammarObject]:
|
||||
value = self.cache.get(key)
|
||||
if value:
|
||||
return value.copy(), True
|
||||
value = self.executor.submit(self._init_value_dispatch, key)
|
||||
return value, False
|
||||
|
||||
if cache_hit:
|
||||
entry.event.wait()
|
||||
else:
|
||||
entry.value = self._init_value_dispatch(key)
|
||||
entry.event.set()
|
||||
return entry.value.copy() if entry.value else None
|
||||
|
||||
def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
||||
with self.cache_lock:
|
||||
entry = self.cache.get(key)
|
||||
if not entry or not entry.event.is_set():
|
||||
return None
|
||||
val = self.cache[key].value
|
||||
return val.copy() if val else None
|
||||
|
||||
def get_future_value(self, key: Tuple[str, str]) -> Future:
|
||||
return self.executor.submit(self._init_value, key)
|
||||
def set_cache(self, key: Tuple[str, str], value: BaseGrammarObject):
|
||||
self.cache[key] = value
|
||||
|
||||
def reset(self):
|
||||
with self.cache_lock:
|
||||
self.cache.clear()
|
||||
self.cache.clear()
|
||||
|
||||
|
||||
def create_grammar_backend(
|
||||
@@ -211,9 +185,12 @@ def create_grammar_backend(
|
||||
raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
|
||||
|
||||
if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"):
|
||||
from .reasoner_grammar_backend import ReasonerGrammarBackend
|
||||
from sglang.srt.constrained.reasoner_grammar_backend import (
|
||||
ReasonerGrammarBackend,
|
||||
)
|
||||
|
||||
grammar_backend = ReasonerGrammarBackend(
|
||||
grammar_backend, tokenizer.think_end_id
|
||||
)
|
||||
|
||||
return grammar_backend
|
||||
|
||||
@@ -50,21 +50,6 @@ class GuidanceGrammar(BaseGrammarObject):
|
||||
self.finished = False
|
||||
self.bitmask = None
|
||||
|
||||
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
||||
ff_tokens = self.ll_matcher.compute_ff_tokens()
|
||||
if ff_tokens:
|
||||
return ff_tokens, ""
|
||||
else:
|
||||
return None
|
||||
|
||||
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
||||
return "", -1
|
||||
|
||||
def jump_and_retokenize(
|
||||
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
||||
):
|
||||
pass
|
||||
|
||||
def accept_token(self, token: int):
|
||||
if not self.ll_matcher.consume_token(token):
|
||||
logger.warning(f"matcher error: {self.ll_matcher.get_error()}")
|
||||
@@ -104,6 +89,21 @@ class GuidanceGrammar(BaseGrammarObject):
|
||||
serialized_grammar=self.serialized_grammar,
|
||||
)
|
||||
|
||||
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
||||
ff_tokens = self.ll_matcher.compute_ff_tokens()
|
||||
if ff_tokens:
|
||||
return ff_tokens, ""
|
||||
else:
|
||||
return None
|
||||
|
||||
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
||||
return "", -1
|
||||
|
||||
def jump_and_retokenize(
|
||||
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class GuidanceBackend(BaseGrammarBackend):
|
||||
|
||||
@@ -130,12 +130,16 @@ class GuidanceBackend(BaseGrammarBackend):
|
||||
return None
|
||||
|
||||
def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]:
|
||||
serialized_grammar = LLMatcher.grammar_from_json_schema(
|
||||
key_string,
|
||||
defaults={
|
||||
"whitespace_pattern": self.whitespace_pattern,
|
||||
},
|
||||
)
|
||||
try:
|
||||
serialized_grammar = LLMatcher.grammar_from_json_schema(
|
||||
key_string,
|
||||
defaults={
|
||||
"whitespace_pattern": self.whitespace_pattern,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Skip invalid grammar: {key_string=}, {e=}")
|
||||
return None
|
||||
return self._from_serialized(serialized_grammar)
|
||||
|
||||
def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]:
|
||||
|
||||
@@ -53,6 +53,30 @@ class OutlinesGrammar(BaseGrammarObject):
|
||||
def accept_token(self, token: int):
|
||||
self.state = self.guide.get_next_state(self.state, token)
|
||||
|
||||
def allocate_vocab_mask(
|
||||
self, vocab_size: int, batch_size: int, device
|
||||
) -> torch.Tensor:
|
||||
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
|
||||
|
||||
@staticmethod
|
||||
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
||||
return vocab_mask
|
||||
|
||||
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
||||
tokens = torch.tensor(
|
||||
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
|
||||
).to(vocab_mask.device, non_blocking=True)
|
||||
vocab_mask = vocab_mask[idx]
|
||||
vocab_mask.fill_(1)
|
||||
vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
|
||||
|
||||
@staticmethod
|
||||
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
|
||||
logits.masked_fill_(vocab_mask, float("-inf"))
|
||||
|
||||
def copy(self):
|
||||
return OutlinesGrammar(self.guide, self.jump_forward_map)
|
||||
|
||||
def try_jump_forward(self, tokenizer) -> Optional[Tuple]:
|
||||
if not self.jump_forward_map:
|
||||
return None
|
||||
@@ -86,30 +110,6 @@ class OutlinesGrammar(BaseGrammarObject):
|
||||
):
|
||||
self.state = next_state
|
||||
|
||||
def allocate_vocab_mask(
|
||||
self, vocab_size: int, batch_size: int, device
|
||||
) -> torch.Tensor:
|
||||
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
|
||||
|
||||
@staticmethod
|
||||
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
||||
return vocab_mask
|
||||
|
||||
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
||||
tokens = torch.tensor(
|
||||
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
|
||||
).to(vocab_mask.device, non_blocking=True)
|
||||
vocab_mask = vocab_mask[idx]
|
||||
vocab_mask.fill_(1)
|
||||
vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
|
||||
|
||||
@staticmethod
|
||||
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
|
||||
logits.masked_fill_(vocab_mask, float("-inf"))
|
||||
|
||||
def copy(self):
|
||||
return OutlinesGrammar(self.guide, self.jump_forward_map)
|
||||
|
||||
|
||||
class OutlinesGrammarBackend(BaseGrammarBackend):
|
||||
def __init__(
|
||||
@@ -169,8 +169,9 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
|
||||
key_string,
|
||||
whitespace_pattern=self.whitespace_pattern,
|
||||
)
|
||||
except (NotImplementedError, json.decoder.JSONDecodeError) as e:
|
||||
logger.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
|
||||
except (NotImplementedError, json.decoder.JSONDecodeError, ValueError) as e:
|
||||
logger.warning(f"Skip invalid json_schema: {key_string=}, {e=}")
|
||||
return None
|
||||
return self._compile_regex(regex)
|
||||
|
||||
def dispatch_regex(self, key_string: str):
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# ==============================================================================
|
||||
"""The baseclass of a backend for reasoner grammar-guided constrained decoding."""
|
||||
|
||||
from concurrent.futures import Future
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -28,13 +27,12 @@ class ReasonerGrammarObject(BaseGrammarObject):
|
||||
self.think_end_id = think_end_id
|
||||
self.is_in_reasoning = True
|
||||
|
||||
@property
|
||||
def finished(self):
|
||||
return self.grammar.finished
|
||||
def accept_token(self, token: int):
|
||||
if token == self.think_end_id:
|
||||
self.is_in_reasoning = False
|
||||
|
||||
@finished.setter
|
||||
def finished(self, finished):
|
||||
self.grammar.finished = finished
|
||||
if not self.is_in_reasoning and token != self.think_end_id:
|
||||
self.grammar.accept_token(token)
|
||||
|
||||
def allocate_vocab_mask(
|
||||
self, vocab_size: int, batch_size: int, device
|
||||
@@ -52,12 +50,16 @@ class ReasonerGrammarObject(BaseGrammarObject):
|
||||
def apply_vocab_mask(self):
|
||||
return self.grammar.apply_vocab_mask
|
||||
|
||||
def accept_token(self, token: int):
|
||||
if token == self.think_end_id:
|
||||
self.is_in_reasoning = False
|
||||
def copy(self) -> BaseGrammarObject:
|
||||
return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id)
|
||||
|
||||
if not self.is_in_reasoning and token != self.think_end_id:
|
||||
self.grammar.accept_token(token)
|
||||
@property
|
||||
def finished(self):
|
||||
return self.grammar.finished
|
||||
|
||||
@finished.setter
|
||||
def finished(self, finished):
|
||||
self.grammar.finished = finished
|
||||
|
||||
def try_jump_forward(self, tokenizer):
|
||||
return self.grammar.try_jump_forward(tokenizer)
|
||||
@@ -72,30 +74,17 @@ class ReasonerGrammarObject(BaseGrammarObject):
|
||||
old_output_ids, new_output_ids, next_state
|
||||
)
|
||||
|
||||
def copy(self) -> BaseGrammarObject:
|
||||
return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id)
|
||||
|
||||
|
||||
class ReasonerGrammarBackend(BaseGrammarBackend):
|
||||
def __init__(self, grammar_backend: BaseGrammarBackend, think_end_id):
|
||||
super().__init__()
|
||||
self.grammar_backend = grammar_backend
|
||||
self.think_end_id = think_end_id
|
||||
|
||||
def get_cached_value(self, key: Tuple[str, str]) -> Optional[ReasonerGrammarObject]:
|
||||
grammar = self.grammar_backend.get_cached_value(key)
|
||||
return ReasonerGrammarObject(grammar, self.think_end_id) if grammar else None
|
||||
|
||||
def get_future_value(self, key: Tuple[str, str]) -> Future:
|
||||
grammar = Future()
|
||||
|
||||
def callback(f: Future):
|
||||
if result := f.result():
|
||||
grammar.set_result(ReasonerGrammarObject(result, self.think_end_id))
|
||||
else:
|
||||
grammar.set_result(None)
|
||||
|
||||
self.grammar_backend.get_future_value(key).add_done_callback(callback)
|
||||
return grammar
|
||||
|
||||
def reset(self):
|
||||
self.grammar_backend.reset()
|
||||
def _init_value_dispatch(
|
||||
self, key: Tuple[str, str]
|
||||
) -> Optional[ReasonerGrammarObject]:
|
||||
ret = self.grammar_backend._init_value_dispatch(key)
|
||||
if ret is None:
|
||||
return None
|
||||
return ReasonerGrammarObject(ret, self.think_end_id)
|
||||
|
||||
@@ -18,7 +18,6 @@ import logging
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import xgrammar
|
||||
from xgrammar import (
|
||||
CompiledGrammar,
|
||||
GrammarCompiler,
|
||||
@@ -35,7 +34,6 @@ from sglang.srt.constrained.base_grammar_backend import (
|
||||
from sglang.srt.constrained.triton_ops.bitmask_ops import (
|
||||
apply_token_bitmask_inplace_triton,
|
||||
)
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,49 +49,35 @@ class XGrammarGrammar(BaseGrammarObject):
|
||||
vocab_size: int,
|
||||
ctx: CompiledGrammar,
|
||||
override_stop_tokens: Optional[Union[List[int], int]],
|
||||
key_string: Optional[str] = None, # TODO (sk): for debugging, remove later
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.matcher = matcher
|
||||
self.vocab_size = vocab_size
|
||||
self.ctx = ctx
|
||||
self.override_stop_tokens = override_stop_tokens
|
||||
self.finished = False
|
||||
|
||||
from xgrammar.kernels.apply_token_bitmask_inplace_cpu import (
|
||||
apply_token_bitmask_inplace_cpu,
|
||||
)
|
||||
|
||||
self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_cpu
|
||||
self.accepted_tokens = []
|
||||
self.key_string = key_string
|
||||
|
||||
def accept_token(self, token: int):
|
||||
assert self.matcher.accept_token(token)
|
||||
|
||||
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
||||
s = self.matcher.find_jump_forward_string()
|
||||
if s:
|
||||
return [], s
|
||||
return None
|
||||
|
||||
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
||||
_, data = helper
|
||||
return data, -1
|
||||
|
||||
def jump_and_retokenize(
|
||||
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
||||
):
|
||||
k = 0
|
||||
for i, old_id in enumerate(old_output_ids):
|
||||
if old_id == new_output_ids[i]:
|
||||
k = i + 1
|
||||
if not self.is_terminated():
|
||||
accepted = self.matcher.accept_token(token)
|
||||
if not accepted:
|
||||
# log for debugging
|
||||
raise ValueError(
|
||||
f"Tokens not accepted: {token}\n"
|
||||
f"Accepted tokens: {self.accepted_tokens}\n"
|
||||
f"Key string: {self.key_string}"
|
||||
)
|
||||
else:
|
||||
break
|
||||
self.accepted_tokens.append(token)
|
||||
|
||||
# rollback to the last token that is the same
|
||||
if k < len(old_output_ids):
|
||||
self.matcher.rollback(len(old_output_ids) - k)
|
||||
def rollback(self, k: int):
|
||||
self.matcher.rollback(k)
|
||||
self.accepted_tokens = self.accepted_tokens[:-k]
|
||||
|
||||
for i in range(k, len(new_output_ids)):
|
||||
assert self.matcher.accept_token(new_output_ids[i])
|
||||
def is_terminated(self):
|
||||
return self.matcher.is_terminated()
|
||||
|
||||
def allocate_vocab_mask(
|
||||
self, vocab_size: int, batch_size: int, device
|
||||
@@ -122,9 +106,43 @@ class XGrammarGrammar(BaseGrammarObject):
|
||||
override_stop_tokens=self.override_stop_tokens,
|
||||
)
|
||||
return XGrammarGrammar(
|
||||
matcher, self.vocab_size, self.ctx, self.override_stop_tokens
|
||||
matcher,
|
||||
self.vocab_size,
|
||||
self.ctx,
|
||||
self.override_stop_tokens,
|
||||
self.key_string,
|
||||
)
|
||||
|
||||
def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
|
||||
s = self.matcher.find_jump_forward_string()
|
||||
if s:
|
||||
return [], s
|
||||
return None
|
||||
|
||||
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
||||
_, data = helper
|
||||
return data, -1
|
||||
|
||||
def jump_and_retokenize(
|
||||
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
||||
):
|
||||
k = 0
|
||||
for i, old_id in enumerate(old_output_ids):
|
||||
if old_id == new_output_ids[i]:
|
||||
k = i + 1
|
||||
else:
|
||||
break
|
||||
|
||||
# rollback to the last token that is the same
|
||||
if k < len(old_output_ids):
|
||||
self.matcher.rollback(len(old_output_ids) - k)
|
||||
|
||||
for i in range(k, len(new_output_ids)):
|
||||
assert self.matcher.accept_token(new_output_ids[i])
|
||||
|
||||
def __repr__(self):
|
||||
return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=})"
|
||||
|
||||
|
||||
class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
def __init__(
|
||||
@@ -143,9 +161,15 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
self.vocab_size = vocab_size
|
||||
self.override_stop_tokens = override_stop_tokens
|
||||
|
||||
def _from_context(self, ctx: CompiledGrammar) -> XGrammarGrammar:
|
||||
matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
|
||||
return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens)
|
||||
def _from_context(self, ctx: CompiledGrammar, key_string: str) -> XGrammarGrammar:
|
||||
matcher = GrammarMatcher(
|
||||
ctx,
|
||||
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
||||
override_stop_tokens=self.override_stop_tokens,
|
||||
)
|
||||
return XGrammarGrammar(
|
||||
matcher, self.vocab_size, ctx, self.override_stop_tokens, key_string
|
||||
)
|
||||
|
||||
def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||
try:
|
||||
@@ -157,7 +181,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
except RuntimeError as e:
|
||||
logging.warning(f"Skip invalid json_schema: json_schema={key_string}, {e=}")
|
||||
return None
|
||||
return self._from_context(ctx)
|
||||
return self._from_context(ctx, key_string)
|
||||
|
||||
def dispatch_ebnf(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||
try:
|
||||
@@ -165,7 +189,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
except RuntimeError as e:
|
||||
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
|
||||
return None
|
||||
return self._from_context(ctx)
|
||||
return self._from_context(ctx, key_string)
|
||||
|
||||
def dispatch_regex(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||
try:
|
||||
@@ -173,7 +197,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
except RuntimeError as e:
|
||||
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
|
||||
return None
|
||||
return self._from_context(ctx)
|
||||
return self._from_context(ctx, key_string)
|
||||
|
||||
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
|
||||
try:
|
||||
@@ -190,9 +214,11 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
tags, structural_tag["triggers"]
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
|
||||
logging.warning(
|
||||
f"Skip invalid structural_tag: structural_tag={key_string}, {e=}"
|
||||
)
|
||||
return None
|
||||
return self._from_context(ctx)
|
||||
return self._from_context(ctx, key_string)
|
||||
|
||||
def reset(self):
|
||||
if self.grammar_compiler:
|
||||
|
||||
@@ -239,10 +239,6 @@ def top_p_normalize_probs_torch(
|
||||
|
||||
|
||||
def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
|
||||
assert len(top_logprobs_nums) == logprobs.shape[0], (
|
||||
len(top_logprobs_nums),
|
||||
logprobs.shape[0],
|
||||
)
|
||||
max_k = max(top_logprobs_nums)
|
||||
ret = logprobs.topk(max_k, dim=1)
|
||||
values = ret.values.tolist()
|
||||
|
||||
@@ -533,6 +533,7 @@ class Req:
|
||||
|
||||
# Constrained decoding
|
||||
self.grammar: Optional[BaseGrammarObject] = None
|
||||
self.grammar_wait_ct = 0
|
||||
|
||||
# The number of cached tokens that were already cached in the KV cache
|
||||
self.cached_tokens = 0
|
||||
|
||||
@@ -149,6 +149,7 @@ logger = logging.getLogger(__name__)
|
||||
# Test retract decode for debugging purposes
|
||||
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
|
||||
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
|
||||
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1024,9 +1025,11 @@ class Scheduler(
|
||||
elif req.sampling_params.structural_tag:
|
||||
key = ("structural_tag", req.sampling_params.structural_tag)
|
||||
|
||||
req.grammar = self.grammar_backend.get_cached_value(key)
|
||||
if not req.grammar:
|
||||
req.grammar = self.grammar_backend.get_future_value(key)
|
||||
value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
|
||||
req.grammar = value
|
||||
|
||||
if not cache_hit:
|
||||
req.grammar_key = key
|
||||
add_to_grammar_queue = True
|
||||
|
||||
if add_to_grammar_queue:
|
||||
@@ -1208,6 +1211,7 @@ class Scheduler(
|
||||
self.stats.cache_hit_rate = 0.0
|
||||
self.stats.gen_throughput = self.last_gen_throughput
|
||||
self.stats.num_queue_reqs = len(self.waiting_queue)
|
||||
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
||||
self.stats.spec_accept_length = spec_accept_length
|
||||
self.metrics_collector.log_stats(self.stats)
|
||||
|
||||
@@ -1255,6 +1259,7 @@ class Scheduler(
|
||||
self.stats.token_usage = num_used / self.max_total_num_tokens
|
||||
self.stats.gen_throughput = 0
|
||||
self.stats.num_queue_reqs = len(self.waiting_queue)
|
||||
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
||||
self.metrics_collector.log_stats(self.stats)
|
||||
|
||||
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
||||
@@ -1715,11 +1720,17 @@ class Scheduler(
|
||||
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
||||
|
||||
num_ready_reqs = 0
|
||||
num_abort_reqs = 0
|
||||
for req in self.grammar_queue:
|
||||
try:
|
||||
req.grammar = req.grammar.result(timeout=0.05)
|
||||
req.grammar = req.grammar.result(timeout=0.03)
|
||||
if req.grammar:
|
||||
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
||||
num_ready_reqs += 1
|
||||
except futures._base.TimeoutError:
|
||||
req.grammar_wait_ct += 1
|
||||
if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
|
||||
num_abort_reqs = 1
|
||||
break
|
||||
|
||||
if self.server_args.enable_dp_attention:
|
||||
@@ -1731,14 +1742,28 @@ class Scheduler(
|
||||
|
||||
if tp_size > 1:
|
||||
# Sync across TP ranks to make sure they have the same number of ready requests
|
||||
tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
|
||||
tensor = torch.tensor([num_ready_reqs, num_abort_reqs], dtype=torch.int32)
|
||||
torch.distributed.all_reduce(
|
||||
tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
|
||||
)
|
||||
num_ready_reqs_max = tensor.item()
|
||||
num_ready_reqs_max, num_abort_reqs_max = tensor.tolist()
|
||||
|
||||
for i in range(num_ready_reqs, num_ready_reqs_max):
|
||||
self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
|
||||
num_ready_reqs = num_ready_reqs_max
|
||||
req = self.grammar_queue[i]
|
||||
req.grammar = req.grammar.result()
|
||||
if req.grammar:
|
||||
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
||||
|
||||
for i in range(num_ready_reqs, num_ready_reqs + num_abort_reqs_max):
|
||||
req = self.grammar_queue[i]
|
||||
req.grammar.cancel()
|
||||
req.grammar = None
|
||||
error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
|
||||
logger.error(error_msg)
|
||||
req.finished_reason = FINISH_ABORT(
|
||||
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
||||
)
|
||||
num_ready_reqs = num_ready_reqs_max + num_abort_reqs_max
|
||||
|
||||
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
|
||||
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
||||
|
||||
@@ -1230,11 +1230,18 @@ class TokenizerManager:
|
||||
state.last_completion_tokens = completion_tokens
|
||||
|
||||
if state.finished:
|
||||
has_grammar = (
|
||||
state.obj.sampling_params.get("json_schema", None)
|
||||
or state.obj.sampling_params.get("regex", None)
|
||||
or state.obj.sampling_params.get("ebnf", None)
|
||||
or state.obj.sampling_params.get("structural_tag", None)
|
||||
)
|
||||
self.metrics_collector.observe_one_finished_request(
|
||||
recv_obj.prompt_tokens[i],
|
||||
completion_tokens,
|
||||
recv_obj.cached_tokens[i],
|
||||
state.finished_time - state.created_time,
|
||||
has_grammar,
|
||||
)
|
||||
|
||||
def dump_requests(self, state: ReqState, out_dict: dict):
|
||||
|
||||
@@ -15,7 +15,119 @@
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Union
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
SGLANG_TEST_REQUEST_TIME_STATS = get_bool_env_var("SGLANG_TEST_REQUEST_TIME_STATS")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeStats:
|
||||
"""
|
||||
Store the timestamps for each stage of a request.
|
||||
|
||||
Unified: wait_queue -> forward -> completion
|
||||
Prefill: bootstrap_queue -> wait_queue -> forward -> transfer_queue -> completion
|
||||
Decode: prealloc_queue -> transfer_queue -> wait_queue -> forward -> completion
|
||||
"""
|
||||
|
||||
lb_entry_time: float = 0.0
|
||||
wait_queue_entry_time: float = 0.0
|
||||
forward_entry_time: float = 0.0
|
||||
completion_time: float = 0.0
|
||||
prefill_bootstrap_queue_entry_time: float = 0.0
|
||||
prefill_transfer_queue_entry_time: float = 0.0
|
||||
decode_prealloc_queue_entry_time: float = 0.0
|
||||
decode_transfer_queue_entry_time: float = 0.0
|
||||
|
||||
class RequestType(Enum):
|
||||
UNIFIED = "unified"
|
||||
PREFILL = "prefill"
|
||||
DECODE = "decode"
|
||||
INVALID = "invalid"
|
||||
|
||||
def __str__(self) -> str:
|
||||
# if unified
|
||||
_type = self.get_type()
|
||||
|
||||
if _type == self.RequestType.UNIFIED:
|
||||
queue_duration = self.forward_entry_time - self.wait_queue_entry_time
|
||||
forward_duration = self.completion_time - self.forward_entry_time
|
||||
|
||||
if SGLANG_TEST_REQUEST_TIME_STATS:
|
||||
assert (
|
||||
queue_duration >= 0 and forward_duration >= 0
|
||||
), f"queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
|
||||
|
||||
return f"queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.wait_queue_entry_time}"
|
||||
elif _type == self.RequestType.PREFILL:
|
||||
bootstrap_duration = (
|
||||
self.wait_queue_entry_time - self.prefill_bootstrap_queue_entry_time
|
||||
)
|
||||
|
||||
queue_duration = self.forward_entry_time - self.wait_queue_entry_time
|
||||
|
||||
forward_duration = self.completion_time - self.forward_entry_time
|
||||
|
||||
if SGLANG_TEST_REQUEST_TIME_STATS:
|
||||
assert (
|
||||
bootstrap_duration >= 0
|
||||
and queue_duration >= 0
|
||||
and forward_duration >= 0
|
||||
), f"bootstrap_duration={bootstrap_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
|
||||
return f"bootstrap_duration={self.format_duration(bootstrap_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.prefill_bootstrap_queue_entry_time}"
|
||||
# if decode
|
||||
elif _type == self.RequestType.DECODE:
|
||||
prealloc_duration = (
|
||||
self.decode_transfer_queue_entry_time
|
||||
- self.decode_prealloc_queue_entry_time
|
||||
)
|
||||
|
||||
transfer_duration = (
|
||||
self.wait_queue_entry_time - self.decode_transfer_queue_entry_time
|
||||
)
|
||||
queue_duration = self.forward_entry_time - self.wait_queue_entry_time
|
||||
forward_duration = self.completion_time - self.forward_entry_time
|
||||
|
||||
if SGLANG_TEST_REQUEST_TIME_STATS:
|
||||
assert (
|
||||
prealloc_duration >= 0
|
||||
and transfer_duration >= 0
|
||||
and queue_duration >= 0
|
||||
and forward_duration >= 0
|
||||
), f"prealloc_duration={prealloc_duration} < 0 or transfer_duration={transfer_duration} < 0 or queue_duration={queue_duration} < 0 or forward_duration={forward_duration} < 0"
|
||||
|
||||
return f"prealloc_duration={self.format_duration(prealloc_duration)}, transfer_duration={self.format_duration(transfer_duration)}, queue_duration={self.format_duration(queue_duration)}, forward_duration={self.format_duration(forward_duration)}, start_time={self.decode_prealloc_queue_entry_time}"
|
||||
else:
|
||||
return "Invalid Time Stats"
|
||||
|
||||
def format_duration(self, duration: float) -> str:
|
||||
return f"{duration * 1e3:.2f}ms"
|
||||
|
||||
def get_type(self) -> RequestType:
|
||||
"""Determine the type of request based on timestamp values."""
|
||||
if (
|
||||
self.prefill_bootstrap_queue_entry_time == 0.0
|
||||
and self.prefill_transfer_queue_entry_time == 0.0
|
||||
and self.decode_prealloc_queue_entry_time == 0.0
|
||||
and self.decode_transfer_queue_entry_time == 0.0
|
||||
):
|
||||
return self.RequestType.UNIFIED
|
||||
elif (
|
||||
self.prefill_bootstrap_queue_entry_time > 0.0
|
||||
and self.prefill_transfer_queue_entry_time > 0.0
|
||||
):
|
||||
return self.RequestType.PREFILL
|
||||
elif (
|
||||
self.decode_prealloc_queue_entry_time > 0.0
|
||||
and self.decode_transfer_queue_entry_time > 0.0
|
||||
and self.wait_queue_entry_time > 0.0
|
||||
):
|
||||
return self.RequestType.DECODE
|
||||
else:
|
||||
return self.RequestType.INVALID
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -26,15 +138,20 @@ class SchedulerStats:
|
||||
gen_throughput: float = 0.0
|
||||
num_queue_reqs: int = 0
|
||||
cache_hit_rate: float = 0.0
|
||||
num_grammar_queue_reqs: int = 0
|
||||
spec_accept_length: float = 0.0
|
||||
avg_request_queue_latency: float = 0.0
|
||||
num_prefill_prealloc_queue_reqs: int = 0
|
||||
num_prefill_infight_queue_reqs: int = 0
|
||||
num_decode_prealloc_queue_reqs: int = 0
|
||||
num_decode_transfer_queue_reqs: int = 0
|
||||
|
||||
|
||||
class SchedulerMetricsCollector:
|
||||
|
||||
def __init__(self, labels: Dict[str, str]) -> None:
|
||||
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
|
||||
from prometheus_client import Gauge, Histogram
|
||||
from prometheus_client import Counter, Gauge
|
||||
|
||||
self.labels = labels
|
||||
self.last_log_time = time.time()
|
||||
@@ -74,6 +191,13 @@ class SchedulerMetricsCollector:
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
|
||||
self.num_grammar_queue_reqs = Gauge(
|
||||
name="sglang:num_grammar_queue_reqs",
|
||||
documentation="The number of requests in the grammar waiting queue.",
|
||||
labelnames=labels.keys(),
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
|
||||
self.cache_hit_rate = Gauge(
|
||||
name="sglang:cache_hit_rate",
|
||||
documentation="The prefix cache hit rate.",
|
||||
@@ -95,28 +219,98 @@ class SchedulerMetricsCollector:
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
|
||||
# Disaggregation queue metrics
|
||||
self.num_prefill_prealloc_queue_reqs = Gauge(
|
||||
name="sglang:num_prefill_prealloc_queue_reqs",
|
||||
documentation="The number of requests in the prefill prealloc queue.",
|
||||
labelnames=labels.keys(),
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
|
||||
self.num_prefill_infight_queue_reqs = Gauge(
|
||||
name="sglang:num_prefill_infight_queue_reqs",
|
||||
documentation="The number of requests in the prefill infight queue.",
|
||||
labelnames=labels.keys(),
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
|
||||
self.num_decode_prealloc_queue_reqs = Gauge(
|
||||
name="sglang:num_decode_prealloc_queue_reqs",
|
||||
documentation="The number of requests in the decode prealloc queue.",
|
||||
labelnames=labels.keys(),
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
|
||||
self.num_decode_transfer_queue_reqs = Gauge(
|
||||
name="sglang:num_decode_transfer_queue_reqs",
|
||||
documentation="The number of requests in the decode transfer queue.",
|
||||
labelnames=labels.keys(),
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
|
||||
self.num_bootstrap_failed_reqs = Counter(
|
||||
name="sglang:num_bootstrap_failed_reqs",
|
||||
documentation="The number of bootstrap failed requests.",
|
||||
labelnames=labels.keys(),
|
||||
)
|
||||
|
||||
self.num_transfer_failed_reqs = Counter(
|
||||
name="sglang:num_transfer_failed_reqs",
|
||||
documentation="The number of transfer failed requests.",
|
||||
labelnames=labels.keys(),
|
||||
)
|
||||
|
||||
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
|
||||
# Convenience function for logging to gauge.
|
||||
gauge.labels(**self.labels).set(data)
|
||||
|
||||
def increment_bootstrap_failed_reqs(self) -> None:
|
||||
self.num_bootstrap_failed_reqs.labels(**self.labels).inc(1)
|
||||
|
||||
def increment_transfer_failed_reqs(self) -> None:
|
||||
self.num_transfer_failed_reqs.labels(**self.labels).inc(1)
|
||||
|
||||
def log_stats(self, stats: SchedulerStats) -> None:
|
||||
self._log_gauge(self.num_running_reqs, stats.num_running_reqs)
|
||||
self._log_gauge(self.num_used_tokens, stats.num_used_tokens)
|
||||
self._log_gauge(self.token_usage, stats.token_usage)
|
||||
self._log_gauge(self.gen_throughput, stats.gen_throughput)
|
||||
self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
|
||||
self._log_gauge(self.num_grammar_queue_reqs, stats.num_grammar_queue_reqs)
|
||||
self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
|
||||
self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
|
||||
self._log_gauge(self.avg_request_queue_latency, stats.avg_request_queue_latency)
|
||||
|
||||
# Disaggregation metrics
|
||||
self._log_gauge(
|
||||
self.num_prefill_prealloc_queue_reqs, stats.num_prefill_prealloc_queue_reqs
|
||||
)
|
||||
self._log_gauge(
|
||||
self.num_prefill_infight_queue_reqs, stats.num_prefill_infight_queue_reqs
|
||||
)
|
||||
self._log_gauge(
|
||||
self.num_decode_prealloc_queue_reqs, stats.num_decode_prealloc_queue_reqs
|
||||
)
|
||||
self._log_gauge(
|
||||
self.num_decode_transfer_queue_reqs, stats.num_decode_transfer_queue_reqs
|
||||
)
|
||||
|
||||
self.last_log_time = time.time()
|
||||
|
||||
|
||||
class TokenizerMetricsCollector:
|
||||
def __init__(self, labels: Dict[str, str]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
labels: Dict[str, str],
|
||||
bucket_time_to_first_token: Optional[List[float]] = None,
|
||||
bucket_inter_token_latency: Optional[List[float]] = None,
|
||||
bucket_e2e_request_latency: Optional[List[float]] = None,
|
||||
collect_tokens_histogram: bool = False,
|
||||
) -> None:
|
||||
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
|
||||
from prometheus_client import Counter, Histogram
|
||||
|
||||
self.labels = labels
|
||||
self.collect_tokens_histogram = collect_tokens_histogram
|
||||
|
||||
self.prompt_tokens_total = Counter(
|
||||
name="sglang:prompt_tokens_total",
|
||||
@@ -130,6 +324,66 @@ class TokenizerMetricsCollector:
|
||||
labelnames=labels.keys(),
|
||||
)
|
||||
|
||||
if collect_tokens_histogram:
|
||||
bucket_prompt_tokens = [
|
||||
100,
|
||||
300,
|
||||
500,
|
||||
700,
|
||||
1000,
|
||||
1500,
|
||||
2000,
|
||||
3000,
|
||||
4000,
|
||||
5000,
|
||||
6000,
|
||||
7000,
|
||||
8000,
|
||||
9000,
|
||||
10000,
|
||||
12000,
|
||||
15000,
|
||||
20000,
|
||||
22000,
|
||||
25000,
|
||||
30000,
|
||||
35000,
|
||||
40000,
|
||||
]
|
||||
self.prompt_tokens_histogram = Histogram(
|
||||
name="sglang:prompt_tokens_histogram",
|
||||
documentation="Histogram of prompt token length.",
|
||||
labelnames=labels.keys(),
|
||||
buckets=bucket_prompt_tokens,
|
||||
)
|
||||
bucket_generation_tokens = [
|
||||
100,
|
||||
300,
|
||||
500,
|
||||
1000,
|
||||
1200,
|
||||
1500,
|
||||
1700,
|
||||
2000,
|
||||
2500,
|
||||
3000,
|
||||
3500,
|
||||
4000,
|
||||
4500,
|
||||
5000,
|
||||
6000,
|
||||
7000,
|
||||
8000,
|
||||
9000,
|
||||
10000,
|
||||
]
|
||||
self.generation_tokens_histogram = Histogram(
|
||||
name="sglang:generation_tokens_histogram",
|
||||
documentation="Histogram of generation token length.",
|
||||
labelnames=labels.keys(),
|
||||
buckets=bucket_generation_tokens,
|
||||
)
|
||||
|
||||
self.cached_tokens_total = Counter(
|
||||
name="sglang:cached_tokens_total",
|
||||
documentation="Number of cached prompt tokens.",
|
||||
@@ -142,11 +396,14 @@ class TokenizerMetricsCollector:
|
||||
labelnames=labels.keys(),
|
||||
)
|
||||
|
||||
self.histogram_time_to_first_token = Histogram(
|
||||
name="sglang:time_to_first_token_seconds",
|
||||
documentation="Histogram of time to first token in seconds.",
|
||||
self.num_so_requests_total = Counter(
|
||||
name="sglang:num_so_requests_total",
|
||||
documentation="Number of structured output requests processed.",
|
||||
labelnames=labels.keys(),
|
||||
buckets=[
|
||||
)
|
||||
|
||||
if bucket_time_to_first_token is None:
|
||||
bucket_time_to_first_token = [
|
||||
0.1,
|
||||
0.2,
|
||||
0.4,
|
||||
@@ -165,14 +422,33 @@ class TokenizerMetricsCollector:
|
||||
100,
|
||||
200,
|
||||
400,
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
self.histogram_inter_token_latency_seconds = Histogram(
|
||||
name="sglang:inter_token_latency_seconds",
|
||||
documentation="Histogram of inter-token latency in seconds.",
|
||||
labelnames=labels.keys(),
|
||||
buckets=[
|
||||
if bucket_e2e_request_latency is None:
|
||||
bucket_e2e_request_latency = [
|
||||
0.1,
|
||||
0.2,
|
||||
0.4,
|
||||
0.6,
|
||||
0.8,
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
6,
|
||||
8,
|
||||
10,
|
||||
20,
|
||||
40,
|
||||
60,
|
||||
80,
|
||||
100,
|
||||
200,
|
||||
400,
|
||||
800,
|
||||
]
|
||||
|
||||
if bucket_inter_token_latency is None:
|
||||
bucket_inter_token_latency = [
|
||||
0.002,
|
||||
0.004,
|
||||
0.006,
|
||||
@@ -196,34 +472,27 @@ class TokenizerMetricsCollector:
|
||||
4.000,
|
||||
6.000,
|
||||
8.000,
|
||||
],
|
||||
]
|
||||
|
||||
self.histogram_time_to_first_token = Histogram(
|
||||
name="sglang:time_to_first_token_seconds",
|
||||
documentation="Histogram of time to first token in seconds.",
|
||||
labelnames=labels.keys(),
|
||||
buckets=bucket_time_to_first_token,
|
||||
)
|
||||
|
||||
self.histogram_inter_token_latency_seconds = Histogram(
|
||||
name="sglang:inter_token_latency_seconds",
|
||||
documentation="Histogram of inter-token latency in seconds.",
|
||||
labelnames=labels.keys(),
|
||||
buckets=bucket_inter_token_latency,
|
||||
)
|
||||
|
||||
self.histogram_e2e_request_latency = Histogram(
|
||||
name="sglang:e2e_request_latency_seconds",
|
||||
documentation="Histogram of End-to-end request latency in seconds",
|
||||
labelnames=labels.keys(),
|
||||
buckets=[
|
||||
0.1,
|
||||
0.2,
|
||||
0.4,
|
||||
0.6,
|
||||
0.8,
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
6,
|
||||
8,
|
||||
10,
|
||||
20,
|
||||
40,
|
||||
60,
|
||||
80,
|
||||
100,
|
||||
200,
|
||||
400,
|
||||
800,
|
||||
],
|
||||
buckets=bucket_e2e_request_latency,
|
||||
)
|
||||
|
||||
def _log_histogram(self, histogram, data: Union[int, float]) -> None:
|
||||
@@ -235,13 +504,19 @@ class TokenizerMetricsCollector:
|
||||
generation_tokens: int,
|
||||
cached_tokens: int,
|
||||
e2e_latency: float,
|
||||
has_grammar: bool,
|
||||
):
|
||||
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
|
||||
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
|
||||
if cached_tokens > 0:
|
||||
self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
|
||||
self.num_requests_total.labels(**self.labels).inc(1)
|
||||
if has_grammar:
|
||||
self.num_so_requests_total.labels(**self.labels).inc(1)
|
||||
self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
|
||||
if self.collect_tokens_histogram:
|
||||
self._log_histogram(self.prompt_tokens_histogram, prompt_tokens)
|
||||
self._log_histogram(self.generation_tokens_histogram, generation_tokens)
|
||||
|
||||
def observe_time_to_first_token(self, value: float):
|
||||
self.histogram_time_to_first_token.labels(**self.labels).observe(value)
|
||||
|
||||
Reference in New Issue
Block a user