From b564835364e13979226faa6d56ba6d70e07caa9f Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 2 Oct 2024 13:19:44 -0700 Subject: [PATCH] [Fix] do not maintain regex_fsm in SamplingBatchInfo (#1555) --- .../srt/sampling/sampling_batch_info.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 7d4f39e68..606f11d98 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -59,7 +59,6 @@ class SamplingBatchInfo: [r.sampling_params.min_p for r in reqs], dtype=torch.float ) - ret.regex_fsms = [r.regex_fsm for r in reqs] # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs) @@ -85,6 +84,10 @@ class SamplingBatchInfo: # Handle logit bias but only allocate when needed ret.logit_bias = None + # This is only for regex_fsm. We notice a regression if we maintain the list of regex_fsm + # in SamplingBatchInfo, so we keep it here. + ret.schedule_batch = batch + return ret def __len__(self): @@ -110,18 +113,20 @@ class SamplingBatchInfo: self.linear_penalties = penalizer.apply(self.linear_penalties) def update_regex_vocab_mask(self): + has_regex = any(req.regex_fsm is not None for req in self.schedule_batch.reqs) + # Reset the vocab mask self.vocab_mask = None - if any(regex_fsm is not None for regex_fsm in self.regex_fsms): + if has_regex: self.vocab_mask = torch.zeros( - len(self.regex_fsms), self.vocab_size, dtype=torch.bool, device="cuda" + len(self.temperatures), self.vocab_size, dtype=torch.bool, device="cuda" ) - for i, regex_fsm in enumerate(self.regex_fsms): - if regex_fsm is not None: + for i, req in enumerate(self.schedule_batch.reqs): + if req.regex_fsm is not None: self.vocab_mask[i].fill_(1) self.vocab_mask[i][ - regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens + req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens ] = 0 def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): @@ -138,8 +143,6 @@ class SamplingBatchInfo: if value is not None: # logit_bias can be None setattr(self, item, value[new_indices]) - self.regex_fsms = [self.regex_fsms[i] for i in new_indices] - @staticmethod def merge_bias_tensor( lhs: torch.Tensor, rhs: torch.Tensor, bs1: int, bs2: int, default: int = 0 @@ -176,5 +179,3 @@ class SamplingBatchInfo: self.logit_bias = SamplingBatchInfo.merge_bias_tensor( self.logit_bias, other.logit_bias, len(self), len(other) ) - - self.regex_fsms.extend(other.regex_fsms)