Skip unnecessary penalizer (#1707)

This commit is contained in:
Lianmin Zheng
2024-10-18 17:54:03 -07:00
committed by GitHub
parent bc12d4033f
commit 2bcfba1b08
7 changed files with 104 additions and 75 deletions

View File

@@ -48,20 +48,24 @@ class SamplingBatchInfo:
disable_penalizer: bool,
):
reqs = batch.reqs
with batch.input_ids.device:
temperatures = torch.tensor(
device = batch.input_ids.device
temperatures = (
torch.tensor(
[r.sampling_params.temperature for r in reqs],
dtype=torch.float,
).view(-1, 1)
top_ps = torch.tensor(
[r.sampling_params.top_p for r in reqs], dtype=torch.float
)
top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int32
)
min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float
)
.view(-1, 1)
.to(device, non_blocking=True)
)
top_ps = torch.tensor(
[r.sampling_params.top_p for r in reqs], dtype=torch.float
).to(device, non_blocking=True)
top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int32
).to(device, non_blocking=True)
min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float
).to(device, non_blocking=True)
ret = cls(
temperatures=temperatures,
@@ -80,7 +84,7 @@ class SamplingBatchInfo:
#
# While we choose not to even create the class instances if they are not required, this
# could add additional complexity to the {ScheduleBatch} class, especially we need to
# handle {filter_batch()} and {merge()} cases as well.
# handle {filter_batch()} and {merge_batch()} cases as well.
if disable_penalizer:
ret.penalizer_orchestrator = None
else:
@@ -112,19 +116,20 @@ class SamplingBatchInfo:
self.linear_penalties = None
for penalizer in self.penalizer_orchestrator.penalizers.values():
if not penalizer.is_prepared():
continue
if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
if penalizer.is_prepared():
self.scaling_penalties = penalizer.cumulated_repetition_penalties
self.scaling_penalties = penalizer.cumulated_repetition_penalties
else:
if penalizer.is_prepared():
if self.linear_penalties is None:
bs = self.penalizer_orchestrator.batch.batch_size()
self.linear_penalties = torch.zeros(
(bs, self.vocab_size),
dtype=torch.float32,
device=self.device,
)
self.linear_penalties = penalizer.apply(self.linear_penalties)
if self.linear_penalties is None:
bs = self.penalizer_orchestrator.batch.batch_size()
self.linear_penalties = torch.zeros(
(bs, self.vocab_size),
dtype=torch.float32,
device=self.device,
)
self.linear_penalties = penalizer.apply(self.linear_penalties)
def update_regex_vocab_mask(self):
has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)