diff --git a/python/sglang/srt/constrained/base_grammar_backend.py b/python/sglang/srt/constrained/base_grammar_backend.py new file mode 100644 index 000000000..e298b3d0c --- /dev/null +++ b/python/sglang/srt/constrained/base_grammar_backend.py @@ -0,0 +1,72 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""The baseclass of backends for grammar-guided constrained decoding.""" + +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from threading import Event, Lock +from typing import Any, Optional, Tuple + + +@dataclass +class CacheEntry: + value: Any + event: Event + + +class BaseGrammarObject: + pass + + +class BaseGrammarBackend: + def __init__(self): + self.executor = ThreadPoolExecutor() + self.cache = {} + self.cache_lock = Lock() + + def init_value(self, key: Tuple[str, str]) -> BaseGrammarObject: + with self.cache_lock: + if key in self.cache: + cache_hit = True + entry = self.cache[key] + else: + cache_hit = False + entry = CacheEntry(None, Event()) + self.cache[key] = entry + + if cache_hit: + entry.event.wait() + else: + entry.value = self.init_value_impl(key) + entry.event.set() + return entry.value.copy() + + def init_value_impl(self, key: Tuple[str, str]) -> BaseGrammarObject: + raise NotImplementedError() + + def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]: + with self.cache_lock: + entry = self.cache.get(key) + if not entry or not entry.event.is_set(): + return None + return self.cache[key].value.copy() + + def get_future_value(self, key: Tuple[str, str]) -> Future: + return self.executor.submit(self.init_value, key) + + def reset(self): + with self.cache_lock: + self.cache.clear() diff --git a/python/sglang/srt/constrained/base_tool_cache.py b/python/sglang/srt/constrained/base_tool_cache.py deleted file mode 100644 index 1910eb730..000000000 --- a/python/sglang/srt/constrained/base_tool_cache.py +++ /dev/null @@ -1,102 +0,0 @@ -""" -Copyright 2023-2024 SGLang Team -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -"""Base cache class for constrained decoding tools.""" - -import time -from dataclasses import dataclass -from threading import Event, Lock -from typing import Any, Dict, Tuple - - -@dataclass -class MapEntry: - event: Event - value: Any - - def __iter__(self): - return iter((self.event, self.value)) - - -class BaseToolCache: - - def __init__(self, enable=True): - self.enable: bool = enable - self.cache: Dict[str, MapEntry] = {} - self.metrics: Dict[str, Any] = {} - self.lock_cache: Lock = Lock() - self.lock_metrics: Lock = Lock() - self.reset() - - def reset(self): - with self.lock_cache: - self.cache = {} - with self.lock_metrics: - self.metrics = {"total": 0, "hit": 0, "avg_init_time": 0} - - def _init_with_timer(self, key) -> Tuple[Any, float]: - start = time.monotonic() - val = self.init_value(key) - init_time = time.monotonic() - start - return val, init_time - - def update_time(self, init_time): - with self.lock_metrics: - curr_total = self.metrics["total"] - new_total = curr_total + 1 - - # Update average init time without old_avg * old_total to avoid overflow. - self.metrics["avg_init_time"] = (init_time / new_total) + ( - curr_total / new_total - ) * self.metrics["avg_init_time"] - - def query(self, key): - if not self.enable: - value, init_time = self._init_with_timer(key) - self.update_time(init_time) - return value - - with self.lock_cache: - if key in self.cache: - entry = self.cache[key] - cache_hit = True - else: - entry = MapEntry(Event(), None) - self.cache[key] = entry - cache_hit = False - - with self.lock_metrics: - self.metrics["total"] += 1 - if cache_hit: - self.metrics["hit"] += 1 - - if cache_hit: - entry.event.wait() - else: - entry.value, init_time = self._init_with_timer(key) - self.update_time(init_time) - entry.event.set() - return entry.value - - def init_value(self, key): - raise NotImplementedError() - - def get_cache_hit_rate(self): - with self.lock_metrics: - return self.metrics["hit"] / max(self.metrics["total"], 1) - - def get_avg_init_time(self): - with self.lock_metrics: - return self.metrics["avg_init_time"] diff --git a/python/sglang/srt/constrained/outlines_backend.py b/python/sglang/srt/constrained/outlines_backend.py index 434d94b57..16f32d93f 100644 --- a/python/sglang/srt/constrained/outlines_backend.py +++ b/python/sglang/srt/constrained/outlines_backend.py @@ -17,20 +17,17 @@ limitations under the License. import json import logging -from concurrent.futures import Future, ThreadPoolExecutor from typing import Dict, List, Optional, Tuple, Union import torch -from interegular import InvalidSyntax, parse_pattern from outlines.fsm.guide import RegexGuide from outlines.models.transformers import TransformerTokenizer -from pydantic import BaseModel -from sglang.srt.constrained.base_tool_cache import BaseToolCache -from sglang.srt.constrained.outlines_jump_forward import ( - OutlinesJumpForwardCache, - OutlinesJumpForwardMap, +from sglang.srt.constrained.base_grammar_backend import ( + BaseGrammarBackend, + BaseGrammarObject, ) +from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap logger = logging.getLogger(__name__) @@ -41,6 +38,7 @@ except ImportError: # Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema, # which only accepts string schema as input. from outlines.fsm.json_schema import build_regex_from_schema + from pydantic import BaseModel def build_regex_from_object( object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None @@ -54,16 +52,15 @@ except ImportError: return build_regex_from_schema(schema, whitespace_pattern) -class OutlinesGrammar: +class OutlinesGrammar(BaseGrammarObject): def __init__( self, guide: RegexGuide, - state: int, jump_forward_map: Union[OutlinesJumpForwardMap, None], ) -> None: self.guide = guide - self.state = state self.jump_forward_map = jump_forward_map + self.state = 0 def accept_token(self, token: int): self.state = self.guide.get_next_state(self.state, token) @@ -105,46 +102,18 @@ class OutlinesGrammar: vocab_mask.fill_(1) vocab_mask[self.guide.get_next_instruction(self.state).tokens] = 0 + def copy(self): + return OutlinesGrammar(self.guide, self.jump_forward_map) -class OutlinesGrammarBackend: + +class OutlinesGrammarBackend(BaseGrammarBackend): def __init__( self, tokenizer, - whitespace_patterns: bool, + whitespace_pattern: bool, allow_jump_forward: bool, ): - self.executor = ThreadPoolExecutor() - self.grammar_cache = OutlinesCache( - tokenizer, - whitespace_pattern=whitespace_patterns, - ) - self.jump_forward_cache = ( - OutlinesJumpForwardCache() if allow_jump_forward else None - ) - - def _query(self, key: Tuple[str, str]) -> OutlinesGrammar: - guide, regex = self.grammar_cache.query(key) - jump_forward_map = ( - self.jump_forward_cache.query(regex) if self.jump_forward_cache else None - ) - return OutlinesGrammar(guide, 0, jump_forward_map) - - def query(self, key: Tuple[str, str]) -> Future: - return self.executor.submit(self._query, key) - - def reset(self): - self.grammar_cache.reset() - if self.jump_forward_cache: - self.jump_forward_cache.reset() - - -class OutlinesCache(BaseToolCache): - def __init__( - self, - tokenizer, - whitespace_pattern=None, - ): - super().__init__(enable=True) + super().__init__() try: self.outlines_tokenizer = TransformerTokenizer(tokenizer) @@ -167,9 +136,10 @@ class OutlinesCache(BaseToolCache): self.outlines_tokenizer.vocabulary = ( self.outlines_tokenizer.tokenizer.get_vocab() ) + self.allow_jump_forward = allow_jump_forward self.whitespace_pattern = whitespace_pattern - def init_value(self, key): + def init_value_impl(self, key: Tuple[str, str]) -> OutlinesGrammar: key_type, key_string = key if key_type == "json": try: @@ -186,18 +156,10 @@ class OutlinesCache(BaseToolCache): regex = key_string else: raise ValueError(f"Invalid key_type: {key_type}") - try: - parse_pattern(regex) - except InvalidSyntax as e: - logger.warning(f"skip invalid regex guide: {regex=}, {e=}") - return None, regex - ret = RegexGuide(regex, self.outlines_tokenizer), regex - return ret - - def _query(self, key: Tuple[str, str]): - guide, regex = self.grammar_cache.query(key) - jump_forward_map = ( - self.jump_forward_cache.query(regex) if self.jump_forward_cache else None - ) - return OutlinesGrammar(guide, 0, jump_forward_map) + guide = RegexGuide(regex, self.outlines_tokenizer) + if self.allow_jump_forward: + jump_forward_map = OutlinesJumpForwardMap(regex) + else: + jump_forward_map = None + return OutlinesGrammar(guide, jump_forward_map) diff --git a/python/sglang/srt/constrained/outlines_jump_forward.py b/python/sglang/srt/constrained/outlines_jump_forward.py index e3dd6b166..006b8dcd6 100644 --- a/python/sglang/srt/constrained/outlines_jump_forward.py +++ b/python/sglang/srt/constrained/outlines_jump_forward.py @@ -27,8 +27,6 @@ from interegular import InvalidSyntax from outlines.caching import cache as disk_cache from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm -from sglang.srt.constrained.base_tool_cache import BaseToolCache - IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" logger = logging.getLogger(__name__) @@ -42,92 +40,90 @@ class JumpEdge: byte_next_state: int = None +@disk_cache() +def init_state_to_jump_forward(regex_string): + try: + regex_pattern = interegular.parse_pattern(regex_string) + except InvalidSyntax as e: + logger.warning(f"skip invalid regex: {regex_string}, {e=}") + return + + byte_fsm = make_byte_level_fsm(regex_pattern.to_fsm().reduce(), keep_utf8=True) + regex_fsm, _ = make_deterministic_fsm(byte_fsm) + + fsm_info: FSMInfo = regex_fsm.fsm_info + + symbol_to_id = fsm_info.alphabet_symbol_mapping + id_to_symbol = {} + for symbol, id_ in symbol_to_id.items(): + id_to_symbol.setdefault(id_, []).append(symbol) + + transitions = fsm_info.transitions + + outgoings_ct = defaultdict(int) + # NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally + for s in fsm_info.finals: + outgoings_ct[s] = 1 + + state_to_jump_forward = {} + for (state, id_), next_state in transitions.items(): + if id_ == fsm_info.alphabet_anything_value: + # Arbitrarily symbol cannot be recognized as jump forward + continue + + symbols = id_to_symbol[id_] + for c in symbols: + if len(c) > 1: + # Skip byte level transitions like c = "5E" + continue + + outgoings_ct[state] += 1 + if outgoings_ct[state] > 1: + if state in state_to_jump_forward: + del state_to_jump_forward[state] + break + + state_to_jump_forward[state] = JumpEdge( + symbol=c, + symbol_next_state=next_state, + ) + + # Process the byte level jump forward + outgoings_ct = defaultdict(int) + for s in fsm_info.finals: + outgoings_ct[s] = 1 + + for (state, id_), next_state in transitions.items(): + if id_ == fsm_info.alphabet_anything_value: + continue + symbols = id_to_symbol[id_] + for c in symbols: + byte_ = None + if len(c) == 1 and ord(c) < 0x80: + # ASCII character + byte_ = ord(c) + elif len(c) > 1: + # FIXME: This logic is due to the leading \x00 + # https://github.com/outlines-dev/outlines/pull/930 + byte_ = int(symbols[0][1:], 16) + + if byte_ is not None: + outgoings_ct[state] += 1 + if outgoings_ct[state] > 1: + if state in state_to_jump_forward: + del state_to_jump_forward[state] + break + e = state_to_jump_forward.get(state, JumpEdge()) + e.byte = byte_ + e.byte_next_state = next_state + state_to_jump_forward[state] = e + + return state_to_jump_forward + + class OutlinesJumpForwardMap: def __init__(self, regex_string): - @disk_cache() - def _init_state_to_jump_forward(regex_string): - try: - regex_pattern = interegular.parse_pattern(regex_string) - except InvalidSyntax as e: - logger.warning(f"skip invalid regex: {regex_string}, {e=}") - self.state_to_jump_forward = None - return - - byte_fsm = make_byte_level_fsm( - regex_pattern.to_fsm().reduce(), keep_utf8=True - ) - regex_fsm, _ = make_deterministic_fsm(byte_fsm) - - fsm_info: FSMInfo = regex_fsm.fsm_info - - symbol_to_id = fsm_info.alphabet_symbol_mapping - id_to_symbol = {} - for symbol, id_ in symbol_to_id.items(): - id_to_symbol.setdefault(id_, []).append(symbol) - - transitions = fsm_info.transitions - - outgoings_ct = defaultdict(int) - # NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally - for s in fsm_info.finals: - outgoings_ct[s] = 1 - - state_to_jump_forward = {} - for (state, id_), next_state in transitions.items(): - if id_ == fsm_info.alphabet_anything_value: - # Arbitrarily symbol cannot be recognized as jump forward - continue - - symbols = id_to_symbol[id_] - for c in symbols: - if len(c) > 1: - # Skip byte level transitions like c = "5E" - continue - - outgoings_ct[state] += 1 - if outgoings_ct[state] > 1: - if state in state_to_jump_forward: - del state_to_jump_forward[state] - break - - state_to_jump_forward[state] = JumpEdge( - symbol=c, - symbol_next_state=next_state, - ) - - # Process the byte level jump forward - outgoings_ct = defaultdict(int) - for s in fsm_info.finals: - outgoings_ct[s] = 1 - - for (state, id_), next_state in transitions.items(): - if id_ == fsm_info.alphabet_anything_value: - continue - symbols = id_to_symbol[id_] - for c in symbols: - byte_ = None - if len(c) == 1 and ord(c) < 0x80: - # ASCII character - byte_ = ord(c) - elif len(c) > 1: - # FIXME: This logic is due to the leading \x00 - # https://github.com/outlines-dev/outlines/pull/930 - byte_ = int(symbols[0][1:], 16) - - if byte_ is not None: - outgoings_ct[state] += 1 - if outgoings_ct[state] > 1: - if state in state_to_jump_forward: - del state_to_jump_forward[state] - break - e = state_to_jump_forward.get(state, JumpEdge()) - e.byte = byte_ - e.byte_next_state = next_state - state_to_jump_forward[state] = e - - return state_to_jump_forward - - self.state_to_jump_forward = _init_state_to_jump_forward(regex_string) + self.state_to_jump_forward = init_state_to_jump_forward(regex_string) def jump_forward_symbol(self, state): jump_forward_str = "" @@ -164,18 +160,6 @@ class OutlinesJumpForwardMap: ) -class OutlinesJumpForwardCache(BaseToolCache): - def __init__(self): - super().__init__() - - def init_value(self, regex): - forward_map = OutlinesJumpForwardMap(regex) - if forward_map.state_to_jump_forward: - return forward_map - else: - return None - - def test_main(regex_string): jump_forward_map = OutlinesJumpForwardMap(regex_string) for state, e in jump_forward_map.state_to_jump_forward.items(): diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py index 812865da6..d0416ec3d 100644 --- a/python/sglang/srt/constrained/xgrammar_backend.py +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -15,38 +15,36 @@ limitations under the License. """Constrained decoding with xgrammar backend.""" -from concurrent.futures import Future, ThreadPoolExecutor from typing import List, Tuple import torch +from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher -try: - from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher - - import_error = None -except ImportError as e: - import_error = e - - class Dummy: - pass - - GrammarMatcher = CompiledGrammar = CachedGrammarCompiler = Dummy - +from sglang.srt.constrained.base_grammar_backend import ( + BaseGrammarBackend, + BaseGrammarObject, +) MAX_ROLLBACK_TOKENS = 10 -class XGrammarGrammar: +class XGrammarGrammar(BaseGrammarObject): - def __init__(self, matcher: GrammarMatcher, vocab_size: int) -> None: + def __init__( + self, matcher: GrammarMatcher, vocab_size: int, ctx: CompiledGrammar + ) -> None: self.matcher = matcher self.vocab_size = vocab_size + self.ctx = ctx def accept_token(self, token: int): assert self.matcher.accept_token(token) def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]: - return [], self.matcher.find_jump_forward_string() + s = self.matcher.find_jump_forward_string() + if s: + return [], s + return None def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]: _, data = helper @@ -77,51 +75,40 @@ class XGrammarGrammar: self.matcher.get_rejected_tokens_from_bitmask(bitmask, self.vocab_size) ] = 1 + def copy(self): + matcher = GrammarMatcher( + self.ctx, + max_rollback_tokens=MAX_ROLLBACK_TOKENS, + mask_vocab_size=self.vocab_size, + ) + return XGrammarGrammar(matcher, self.vocab_size, self.ctx) -class XGrammarGrammarBackend: + +class XGrammarGrammarBackend(BaseGrammarBackend): def __init__( self, tokenizer, vocab_size: int, ): - if import_error: - raise import_error - - self.executor = ThreadPoolExecutor() - self.grammar_cache = XGrammarCache(tokenizer, vocab_size) - self.vocab_size = vocab_size - - def _query(self, key: Tuple[str, str]) -> XGrammarGrammar: - return XGrammarGrammar(self.grammar_cache.query(key), self.vocab_size) - - def query(self, key: Tuple[str, str]) -> Future: - return self.executor.submit(self._query, key) - - def reset(self): - self.grammar_cache.reset() - - -class XGrammarCache: - def __init__(self, tokenizer, vocab_size: int): + super().__init__() self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer) self.vocab_size = vocab_size - def get_context(self, key: Tuple[str, str]) -> CompiledGrammar: + def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar: key_type, key_string = key if key_type == "json": - return self.grammar_cache.get_compiled_grammar_for_json_schema(key_string) + ctx = self.grammar_cache.get_compiled_grammar_for_json_schema(key_string) elif key_type == "regex": raise ValueError("regex hasn't been supported by xgrammar yet") else: raise ValueError(f"Invalid key_type: {key_type}") - def query(self, key: Tuple[str, str]) -> GrammarMatcher: - ctx = self.get_context(key) - return GrammarMatcher( + matcher = GrammarMatcher( ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS, mask_vocab_size=self.vocab_size, ) + return XGrammarGrammar(matcher, self.vocab_size, ctx) def reset(self): self.grammar_cache.clear() diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a3bcbeba5..0d57abdd2 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -37,6 +37,7 @@ import torch from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool @@ -248,7 +249,7 @@ class Req: self.embedding = None # Constrained decoding - self.grammar = None + self.grammar: Optional[BaseGrammarObject] = None # The number of cached tokens, that were already cached in the KV cache self.cached_tokens = 0 diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 2cdb69f1a..176a1f2f5 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -244,7 +244,7 @@ class Scheduler: self.grammar_backend = OutlinesGrammarBackend( self.tokenizer, - whitespace_patterns=server_args.constrained_json_whitespace_pattern, + whitespace_pattern=server_args.constrained_json_whitespace_pattern, allow_jump_forward=not server_args.disable_jump_forward, ) elif server_args.grammar_backend == "xgrammar": @@ -467,21 +467,6 @@ class Scheduler: # By default, only return the logprobs for output tokens req.logprob_start_len = len(recv_req.input_ids) - 1 - # Init grammar cache for this request - if ( - req.sampling_params.json_schema is not None - or req.sampling_params.regex is not None - ): - assert self.grammar_backend is not None - if req.sampling_params.json_schema is not None: - req.grammar = self.grammar_backend.query( - ("json", req.sampling_params.json_schema), - ) - elif req.sampling_params.regex is not None: - req.grammar = self.grammar_backend.query( - ("regex", req.sampling_params.regex) - ) - # Truncate prompts that are too long if len(req.origin_input_ids) > self.max_req_input_len: logger.warning( @@ -499,7 +484,24 @@ class Scheduler: self.max_req_len - len(req.origin_input_ids) - 1, ) - if req.grammar is not None: + # Init grammar cache for this request + add_to_grammar_queue = False + if ( + req.sampling_params.json_schema is not None + or req.sampling_params.regex is not None + ): + assert self.grammar_backend is not None + if req.sampling_params.json_schema is not None: + key = ("json", req.sampling_params.json_schema) + elif req.sampling_params.regex is not None: + key = ("regex", req.sampling_params.regex) + + req.grammar = self.grammar_backend.get_cached_value(key) + if not req.grammar: + req.grammar = self.grammar_backend.get_future_value(key) + add_to_grammar_queue = True + + if add_to_grammar_queue: self.grammar_queue.append(req) else: self.waiting_queue.append(req) @@ -650,14 +652,7 @@ class Scheduler: def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: # Check if the grammar is ready in the grammar queue if self.grammar_queue: - new_grammar_queue = [] - for req in self.grammar_queue: - try: - req.grammar = req.grammar.result(timeout=0.05) - self.waiting_queue.append(req) - except futures._base.TimeoutError: - new_grammar_queue.append(req) - self.grammar_queue = new_grammar_queue + self.move_ready_grammar_requests() # Handle the cases where prefill is not allowed if ( @@ -1145,6 +1140,30 @@ class Scheduler: ) ) + def move_ready_grammar_requests(self): + """Move requests whose grammar objects are ready from grammar_queue to waiting_queue.""" + num_ready_reqs = 0 + for req in self.grammar_queue: + try: + req.grammar = req.grammar.result(timeout=0.05) + num_ready_reqs += 1 + except futures._base.TimeoutError: + break + + if self.tp_size > 1: + # Sync across TP ranks to make sure they have the same number of ready requests + tensor = torch.tensor(num_ready_reqs, dtype=torch.int32) + torch.distributed.all_reduce( + tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group + ) + num_ready_reqs_max = tensor.item() + for i in range(num_ready_reqs, num_ready_reqs_max): + self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result() + num_ready_reqs = num_ready_reqs_max + + self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs]) + self.grammar_queue = self.grammar_queue[num_ready_reqs:] + def flush_cache(self): """Flush the memory pool and cache.""" if len(self.waiting_queue) == 0 and ( @@ -1152,9 +1171,8 @@ class Scheduler: ): self.tree_cache.reset() self.tree_cache_metrics = {"total": 0, "hit": 0} - if self.grammar_backend is not None: + if self.grammar_backend: self.grammar_backend.reset() - # TODO(dark): reset the bnf cache self.req_to_token_pool.clear() self.token_to_kv_pool.clear() torch.cuda.empty_cache()