bugfix: penalizers to be merged before reqs (#1001)
This commit is contained in:
@@ -679,6 +679,11 @@ class ScheduleBatch:
|
||||
setattr(self, item, self_val[new_indices])
|
||||
|
||||
def merge(self, other: "ScheduleBatch"):
|
||||
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
||||
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
|
||||
# needs to be called with pre-merged Batch.reqs.
|
||||
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
||||
|
||||
self.reqs.extend(other.reqs)
|
||||
|
||||
self.req_pool_indices = torch.concat(
|
||||
@@ -692,8 +697,6 @@ class ScheduleBatch:
|
||||
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||
|
||||
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
|
||||
@@ -133,6 +133,10 @@ class BatchedPenalizerOrchestrator:
|
||||
"""
|
||||
Merge the penalizers of another orchestrator into this one.
|
||||
|
||||
Note that this function **must** be called _before_ self.batch.reqs is updated (filtered).
|
||||
Each unprepared penalizers would have to be prepared (creating tensors, etc.) first before merging.
|
||||
This step requires the original batch.reqs, before it gets merged with other batch.reqs.
|
||||
|
||||
Args:
|
||||
their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user