Improve performance of two batch overlap in some imbalanced cases (#6593)

This commit is contained in:
fzyzcjy
2025-05-26 13:36:18 +08:00
committed by GitHub
parent 8c7279c24e
commit a191a0e47c
2 changed files with 50 additions and 6 deletions

View File

@@ -40,13 +40,21 @@ def compute_split_seq_index(
def _split_array_by_half_sum(arr: Sequence[int]) -> int:
overall_sum = sum(arr)
accumulator, split_index = 0, 0
for value in arr[:-1]:
accumulator += value
split_index += 1
if accumulator >= overall_sum // 2:
left_sum = 0
min_diff = float("inf")
best_index = 0
for i in range(1, len(arr)):
left_sum += arr[i - 1]
right_sum = overall_sum - left_sum
diff = abs(left_sum - right_sum)
if diff <= min_diff:
min_diff = diff
best_index = i
else:
break
return split_index
return best_index
def compute_split_token_index(