Skip unnecessary penalizer (#1707)

This commit is contained in:
Lianmin Zheng
2024-10-18 17:54:03 -07:00
committed by GitHub
parent bc12d4033f
commit 2bcfba1b08
7 changed files with 104 additions and 75 deletions

View File

@@ -37,12 +37,16 @@ class BatchedPenalizerOrchestrator:
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
is_required = False
for penalizer in self.penalizers.values():
penalizer.prepare_if_required()
pen_is_required = penalizer.prepare_if_required()
is_required |= pen_is_required
self.is_required = is_required
self.cumulate_input_tokens(
input_ids=[req.origin_input_ids for req in self.reqs()]
)
if self.is_required:
self.cumulate_input_tokens(
input_ids=[req.origin_input_ids for req in self.reqs()]
)
def reqs(self):
return self.batch.reqs
@@ -79,6 +83,9 @@ class BatchedPenalizerOrchestrator:
Args:
output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.
"""
if not self.is_required:
return
token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)
for penalizer in self.penalizers.values():
@@ -95,6 +102,9 @@ class BatchedPenalizerOrchestrator:
Returns:
torch.Tensor: The logits after applying the penalizers.
"""
if not self.is_required:
return
for penalizer in self.penalizers.values():
logits = penalizer.apply(logits)
@@ -112,10 +122,16 @@ class BatchedPenalizerOrchestrator:
indices_to_keep (typing.List[int]): List of indices to keep in the batch.
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
"""
if not self.is_required:
return
empty_indices = len(indices_to_keep) == 0
is_required = False
for penalizer in self.penalizers.values():
if not penalizer.is_required() or empty_indices:
tmp_is_required = penalizer.is_required()
is_required = is_required or tmp_is_required
if not tmp_is_required or empty_indices:
penalizer.teardown()
else:
# create tensor index only when it's needed
@@ -128,6 +144,7 @@ class BatchedPenalizerOrchestrator:
indices_to_keep=indices_to_keep,
indices_tensor_to_keep=indices_tensor_to_keep,
)
self.is_required = is_required
def merge(self, their: "BatchedPenalizerOrchestrator"):
"""
@@ -140,11 +157,10 @@ class BatchedPenalizerOrchestrator:
Args:
their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.
"""
if self.vocab_size != their.vocab_size:
raise ValueError(
f"vocab_size mismatch: {self.vocab_size} != {their.vocab_size}"
)
if not self.is_required and not their.is_required:
return
self.is_required |= their.is_required
for Penalizer, their_penalizer in their.penalizers.items():
if Penalizer not in self.penalizers:
raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers")
@@ -250,6 +266,9 @@ class _BatchedPenalizer(abc.ABC):
def prepare_if_required(self):
if self.is_required():
self.prepare()
return True
else:
return False
def teardown(self):
if self.is_prepared():

View File

@@ -48,20 +48,24 @@ class SamplingBatchInfo:
disable_penalizer: bool,
):
reqs = batch.reqs
with batch.input_ids.device:
temperatures = torch.tensor(
device = batch.input_ids.device
temperatures = (
torch.tensor(
[r.sampling_params.temperature for r in reqs],
dtype=torch.float,
).view(-1, 1)
top_ps = torch.tensor(
[r.sampling_params.top_p for r in reqs], dtype=torch.float
)
top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int32
)
min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float
)
.view(-1, 1)
.to(device, non_blocking=True)
)
top_ps = torch.tensor(
[r.sampling_params.top_p for r in reqs], dtype=torch.float
).to(device, non_blocking=True)
top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int32
).to(device, non_blocking=True)
min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float
).to(device, non_blocking=True)
ret = cls(
temperatures=temperatures,
@@ -80,7 +84,7 @@ class SamplingBatchInfo:
#
# While we choose not to even create the class instances if they are not required, this
# could add additional complexity to the {ScheduleBatch} class, especially we need to
# handle {filter_batch()} and {merge()} cases as well.
# handle {filter_batch()} and {merge_batch()} cases as well.
if disable_penalizer:
ret.penalizer_orchestrator = None
else:
@@ -112,19 +116,20 @@ class SamplingBatchInfo:
self.linear_penalties = None
for penalizer in self.penalizer_orchestrator.penalizers.values():
if not penalizer.is_prepared():
continue
if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
if penalizer.is_prepared():
self.scaling_penalties = penalizer.cumulated_repetition_penalties
self.scaling_penalties = penalizer.cumulated_repetition_penalties
else:
if penalizer.is_prepared():
if self.linear_penalties is None:
bs = self.penalizer_orchestrator.batch.batch_size()
self.linear_penalties = torch.zeros(
(bs, self.vocab_size),
dtype=torch.float32,
device=self.device,
)
self.linear_penalties = penalizer.apply(self.linear_penalties)
if self.linear_penalties is None:
bs = self.penalizer_orchestrator.batch.batch_size()
self.linear_penalties = torch.zeros(
(bs, self.vocab_size),
dtype=torch.float32,
device=self.device,
)
self.linear_penalties = penalizer.apply(self.linear_penalties)
def update_regex_vocab_mask(self):
has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)