Improve performance of two batch overlap in some imbalanced cases (#6593)
This commit is contained in:
@@ -40,13 +40,21 @@ def compute_split_seq_index(
|
||||
|
||||
def _split_array_by_half_sum(arr: Sequence[int]) -> int:
|
||||
overall_sum = sum(arr)
|
||||
accumulator, split_index = 0, 0
|
||||
for value in arr[:-1]:
|
||||
accumulator += value
|
||||
split_index += 1
|
||||
if accumulator >= overall_sum // 2:
|
||||
left_sum = 0
|
||||
min_diff = float("inf")
|
||||
best_index = 0
|
||||
|
||||
for i in range(1, len(arr)):
|
||||
left_sum += arr[i - 1]
|
||||
right_sum = overall_sum - left_sum
|
||||
diff = abs(left_sum - right_sum)
|
||||
if diff <= min_diff:
|
||||
min_diff = diff
|
||||
best_index = i
|
||||
else:
|
||||
break
|
||||
return split_index
|
||||
|
||||
return best_index
|
||||
|
||||
|
||||
def compute_split_token_index(
|
||||
|
||||
Reference in New Issue
Block a user