[Performance] Support both xgrammar and outlines for constrained decoding (#1752)
This commit is contained in:
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
import torch
|
||||
|
||||
import sglang.srt.sampling.penaltylib as penaltylib
|
||||
from sglang.srt.constrained import RegexGuide
|
||||
from sglang.srt.constrained.grammar import Grammar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
@@ -29,11 +29,9 @@ class SamplingBatchInfo:
|
||||
# Bias Tensors
|
||||
vocab_size: int
|
||||
logit_bias: torch.Tensor = None
|
||||
vocab_mask: torch.Tensor = None
|
||||
vocab_mask: Optional[torch.Tensor] = None
|
||||
|
||||
# FSM states
|
||||
regex_fsms: List[RegexGuide] = None
|
||||
regex_fsm_states: List[int] = None
|
||||
grammars: Optional[List[Optional[Grammar]]] = None
|
||||
|
||||
# Penalizer
|
||||
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
||||
@@ -136,8 +134,7 @@ class SamplingBatchInfo:
|
||||
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
||||
|
||||
def update_regex_vocab_mask(self):
|
||||
has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)
|
||||
if not has_regex:
|
||||
if not self.grammars or not any(grammar for grammar in self.grammars):
|
||||
self.vocab_mask = None
|
||||
return
|
||||
|
||||
@@ -147,12 +144,9 @@ class SamplingBatchInfo:
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
)
|
||||
for i, regex_fsm in enumerate(self.regex_fsms):
|
||||
if regex_fsm is not None:
|
||||
self.vocab_mask[i].fill_(1)
|
||||
self.vocab_mask[i][
|
||||
regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
|
||||
] = 0
|
||||
for i, grammar in enumerate(self.grammars):
|
||||
if grammar is not None:
|
||||
grammar.fill_vocab_mask(self.vocab_mask[i], self.vocab_size)
|
||||
|
||||
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
||||
if self.penalizer_orchestrator:
|
||||
|
||||
Reference in New Issue
Block a user