Fix grammar backend (#2018)
This commit is contained in:
@@ -21,6 +21,7 @@ import threading
|
||||
import time
|
||||
import warnings
|
||||
from collections import deque
|
||||
from concurrent import futures
|
||||
from types import SimpleNamespace
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -29,7 +30,6 @@ import zmq
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.constrained.grammar import GrammarBackend
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.io_struct import (
|
||||
@@ -100,7 +100,7 @@ class Scheduler:
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = server_args.tp_size
|
||||
self.schedule_policy = server_args.schedule_policy
|
||||
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
||||
self.disable_jump_forward = server_args.disable_jump_forward
|
||||
self.lora_paths = server_args.lora_paths
|
||||
self.max_loras_per_batch = server_args.max_loras_per_batch
|
||||
self.enable_overlap = server_args.enable_overlap_schedule
|
||||
@@ -234,22 +234,33 @@ class Scheduler:
|
||||
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
|
||||
)
|
||||
|
||||
# Init the grammar cache for constrained generation
|
||||
self.grammar_cache = None
|
||||
# Init the grammar backend for constrained generation
|
||||
self.grammar_queue: List[Req] = []
|
||||
|
||||
if not server_args.skip_tokenizer_init:
|
||||
self.grammar_cache = GrammarBackend(
|
||||
server_args.tokenizer_path,
|
||||
{
|
||||
"tokenizer_mode": server_args.tokenizer_mode,
|
||||
"trust_remote_code": server_args.trust_remote_code,
|
||||
},
|
||||
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
||||
whitespace_patterns=server_args.constrained_json_whitespace_pattern,
|
||||
backend=server_args.grammar_backend,
|
||||
allow_jump=not server_args.disable_regex_jump_forward,
|
||||
)
|
||||
if server_args.grammar_backend == "outlines":
|
||||
from sglang.srt.constrained.outlines_backend import (
|
||||
OutlinesGrammarBackend,
|
||||
)
|
||||
|
||||
self.grammar_backend = OutlinesGrammarBackend(
|
||||
self.tokenizer,
|
||||
whitespace_patterns=server_args.constrained_json_whitespace_pattern,
|
||||
allow_jump_forward=not server_args.disable_jump_forward,
|
||||
)
|
||||
elif server_args.grammar_backend == "xgrammar":
|
||||
from sglang.srt.constrained.xgrammar_backend import (
|
||||
XGrammarGrammarBackend,
|
||||
)
|
||||
|
||||
self.grammar_backend = XGrammarGrammarBackend(
|
||||
self.tokenizer, vocab_size=self.model_config.vocab_size
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid grammar backend: {server_args.grammar_backend}"
|
||||
)
|
||||
else:
|
||||
self.grammar_backend = None
|
||||
|
||||
# Init new token estimation
|
||||
assert (
|
||||
@@ -461,15 +472,14 @@ class Scheduler:
|
||||
req.sampling_params.json_schema is not None
|
||||
or req.sampling_params.regex is not None
|
||||
):
|
||||
assert self.grammar_cache is not None
|
||||
assert self.grammar_backend is not None
|
||||
if req.sampling_params.json_schema is not None:
|
||||
req.grammar = self.grammar_cache.query(
|
||||
req.grammar = self.grammar_backend.query(
|
||||
("json", req.sampling_params.json_schema),
|
||||
self.model_config.vocab_size,
|
||||
)
|
||||
elif req.sampling_params.regex is not None:
|
||||
req.grammar = self.grammar_cache.query(
|
||||
("regex", req.sampling_params.regex), self.model_config.vocab_size
|
||||
req.grammar = self.grammar_backend.query(
|
||||
("regex", req.sampling_params.regex)
|
||||
)
|
||||
|
||||
# Truncate prompts that are too long
|
||||
@@ -638,14 +648,14 @@ class Scheduler:
|
||||
return self.running_batch
|
||||
|
||||
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
||||
# Check if the grammar queue is ready
|
||||
# Check if the grammar is ready in the grammar queue
|
||||
if self.grammar_queue:
|
||||
new_grammar_queue = []
|
||||
for req in self.grammar_queue:
|
||||
if req.grammar.done():
|
||||
req.grammar = req.grammar.result()
|
||||
try:
|
||||
req.grammar = req.grammar.result(timeout=0.05)
|
||||
self.waiting_queue.append(req)
|
||||
else:
|
||||
except futures._base.TimeoutError:
|
||||
new_grammar_queue.append(req)
|
||||
self.grammar_queue = new_grammar_queue
|
||||
|
||||
@@ -783,7 +793,7 @@ class Scheduler:
|
||||
)
|
||||
|
||||
# Check for jump-forward
|
||||
if not self.disable_regex_jump_forward:
|
||||
if not self.disable_jump_forward:
|
||||
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
||||
self.waiting_queue.extend(jump_forward_reqs)
|
||||
if batch.is_empty():
|
||||
@@ -1142,8 +1152,8 @@ class Scheduler:
|
||||
):
|
||||
self.tree_cache.reset()
|
||||
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||
if self.grammar_cache is not None:
|
||||
self.grammar_cache.reset()
|
||||
if self.grammar_backend is not None:
|
||||
self.grammar_backend.reset()
|
||||
# TODO(dark): reset the bnf cache
|
||||
self.req_to_token_pool.clear()
|
||||
self.token_to_kv_pool.clear()
|
||||
|
||||
Reference in New Issue
Block a user