bugfix: penalizers to be merged before reqs (#1001)
This commit is contained in:
@@ -679,6 +679,11 @@ class ScheduleBatch:
|
|||||||
setattr(self, item, self_val[new_indices])
|
setattr(self, item, self_val[new_indices])
|
||||||
|
|
||||||
def merge(self, other: "ScheduleBatch"):
|
def merge(self, other: "ScheduleBatch"):
|
||||||
|
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
||||||
|
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
|
||||||
|
# needs to be called with pre-merged Batch.reqs.
|
||||||
|
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
||||||
|
|
||||||
self.reqs.extend(other.reqs)
|
self.reqs.extend(other.reqs)
|
||||||
|
|
||||||
self.req_pool_indices = torch.concat(
|
self.req_pool_indices = torch.concat(
|
||||||
@@ -692,8 +697,6 @@ class ScheduleBatch:
|
|||||||
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
||||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||||
|
|
||||||
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
|
||||||
|
|
||||||
for item in [
|
for item in [
|
||||||
"temperatures",
|
"temperatures",
|
||||||
"top_ps",
|
"top_ps",
|
||||||
|
|||||||
@@ -133,6 +133,10 @@ class BatchedPenalizerOrchestrator:
|
|||||||
"""
|
"""
|
||||||
Merge the penalizers of another orchestrator into this one.
|
Merge the penalizers of another orchestrator into this one.
|
||||||
|
|
||||||
|
Note that this function **must** be called _before_ self.batch.reqs is updated (filtered).
|
||||||
|
Each unprepared penalizers would have to be prepared (creating tensors, etc.) first before merging.
|
||||||
|
This step requires the original batch.reqs, before it gets merged with other batch.reqs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.
|
their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import unittest
|
import unittest
|
||||||
|
from multiprocessing import Process
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@@ -58,6 +59,40 @@ class TestBatchPenalizerE2E(unittest.TestCase):
|
|||||||
def test_default_values(self):
|
def test_default_values(self):
|
||||||
self.run_decode()
|
self.run_decode()
|
||||||
|
|
||||||
|
def test_mixed(self):
|
||||||
|
"""
|
||||||
|
Sends two requests with one with penalizers disabled, and the other with penalizers enabled.
|
||||||
|
This will cause two different {ScheduleBatch} to be initialized and eventually gets merged.
|
||||||
|
|
||||||
|
Merging batch with penalizers enabled with enabled, or disabled is trivial. However disabled + enabled is not.
|
||||||
|
This is because the penalizer will not be prepared if it is not required, then it will be prepared during the merge.
|
||||||
|
|
||||||
|
This test triggers the merge of disabled + enabled.
|
||||||
|
"""
|
||||||
|
|
||||||
|
processes = []
|
||||||
|
|
||||||
|
p = Process(
|
||||||
|
target=self.run_decode,
|
||||||
|
)
|
||||||
|
processes.append(p)
|
||||||
|
p.start()
|
||||||
|
|
||||||
|
p = Process(
|
||||||
|
target=self.run_decode,
|
||||||
|
kwargs={
|
||||||
|
"frequency_penalty": 2,
|
||||||
|
"min_new_tokens": 16,
|
||||||
|
"presence_penalty": 2,
|
||||||
|
"repetition_penalty": 2,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
processes.append(p)
|
||||||
|
p.start()
|
||||||
|
|
||||||
|
for p in processes:
|
||||||
|
p.join()
|
||||||
|
|
||||||
def test_frequency_penalty(self):
|
def test_frequency_penalty(self):
|
||||||
self.run_decode(frequency_penalty=2)
|
self.run_decode(frequency_penalty=2)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user