[Performance] Support both xgrammar and outlines for constrained decoding (#1752)

This commit is contained in:
DarkSharpness
2024-10-26 06:47:02 +09:00
committed by GitHub
parent 30643fed7f
commit b77a02cdfd
7 changed files with 325 additions and 77 deletions

View File

@@ -37,8 +37,7 @@ import torch
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.constrained.grammar import Grammar
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
@@ -247,9 +246,7 @@ class Req:
self.embedding = None
# Constrained decoding
self.regex_fsm: RegexGuide = None
self.regex_fsm_state: int = 0
self.jump_forward_map: JumpForwardMap = None
self.grammar: Optional[Grammar] = None
# For Qwen2-VL
self.mrope_position_delta = [] # use mutable object
@@ -359,6 +356,8 @@ class Req:
return
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
assert self.grammar is not None and self.tokenizer is not None
if self.origin_input_text is None:
# Recovering text can only use unpadded ids
self.origin_input_text = self.tokenizer.decode(
@@ -398,7 +397,8 @@ class Req:
self.surr_offset = self.read_offset - i
break
self.regex_fsm_state = next_state
# update the inner state of the grammar
self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
if self.return_logprob:
# For fast-forward part's logprobs
@@ -468,8 +468,8 @@ class ScheduleBatch:
# Stream
has_stream: bool = False
# Has regex
has_regex: bool = False
# Has grammar
has_grammar: bool = False
# device
device: str = "cuda"
@@ -477,7 +477,7 @@ class ScheduleBatch:
@classmethod
def init_new(
cls,
reqs,
reqs: List[Req],
req_to_token_pool,
token_to_kv_pool,
tree_cache,
@@ -491,7 +491,7 @@ class ScheduleBatch:
model_config=model_config,
return_logprob=any(req.return_logprob for req in reqs),
has_stream=any(req.stream for req in reqs),
has_regex=any(req.regex_fsm for req in reqs),
has_grammar=any(req.grammar for req in reqs),
device=req_to_token_pool.device,
)
@@ -803,26 +803,10 @@ class ScheduleBatch:
keep_indices = set(i for i in range(len(self.reqs)))
for i, req in enumerate(self.reqs):
if req.jump_forward_map is not None:
jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
req.regex_fsm_state
)
if jump_forward_bytes is not None and len(jump_forward_bytes) > 1:
suffix_bytes = []
continuation_range = range(0x80, 0xC0)
cur_state = req.regex_fsm_state
while (
len(jump_forward_bytes)
and jump_forward_bytes[0][0] in continuation_range
):
# continuation bytes
byte_edge = jump_forward_bytes.pop(0)
suffix_bytes.append(byte_edge[0])
cur_state = byte_edge[1]
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens)
if req.grammar is not None:
jump_helper = req.grammar.try_jump(req.tokenizer)
if jump_helper.can_jump():
suffix_ids = jump_helper.suffix_ids
# Current ids, for cache and revert
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
cur_output_ids = req.output_ids
@@ -836,10 +820,8 @@ class ScheduleBatch:
(
jump_forward_str,
next_state,
) = req.jump_forward_map.jump_forward_symbol(cur_state)
) = req.grammar.jump_forward_str_state(jump_helper)
# Make the incrementally decoded text part of jump_forward_str
# so that the UTF-8 will not corrupt
jump_forward_str = new_text + jump_forward_str
if not req.jump_forward_and_retokenize(
jump_forward_str, next_state
@@ -946,7 +928,7 @@ class ScheduleBatch:
self.top_logprobs_nums = None
self.has_stream = any(req.stream for req in self.reqs)
self.has_regex = any(req.regex_fsm for req in self.reqs)
self.has_grammar = any(req.grammar for req in self.reqs)
self.sampling_info.filter_batch(keep_indices, new_indices)
@@ -979,7 +961,7 @@ class ScheduleBatch:
self.return_logprob = self.return_logprob or other.return_logprob
self.has_stream = self.has_stream or other.has_stream
self.has_regex = self.has_regex or other.has_regex
self.has_grammar = self.has_grammar or other.has_grammar
def get_model_worker_batch(self):
if self.forward_mode.is_decode():
@@ -989,13 +971,10 @@ class ScheduleBatch:
extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens
if self.has_regex:
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
self.sampling_info.regex_fsm_states = [
req.regex_fsm_state for req in self.reqs
]
if self.has_grammar:
self.sampling_info.grammars = [req.grammar for req in self.reqs]
else:
self.sampling_info.regex_fsms = None
self.sampling_info.grammars = None
global bid
bid += 1