diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index 0fbc3c8e7..79ba76d49 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -40,13 +40,21 @@ def compute_split_seq_index( def _split_array_by_half_sum(arr: Sequence[int]) -> int: overall_sum = sum(arr) - accumulator, split_index = 0, 0 - for value in arr[:-1]: - accumulator += value - split_index += 1 - if accumulator >= overall_sum // 2: + left_sum = 0 + min_diff = float("inf") + best_index = 0 + + for i in range(1, len(arr)): + left_sum += arr[i - 1] + right_sum = overall_sum - left_sum + diff = abs(left_sum - right_sum) + if diff <= min_diff: + min_diff = diff + best_index = i + else: break - return split_index + + return best_index def compute_split_token_index( diff --git a/test/srt/test_two_batch_overlap.py b/test/srt/test_two_batch_overlap.py index 89e793ca6..765679fc3 100644 --- a/test/srt/test_two_batch_overlap.py +++ b/test/srt/test_two_batch_overlap.py @@ -4,6 +4,8 @@ from types import SimpleNamespace import requests +from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.two_batch_overlap import compute_split_seq_index from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( @@ -68,5 +70,39 @@ class TestTwoBatchOverlap(unittest.TestCase): self.assertGreater(metrics["score"], 0.5) +class TestTwoBatchOverlapUnitTest(unittest.TestCase): + # TODO change tests when having 6328 + def test_compute_split_seq_index(self): + for num_tokens, expect in [ + (0, 0), + (100, 50), + (99, 49), + ]: + actual = compute_split_seq_index( + forward_mode=ForwardMode.DECODE, num_tokens=num_tokens, extend_lens=None + ) + self.assertEqual(actual, expect) + + for extend_lens, expect in [ + ([], 0), + ([42], 0), + ([42, 999], 1), + ([999, 42], 1), + ([4096, 4096, 4096, 4096], 2), + ([4095, 4096, 4096, 4096, 1], 2), + ([1, 4095, 4096, 4096, 4096], 3), + ([4097, 4096, 4096, 4095, 1], 2), + ([1, 1, 1, 1, 99999], 4), + ([99999, 1, 1, 1, 1], 1), + ]: + actual = compute_split_seq_index( + forward_mode=ForwardMode.EXTEND, + num_tokens=None, + extend_lens=extend_lens, + ) + print(f"{extend_lens=} {expect=} {actual=}") + self.assertEqual(actual, expect) + + if __name__ == "__main__": unittest.main()