[PD] Vectorise group_concurrent_contiguous in NumPy (#5834)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
Yuan Luo
2025-05-01 22:42:37 +08:00
committed by GitHub
parent 4322c31e24
commit 67b7d5b1df

View File

@@ -37,25 +37,16 @@ logger = logging.getLogger(__name__)
def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
src_groups = []
dst_groups = []
current_src = [src_indices[0]]
current_dst = [dst_indices[0]]
"""Vectorised NumPy implementation."""
if src_indices.size == 0:
return [], []
for i in range(1, len(src_indices)):
src_contiguous = src_indices[i] == src_indices[i - 1] + 1
dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1
if src_contiguous and dst_contiguous:
current_src.append(src_indices[i])
current_dst.append(dst_indices[i])
else:
src_groups.append(current_src)
dst_groups.append(current_dst)
current_src = [src_indices[i]]
current_dst = [dst_indices[i]]
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
src_groups = np.split(src_indices, brk)
dst_groups = np.split(dst_indices, brk)
src_groups.append(current_src)
dst_groups.append(current_dst)
src_groups = [g.tolist() for g in src_groups]
dst_groups = [g.tolist() for g in dst_groups]
return src_groups, dst_groups