Fix grammar backend (#2018)

This commit is contained in:
Lianmin Zheng
2024-11-12 21:17:38 -08:00
committed by GitHub
parent 125b1199c5
commit ba069a24d3
13 changed files with 401 additions and 434 deletions

View File

@@ -21,6 +21,7 @@ import threading
import time
import warnings
from collections import deque
from concurrent import futures
from types import SimpleNamespace
from typing import List, Optional
@@ -29,7 +30,6 @@ import zmq
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.grammar import GrammarBackend
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
@@ -100,7 +100,7 @@ class Scheduler:
self.tp_rank = tp_rank
self.tp_size = server_args.tp_size
self.schedule_policy = server_args.schedule_policy
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
self.disable_jump_forward = server_args.disable_jump_forward
self.lora_paths = server_args.lora_paths
self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = server_args.enable_overlap_schedule
@@ -234,22 +234,33 @@ class Scheduler:
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
)
# Init the grammar cache for constrained generation
self.grammar_cache = None
# Init the grammar backend for constrained generation
self.grammar_queue: List[Req] = []
if not server_args.skip_tokenizer_init:
self.grammar_cache = GrammarBackend(
server_args.tokenizer_path,
{
"tokenizer_mode": server_args.tokenizer_mode,
"trust_remote_code": server_args.trust_remote_code,
},
skip_tokenizer_init=server_args.skip_tokenizer_init,
whitespace_patterns=server_args.constrained_json_whitespace_pattern,
backend=server_args.grammar_backend,
allow_jump=not server_args.disable_regex_jump_forward,
)
if server_args.grammar_backend == "outlines":
from sglang.srt.constrained.outlines_backend import (
OutlinesGrammarBackend,
)
self.grammar_backend = OutlinesGrammarBackend(
self.tokenizer,
whitespace_patterns=server_args.constrained_json_whitespace_pattern,
allow_jump_forward=not server_args.disable_jump_forward,
)
elif server_args.grammar_backend == "xgrammar":
from sglang.srt.constrained.xgrammar_backend import (
XGrammarGrammarBackend,
)
self.grammar_backend = XGrammarGrammarBackend(
self.tokenizer, vocab_size=self.model_config.vocab_size
)
else:
raise ValueError(
f"Invalid grammar backend: {server_args.grammar_backend}"
)
else:
self.grammar_backend = None
# Init new token estimation
assert (
@@ -461,15 +472,14 @@ class Scheduler:
req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None
):
assert self.grammar_cache is not None
assert self.grammar_backend is not None
if req.sampling_params.json_schema is not None:
req.grammar = self.grammar_cache.query(
req.grammar = self.grammar_backend.query(
("json", req.sampling_params.json_schema),
self.model_config.vocab_size,
)
elif req.sampling_params.regex is not None:
req.grammar = self.grammar_cache.query(
("regex", req.sampling_params.regex), self.model_config.vocab_size
req.grammar = self.grammar_backend.query(
("regex", req.sampling_params.regex)
)
# Truncate prompts that are too long
@@ -638,14 +648,14 @@ class Scheduler:
return self.running_batch
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
# Check if the grammar queue is ready
# Check if the grammar is ready in the grammar queue
if self.grammar_queue:
new_grammar_queue = []
for req in self.grammar_queue:
if req.grammar.done():
req.grammar = req.grammar.result()
try:
req.grammar = req.grammar.result(timeout=0.05)
self.waiting_queue.append(req)
else:
except futures._base.TimeoutError:
new_grammar_queue.append(req)
self.grammar_queue = new_grammar_queue
@@ -783,7 +793,7 @@ class Scheduler:
)
# Check for jump-forward
if not self.disable_regex_jump_forward:
if not self.disable_jump_forward:
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
self.waiting_queue.extend(jump_forward_reqs)
if batch.is_empty():
@@ -1142,8 +1152,8 @@ class Scheduler:
):
self.tree_cache.reset()
self.tree_cache_metrics = {"total": 0, "hit": 0}
if self.grammar_cache is not None:
self.grammar_cache.reset()
if self.grammar_backend is not None:
self.grammar_backend.reset()
# TODO(dark): reset the bnf cache
self.req_to_token_pool.clear()
self.token_to_kv_pool.clear()