bugfix: penalizers to be merged before reqs (#1001)

This commit is contained in:
Juwan Yoo
2024-08-09 04:46:24 -07:00
committed by GitHub
parent b91a4cb1b1
commit 10bca45bc6
3 changed files with 44 additions and 2 deletions

View File

@@ -679,6 +679,11 @@ class ScheduleBatch:
setattr(self, item, self_val[new_indices])
def merge(self, other: "ScheduleBatch"):
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
# needs to be called with pre-merged Batch.reqs.
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
self.reqs.extend(other.reqs)
self.req_pool_indices = torch.concat(
@@ -692,8 +697,6 @@ class ScheduleBatch:
self.top_logprobs_nums.extend(other.top_logprobs_nums)
self.return_logprob = any(req.return_logprob for req in self.reqs)
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
for item in [
"temperatures",
"top_ps",

View File

@@ -133,6 +133,10 @@ class BatchedPenalizerOrchestrator:
"""
Merge the penalizers of another orchestrator into this one.
Note that this function **must** be called _before_ self.batch.reqs is updated (filtered).
Each unprepared penalizers would have to be prepared (creating tensors, etc.) first before merging.
This step requires the original batch.reqs, before it gets merged with other batch.reqs.
Args:
their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.
"""