Refactor attention backend (#1381)

This commit is contained in:
Lianmin Zheng
2024-09-11 11:44:26 -07:00
committed by GitHub
parent c03cece42f
commit fec185ce0c
16 changed files with 568 additions and 564 deletions

View File

@@ -143,18 +143,16 @@ class SamplingBatchInfo:
self.linear_penalties = penalizer.apply(self.linear_penalties)
def update_regex_vocab_mask(self, batch: ScheduleBatch):
bs, reqs = batch.batch_size(), batch.reqs
device = "cuda"
has_regex = any(req.regex_fsm is not None for req in reqs)
has_regex = any(req.regex_fsm is not None for req in batch.reqs)
# Reset the vocab mask
self.vocab_mask = None
if has_regex:
self.vocab_mask = torch.zeros(
bs, self.vocab_size, dtype=torch.bool, device=device
batch.batch_size(), self.vocab_size, dtype=torch.bool, device="cuda"
)
for i, req in enumerate(reqs):
for i, req in enumerate(batch.reqs):
if req.regex_fsm is not None:
self.vocab_mask[i].fill_(1)
self.vocab_mask[i][