bugfix: penalizers to be merged before reqs (#1001)

This commit is contained in:
Juwan Yoo
2024-08-09 04:46:24 -07:00
committed by GitHub
parent b91a4cb1b1
commit 10bca45bc6
3 changed files with 44 additions and 2 deletions

View File

@@ -679,6 +679,11 @@ class ScheduleBatch:
setattr(self, item, self_val[new_indices])
def merge(self, other: "ScheduleBatch"):
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
# needs to be called with pre-merged Batch.reqs.
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
self.reqs.extend(other.reqs)
self.req_pool_indices = torch.concat(
@@ -692,8 +697,6 @@ class ScheduleBatch:
self.top_logprobs_nums.extend(other.top_logprobs_nums)
self.return_logprob = any(req.return_logprob for req in self.reqs)
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
for item in [
"temperatures",
"top_ps",