support parallel grammar preprocessing (#1996)

Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
DarkSharpness
2024-11-13 01:45:28 +09:00
committed by GitHub
parent eff468dd5a
commit 125b1199c5
9 changed files with 159 additions and 141 deletions

View File

@@ -29,7 +29,7 @@ import zmq
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.grammar import GrammarCache
from sglang.srt.constrained.grammar import GrammarBackend
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
@@ -234,11 +234,12 @@ class Scheduler:
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
)
# Init the FSM cache for constrained generation
# Init the grammar cache for constrained generation
self.grammar_cache = None
self.grammar_queue: List[Req] = []
if not server_args.skip_tokenizer_init:
self.grammar_cache = GrammarCache(
self.grammar_cache = GrammarBackend(
server_args.tokenizer_path,
{
"tokenizer_mode": server_args.tokenizer_mode,
@@ -455,7 +456,7 @@ class Scheduler:
# By default, only return the logprobs for output tokens
req.logprob_start_len = len(recv_req.input_ids) - 1
# Init regex FSM or BNF
# Init grammar cache for this request
if (
req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None
@@ -488,7 +489,10 @@ class Scheduler:
self.max_req_len - len(req.origin_input_ids) - 1,
)
self.waiting_queue.append(req)
if req.grammar is not None:
self.grammar_queue.append(req)
else:
self.waiting_queue.append(req)
def handle_embedding_request(
self,
@@ -634,6 +638,17 @@ class Scheduler:
return self.running_batch
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
# Check if the grammar queue is ready
if self.grammar_queue:
new_grammar_queue = []
for req in self.grammar_queue:
if req.grammar.done():
req.grammar = req.grammar.result()
self.waiting_queue.append(req)
else:
new_grammar_queue.append(req)
self.grammar_queue = new_grammar_queue
# Handle the cases where prefill is not allowed
if (
self.batch_is_full or len(self.waiting_queue) == 0