Refactor attention backend (#1381)
This commit is contained in:
@@ -143,18 +143,16 @@ class SamplingBatchInfo:
|
||||
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
||||
|
||||
def update_regex_vocab_mask(self, batch: ScheduleBatch):
|
||||
bs, reqs = batch.batch_size(), batch.reqs
|
||||
device = "cuda"
|
||||
has_regex = any(req.regex_fsm is not None for req in reqs)
|
||||
has_regex = any(req.regex_fsm is not None for req in batch.reqs)
|
||||
|
||||
# Reset the vocab mask
|
||||
self.vocab_mask = None
|
||||
|
||||
if has_regex:
|
||||
self.vocab_mask = torch.zeros(
|
||||
bs, self.vocab_size, dtype=torch.bool, device=device
|
||||
batch.batch_size(), self.vocab_size, dtype=torch.bool, device="cuda"
|
||||
)
|
||||
for i, req in enumerate(reqs):
|
||||
for i, req in enumerate(batch.reqs):
|
||||
if req.regex_fsm is not None:
|
||||
self.vocab_mask[i].fill_(1)
|
||||
self.vocab_mask[i][
|
||||
|
||||
Reference in New Issue
Block a user