Skip unnecessary penalizer (#1707)
This commit is contained in:
@@ -48,20 +48,24 @@ class SamplingBatchInfo:
|
||||
disable_penalizer: bool,
|
||||
):
|
||||
reqs = batch.reqs
|
||||
with batch.input_ids.device:
|
||||
temperatures = torch.tensor(
|
||||
device = batch.input_ids.device
|
||||
temperatures = (
|
||||
torch.tensor(
|
||||
[r.sampling_params.temperature for r in reqs],
|
||||
dtype=torch.float,
|
||||
).view(-1, 1)
|
||||
top_ps = torch.tensor(
|
||||
[r.sampling_params.top_p for r in reqs], dtype=torch.float
|
||||
)
|
||||
top_ks = torch.tensor(
|
||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int32
|
||||
)
|
||||
min_ps = torch.tensor(
|
||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
||||
)
|
||||
.view(-1, 1)
|
||||
.to(device, non_blocking=True)
|
||||
)
|
||||
top_ps = torch.tensor(
|
||||
[r.sampling_params.top_p for r in reqs], dtype=torch.float
|
||||
).to(device, non_blocking=True)
|
||||
top_ks = torch.tensor(
|
||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int32
|
||||
).to(device, non_blocking=True)
|
||||
min_ps = torch.tensor(
|
||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
ret = cls(
|
||||
temperatures=temperatures,
|
||||
@@ -80,7 +84,7 @@ class SamplingBatchInfo:
|
||||
#
|
||||
# While we choose not to even create the class instances if they are not required, this
|
||||
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
||||
# handle {filter_batch()} and {merge()} cases as well.
|
||||
# handle {filter_batch()} and {merge_batch()} cases as well.
|
||||
if disable_penalizer:
|
||||
ret.penalizer_orchestrator = None
|
||||
else:
|
||||
@@ -112,19 +116,20 @@ class SamplingBatchInfo:
|
||||
self.linear_penalties = None
|
||||
|
||||
for penalizer in self.penalizer_orchestrator.penalizers.values():
|
||||
if not penalizer.is_prepared():
|
||||
continue
|
||||
|
||||
if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
|
||||
if penalizer.is_prepared():
|
||||
self.scaling_penalties = penalizer.cumulated_repetition_penalties
|
||||
self.scaling_penalties = penalizer.cumulated_repetition_penalties
|
||||
else:
|
||||
if penalizer.is_prepared():
|
||||
if self.linear_penalties is None:
|
||||
bs = self.penalizer_orchestrator.batch.batch_size()
|
||||
self.linear_penalties = torch.zeros(
|
||||
(bs, self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
||||
if self.linear_penalties is None:
|
||||
bs = self.penalizer_orchestrator.batch.batch_size()
|
||||
self.linear_penalties = torch.zeros(
|
||||
(bs, self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
||||
|
||||
def update_regex_vocab_mask(self):
|
||||
has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)
|
||||
|
||||
Reference in New Issue
Block a user