bugfix: penalizers to be merged before reqs (#1001)
This commit is contained in:
@@ -679,6 +679,11 @@ class ScheduleBatch:
|
||||
setattr(self, item, self_val[new_indices])
|
||||
|
||||
def merge(self, other: "ScheduleBatch"):
|
||||
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
||||
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
|
||||
# needs to be called with pre-merged Batch.reqs.
|
||||
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
||||
|
||||
self.reqs.extend(other.reqs)
|
||||
|
||||
self.req_pool_indices = torch.concat(
|
||||
@@ -692,8 +697,6 @@ class ScheduleBatch:
|
||||
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||
|
||||
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
|
||||
@@ -133,6 +133,10 @@ class BatchedPenalizerOrchestrator:
|
||||
"""
|
||||
Merge the penalizers of another orchestrator into this one.
|
||||
|
||||
Note that this function **must** be called _before_ self.batch.reqs is updated (filtered).
|
||||
Each unprepared penalizers would have to be prepared (creating tensors, etc.) first before merging.
|
||||
This step requires the original batch.reqs, before it gets merged with other batch.reqs.
|
||||
|
||||
Args:
|
||||
their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.
|
||||
"""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import unittest
|
||||
from multiprocessing import Process
|
||||
|
||||
import requests
|
||||
|
||||
@@ -58,6 +59,40 @@ class TestBatchPenalizerE2E(unittest.TestCase):
|
||||
def test_default_values(self):
|
||||
self.run_decode()
|
||||
|
||||
def test_mixed(self):
|
||||
"""
|
||||
Sends two requests with one with penalizers disabled, and the other with penalizers enabled.
|
||||
This will cause two different {ScheduleBatch} to be initialized and eventually gets merged.
|
||||
|
||||
Merging batch with penalizers enabled with enabled, or disabled is trivial. However disabled + enabled is not.
|
||||
This is because the penalizer will not be prepared if it is not required, then it will be prepared during the merge.
|
||||
|
||||
This test triggers the merge of disabled + enabled.
|
||||
"""
|
||||
|
||||
processes = []
|
||||
|
||||
p = Process(
|
||||
target=self.run_decode,
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
p = Process(
|
||||
target=self.run_decode,
|
||||
kwargs={
|
||||
"frequency_penalty": 2,
|
||||
"min_new_tokens": 16,
|
||||
"presence_penalty": 2,
|
||||
"repetition_penalty": 2,
|
||||
},
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
def test_frequency_penalty(self):
|
||||
self.run_decode(frequency_penalty=2)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user