Skip unnecessary penalizer (#1707)
This commit is contained in:
@@ -37,12 +37,16 @@ class BatchedPenalizerOrchestrator:
|
||||
|
||||
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
|
||||
|
||||
is_required = False
|
||||
for penalizer in self.penalizers.values():
|
||||
penalizer.prepare_if_required()
|
||||
pen_is_required = penalizer.prepare_if_required()
|
||||
is_required |= pen_is_required
|
||||
self.is_required = is_required
|
||||
|
||||
self.cumulate_input_tokens(
|
||||
input_ids=[req.origin_input_ids for req in self.reqs()]
|
||||
)
|
||||
if self.is_required:
|
||||
self.cumulate_input_tokens(
|
||||
input_ids=[req.origin_input_ids for req in self.reqs()]
|
||||
)
|
||||
|
||||
def reqs(self):
|
||||
return self.batch.reqs
|
||||
@@ -79,6 +83,9 @@ class BatchedPenalizerOrchestrator:
|
||||
Args:
|
||||
output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.
|
||||
"""
|
||||
if not self.is_required:
|
||||
return
|
||||
|
||||
token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)
|
||||
|
||||
for penalizer in self.penalizers.values():
|
||||
@@ -95,6 +102,9 @@ class BatchedPenalizerOrchestrator:
|
||||
Returns:
|
||||
torch.Tensor: The logits after applying the penalizers.
|
||||
"""
|
||||
if not self.is_required:
|
||||
return
|
||||
|
||||
for penalizer in self.penalizers.values():
|
||||
logits = penalizer.apply(logits)
|
||||
|
||||
@@ -112,10 +122,16 @@ class BatchedPenalizerOrchestrator:
|
||||
indices_to_keep (typing.List[int]): List of indices to keep in the batch.
|
||||
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
|
||||
"""
|
||||
if not self.is_required:
|
||||
return
|
||||
|
||||
empty_indices = len(indices_to_keep) == 0
|
||||
|
||||
is_required = False
|
||||
for penalizer in self.penalizers.values():
|
||||
if not penalizer.is_required() or empty_indices:
|
||||
tmp_is_required = penalizer.is_required()
|
||||
is_required = is_required or tmp_is_required
|
||||
if not tmp_is_required or empty_indices:
|
||||
penalizer.teardown()
|
||||
else:
|
||||
# create tensor index only when it's needed
|
||||
@@ -128,6 +144,7 @@ class BatchedPenalizerOrchestrator:
|
||||
indices_to_keep=indices_to_keep,
|
||||
indices_tensor_to_keep=indices_tensor_to_keep,
|
||||
)
|
||||
self.is_required = is_required
|
||||
|
||||
def merge(self, their: "BatchedPenalizerOrchestrator"):
|
||||
"""
|
||||
@@ -140,11 +157,10 @@ class BatchedPenalizerOrchestrator:
|
||||
Args:
|
||||
their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.
|
||||
"""
|
||||
if self.vocab_size != their.vocab_size:
|
||||
raise ValueError(
|
||||
f"vocab_size mismatch: {self.vocab_size} != {their.vocab_size}"
|
||||
)
|
||||
if not self.is_required and not their.is_required:
|
||||
return
|
||||
|
||||
self.is_required |= their.is_required
|
||||
for Penalizer, their_penalizer in their.penalizers.items():
|
||||
if Penalizer not in self.penalizers:
|
||||
raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers")
|
||||
@@ -250,6 +266,9 @@ class _BatchedPenalizer(abc.ABC):
|
||||
def prepare_if_required(self):
|
||||
if self.is_required():
|
||||
self.prepare()
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def teardown(self):
|
||||
if self.is_prepared():
|
||||
|
||||
@@ -48,20 +48,24 @@ class SamplingBatchInfo:
|
||||
disable_penalizer: bool,
|
||||
):
|
||||
reqs = batch.reqs
|
||||
with batch.input_ids.device:
|
||||
temperatures = torch.tensor(
|
||||
device = batch.input_ids.device
|
||||
temperatures = (
|
||||
torch.tensor(
|
||||
[r.sampling_params.temperature for r in reqs],
|
||||
dtype=torch.float,
|
||||
).view(-1, 1)
|
||||
top_ps = torch.tensor(
|
||||
[r.sampling_params.top_p for r in reqs], dtype=torch.float
|
||||
)
|
||||
top_ks = torch.tensor(
|
||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int32
|
||||
)
|
||||
min_ps = torch.tensor(
|
||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
||||
)
|
||||
.view(-1, 1)
|
||||
.to(device, non_blocking=True)
|
||||
)
|
||||
top_ps = torch.tensor(
|
||||
[r.sampling_params.top_p for r in reqs], dtype=torch.float
|
||||
).to(device, non_blocking=True)
|
||||
top_ks = torch.tensor(
|
||||
[r.sampling_params.top_k for r in reqs], dtype=torch.int32
|
||||
).to(device, non_blocking=True)
|
||||
min_ps = torch.tensor(
|
||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
ret = cls(
|
||||
temperatures=temperatures,
|
||||
@@ -80,7 +84,7 @@ class SamplingBatchInfo:
|
||||
#
|
||||
# While we choose not to even create the class instances if they are not required, this
|
||||
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
||||
# handle {filter_batch()} and {merge()} cases as well.
|
||||
# handle {filter_batch()} and {merge_batch()} cases as well.
|
||||
if disable_penalizer:
|
||||
ret.penalizer_orchestrator = None
|
||||
else:
|
||||
@@ -112,19 +116,20 @@ class SamplingBatchInfo:
|
||||
self.linear_penalties = None
|
||||
|
||||
for penalizer in self.penalizer_orchestrator.penalizers.values():
|
||||
if not penalizer.is_prepared():
|
||||
continue
|
||||
|
||||
if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
|
||||
if penalizer.is_prepared():
|
||||
self.scaling_penalties = penalizer.cumulated_repetition_penalties
|
||||
self.scaling_penalties = penalizer.cumulated_repetition_penalties
|
||||
else:
|
||||
if penalizer.is_prepared():
|
||||
if self.linear_penalties is None:
|
||||
bs = self.penalizer_orchestrator.batch.batch_size()
|
||||
self.linear_penalties = torch.zeros(
|
||||
(bs, self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
||||
if self.linear_penalties is None:
|
||||
bs = self.penalizer_orchestrator.batch.batch_size()
|
||||
self.linear_penalties = torch.zeros(
|
||||
(bs, self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
||||
|
||||
def update_regex_vocab_mask(self):
|
||||
has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)
|
||||
|
||||
Reference in New Issue
Block a user