Fix grammar backend (#2018)

This commit is contained in:
Lianmin Zheng
2024-11-12 21:17:38 -08:00
committed by GitHub
parent 125b1199c5
commit ba069a24d3
13 changed files with 401 additions and 434 deletions

View File

@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, List, Optional
import torch
import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.constrained.grammar import Grammar
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch
@@ -31,7 +30,7 @@ class SamplingBatchInfo:
logit_bias: torch.Tensor = None
vocab_mask: Optional[torch.Tensor] = None
grammars: Optional[List[Optional[Grammar]]] = None
grammars: Optional[List] = None
# Penalizer
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
@@ -146,7 +145,7 @@ class SamplingBatchInfo:
)
for i, grammar in enumerate(self.grammars):
if grammar is not None:
grammar.fill_vocab_mask(self.vocab_mask[i], self.vocab_size)
grammar.fill_vocab_mask(self.vocab_mask[i])
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
if self.penalizer_orchestrator: