From 10bca45bc6415afc2d6fb764c697626875831af9 Mon Sep 17 00:00:00 2001 From: Juwan Yoo Date: Fri, 9 Aug 2024 04:46:24 -0700 Subject: [PATCH] bugfix: penalizers to be merged before reqs (#1001) --- python/sglang/srt/managers/schedule_batch.py | 7 ++-- .../srt/sampling/penaltylib/orchestrator.py | 4 +++ .../test_srt_endpoint_with_penalizers.py | 35 +++++++++++++++++++ 3 files changed, 44 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index e7c5cba92..d2101d2c0 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -679,6 +679,11 @@ class ScheduleBatch: setattr(self, item, self_val[new_indices]) def merge(self, other: "ScheduleBatch"): + # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because + # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it + # needs to be called with pre-merged Batch.reqs. + self.penalizer_orchestrator.merge(other.penalizer_orchestrator) + self.reqs.extend(other.reqs) self.req_pool_indices = torch.concat( @@ -692,8 +697,6 @@ class ScheduleBatch: self.top_logprobs_nums.extend(other.top_logprobs_nums) self.return_logprob = any(req.return_logprob for req in self.reqs) - self.penalizer_orchestrator.merge(other.penalizer_orchestrator) - for item in [ "temperatures", "top_ps", diff --git a/python/sglang/srt/sampling/penaltylib/orchestrator.py b/python/sglang/srt/sampling/penaltylib/orchestrator.py index 969a5d820..4214a746b 100644 --- a/python/sglang/srt/sampling/penaltylib/orchestrator.py +++ b/python/sglang/srt/sampling/penaltylib/orchestrator.py @@ -133,6 +133,10 @@ class BatchedPenalizerOrchestrator: """ Merge the penalizers of another orchestrator into this one. + Note that this function **must** be called _before_ self.batch.reqs is updated (filtered). + Each unprepared penalizers would have to be prepared (creating tensors, etc.) first before merging. + This step requires the original batch.reqs, before it gets merged with other batch.reqs. + Args: their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one. """ diff --git a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py index 5ea6af7cc..e72dc30f9 100644 --- a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py +++ b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py @@ -1,5 +1,6 @@ import json import unittest +from multiprocessing import Process import requests @@ -58,6 +59,40 @@ class TestBatchPenalizerE2E(unittest.TestCase): def test_default_values(self): self.run_decode() + def test_mixed(self): + """ + Sends two requests with one with penalizers disabled, and the other with penalizers enabled. + This will cause two different {ScheduleBatch} to be initialized and eventually gets merged. + + Merging batch with penalizers enabled with enabled, or disabled is trivial. However disabled + enabled is not. + This is because the penalizer will not be prepared if it is not required, then it will be prepared during the merge. + + This test triggers the merge of disabled + enabled. + """ + + processes = [] + + p = Process( + target=self.run_decode, + ) + processes.append(p) + p.start() + + p = Process( + target=self.run_decode, + kwargs={ + "frequency_penalty": 2, + "min_new_tokens": 16, + "presence_penalty": 2, + "repetition_penalty": 2, + }, + ) + processes.append(p) + p.start() + + for p in processes: + p.join() + def test_frequency_penalty(self): self.run_decode(frequency_penalty=2)