[Feature] improve TBO: two chunk overlap (#8144)

This commit is contained in:
HouseWest
2025-08-06 12:11:01 +08:00
committed by GitHub
parent d26ca84f39
commit ca47e24f5d
6 changed files with 218 additions and 29 deletions

View File

@@ -5,7 +5,10 @@ from types import SimpleNamespace
import requests
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.two_batch_overlap import compute_split_seq_index
from sglang.srt.two_batch_overlap import (
compute_split_seq_index,
compute_split_token_index,
)
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
@@ -73,35 +76,46 @@ class TestTwoBatchOverlap(unittest.TestCase):
class TestTwoBatchOverlapUnitTest(unittest.TestCase):
# TODO change tests when having 6328
def test_compute_split_seq_index(self):
def test_compute_split_seq_and_token_index(self):
for num_tokens, expect in [
(0, 0),
(100, 50),
(99, 49),
]:
actual = compute_split_seq_index(
forward_mode=ForwardMode.DECODE, num_tokens=num_tokens, extend_lens=None
forward_mode=ForwardMode.DECODE,
num_tokens=num_tokens,
extend_lens=None,
token_num_per_seq=1,
)
self.assertEqual(actual, expect)
for extend_lens, expect in [
([], 0),
([42], 0),
([42, 999], 1),
([999, 42], 1),
([4096, 4096, 4096, 4096], 2),
([4095, 4096, 4096, 4096, 1], 2),
([1, 4095, 4096, 4096, 4096], 3),
([4097, 4096, 4096, 4095, 1], 2),
([1, 1, 1, 1, 99999], 4),
([99999, 1, 1, 1, 1], 1),
([], (0, 0)),
([42], (0, 21)),
([42, 999], (1, 520)),
([999, 42], (0, 520)),
([498, 502], (1, 498)),
([4096, 4096, 4096, 4096], (2, 8192)),
([4095, 4096, 4096, 4096, 1], (2, 8191)),
([1, 4095, 4096, 4096, 4096], (3, 8192)),
([4097, 4096, 4096, 4095, 1], (2, 8193)),
([1, 1, 1, 1, 99999], (4, 50001)),
([99999, 1, 1, 1, 1], (0, 50001)),
]:
actual = compute_split_seq_index(
actual_seq_idx = compute_split_seq_index(
forward_mode=ForwardMode.EXTEND,
num_tokens=None,
extend_lens=extend_lens,
token_num_per_seq=None,
)
actual_token_idx = compute_split_token_index(
split_seq_index=actual_seq_idx,
forward_mode=ForwardMode.EXTEND,
extend_seq_lens=extend_lens,
token_num_per_seq=None,
)
actual = (actual_seq_idx, actual_token_idx)
print(f"{extend_lens=} {expect=} {actual=}")
self.assertEqual(actual, expect)