From 47f20da223c62473577231cec49dedb86c56220f Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Sun, 1 Sep 2024 21:50:58 -0700 Subject: [PATCH] Fix regex mask (#1296) --- python/sglang/srt/layers/sampler.py | 2 +- python/sglang/srt/sampling/sampling_batch_info.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 6cb7d0a7c..f56fee828 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -63,7 +63,7 @@ class Sampler(CustomOp): logits.add_(sampling_info.logit_bias) if sampling_info.vocab_mask is not None: - logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf")) + logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf")) logits = self._apply_penalties(logits, sampling_info) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 7843f4bd3..38b6701c7 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -154,15 +154,15 @@ class SamplingBatchInfo: self.vocab_mask = None if has_regex: + self.vocab_mask = torch.zeros( + bs, self.vocab_size, dtype=torch.bool, device=device + ) for i, req in enumerate(reqs): if req.regex_fsm is not None: - if self.vocab_mask is None: - self.vocab_mask = torch.zeros( - bs, self.vocab_size, dtype=torch.bool, device=device - ) + self.vocab_mask[i].fill_(1) self.vocab_mask[i][ req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens - ] = 1 + ] = 0 def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor): self.penalizer_orchestrator.filter(unfinished_indices, new_indices)