[Fix] Move ScheduleBatch out of SamplingInfo (#1556)

This commit is contained in:
Lianmin Zheng
2024-10-02 17:18:04 -07:00
committed by GitHub
parent b564835364
commit 317631cada
2 changed files with 19 additions and 10 deletions

View File

@@ -84,10 +84,6 @@ class SamplingBatchInfo:
# Handle logit bias but only allocate when needed
ret.logit_bias = None
# This is only for regex_fsm. We notice a regression if we maintain the list of regex_fsm
# in SamplingBatchInfo, so we keep it here.
ret.schedule_batch = batch
return ret
def __len__(self):
@@ -113,7 +109,7 @@ class SamplingBatchInfo:
self.linear_penalties = penalizer.apply(self.linear_penalties)
def update_regex_vocab_mask(self):
has_regex = any(req.regex_fsm is not None for req in self.schedule_batch.reqs)
has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)
# Reset the vocab mask
self.vocab_mask = None
@@ -122,11 +118,11 @@ class SamplingBatchInfo:
self.vocab_mask = torch.zeros(
len(self.temperatures), self.vocab_size, dtype=torch.bool, device="cuda"
)
for i, req in enumerate(self.schedule_batch.reqs):
if req.regex_fsm is not None:
for i, regex_fsm in enumerate(self.regex_fsms):
if regex_fsm is not None:
self.vocab_mask[i].fill_(1)
self.vocab_mask[i][
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
] = 0
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):