Improve performance of two batch overlap in some imbalanced cases (#6593)
This commit is contained in:
@@ -40,13 +40,21 @@ def compute_split_seq_index(
|
|||||||
|
|
||||||
def _split_array_by_half_sum(arr: Sequence[int]) -> int:
|
def _split_array_by_half_sum(arr: Sequence[int]) -> int:
|
||||||
overall_sum = sum(arr)
|
overall_sum = sum(arr)
|
||||||
accumulator, split_index = 0, 0
|
left_sum = 0
|
||||||
for value in arr[:-1]:
|
min_diff = float("inf")
|
||||||
accumulator += value
|
best_index = 0
|
||||||
split_index += 1
|
|
||||||
if accumulator >= overall_sum // 2:
|
for i in range(1, len(arr)):
|
||||||
|
left_sum += arr[i - 1]
|
||||||
|
right_sum = overall_sum - left_sum
|
||||||
|
diff = abs(left_sum - right_sum)
|
||||||
|
if diff <= min_diff:
|
||||||
|
min_diff = diff
|
||||||
|
best_index = i
|
||||||
|
else:
|
||||||
break
|
break
|
||||||
return split_index
|
|
||||||
|
return best_index
|
||||||
|
|
||||||
|
|
||||||
def compute_split_token_index(
|
def compute_split_token_index(
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ from types import SimpleNamespace
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
|
from sglang.srt.two_batch_overlap import compute_split_seq_index
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
from sglang.test.run_eval import run_eval
|
from sglang.test.run_eval import run_eval
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
@@ -68,5 +70,39 @@ class TestTwoBatchOverlap(unittest.TestCase):
|
|||||||
self.assertGreater(metrics["score"], 0.5)
|
self.assertGreater(metrics["score"], 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTwoBatchOverlapUnitTest(unittest.TestCase):
|
||||||
|
# TODO change tests when having 6328
|
||||||
|
def test_compute_split_seq_index(self):
|
||||||
|
for num_tokens, expect in [
|
||||||
|
(0, 0),
|
||||||
|
(100, 50),
|
||||||
|
(99, 49),
|
||||||
|
]:
|
||||||
|
actual = compute_split_seq_index(
|
||||||
|
forward_mode=ForwardMode.DECODE, num_tokens=num_tokens, extend_lens=None
|
||||||
|
)
|
||||||
|
self.assertEqual(actual, expect)
|
||||||
|
|
||||||
|
for extend_lens, expect in [
|
||||||
|
([], 0),
|
||||||
|
([42], 0),
|
||||||
|
([42, 999], 1),
|
||||||
|
([999, 42], 1),
|
||||||
|
([4096, 4096, 4096, 4096], 2),
|
||||||
|
([4095, 4096, 4096, 4096, 1], 2),
|
||||||
|
([1, 4095, 4096, 4096, 4096], 3),
|
||||||
|
([4097, 4096, 4096, 4095, 1], 2),
|
||||||
|
([1, 1, 1, 1, 99999], 4),
|
||||||
|
([99999, 1, 1, 1, 1], 1),
|
||||||
|
]:
|
||||||
|
actual = compute_split_seq_index(
|
||||||
|
forward_mode=ForwardMode.EXTEND,
|
||||||
|
num_tokens=None,
|
||||||
|
extend_lens=extend_lens,
|
||||||
|
)
|
||||||
|
print(f"{extend_lens=} {expect=} {actual=}")
|
||||||
|
self.assertEqual(actual, expect)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user