[Performance] Support both xgrammar and outlines for constrained decoding (#1752)

This commit is contained in:
DarkSharpness
2024-10-26 06:47:02 +09:00
committed by GitHub
parent 30643fed7f
commit b77a02cdfd
7 changed files with 325 additions and 77 deletions

View File

@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, List, Optional
import torch
import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.grammar import Grammar
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch
@@ -29,11 +29,9 @@ class SamplingBatchInfo:
# Bias Tensors
vocab_size: int
logit_bias: torch.Tensor = None
vocab_mask: torch.Tensor = None
vocab_mask: Optional[torch.Tensor] = None
# FSM states
regex_fsms: List[RegexGuide] = None
regex_fsm_states: List[int] = None
grammars: Optional[List[Optional[Grammar]]] = None
# Penalizer
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
@@ -136,8 +134,7 @@ class SamplingBatchInfo:
self.linear_penalties = penalizer.apply(self.linear_penalties)
def update_regex_vocab_mask(self):
has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)
if not has_regex:
if not self.grammars or not any(grammar for grammar in self.grammars):
self.vocab_mask = None
return
@@ -147,12 +144,9 @@ class SamplingBatchInfo:
dtype=torch.bool,
device=self.device,
)
for i, regex_fsm in enumerate(self.regex_fsms):
if regex_fsm is not None:
self.vocab_mask[i].fill_(1)
self.vocab_mask[i][
regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
] = 0
for i, grammar in enumerate(self.grammars):
if grammar is not None:
grammar.fill_vocab_mask(self.vocab_mask[i], self.vocab_size)
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
if self.penalizer_orchestrator: