[PD] Vectorise group_concurrent_contiguous in NumPy (#5834)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
@@ -37,25 +37,16 @@ logger = logging.getLogger(__name__)
|
||||
def group_concurrent_contiguous(
|
||||
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
|
||||
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
|
||||
src_groups = []
|
||||
dst_groups = []
|
||||
current_src = [src_indices[0]]
|
||||
current_dst = [dst_indices[0]]
|
||||
"""Vectorised NumPy implementation."""
|
||||
if src_indices.size == 0:
|
||||
return [], []
|
||||
|
||||
for i in range(1, len(src_indices)):
|
||||
src_contiguous = src_indices[i] == src_indices[i - 1] + 1
|
||||
dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1
|
||||
if src_contiguous and dst_contiguous:
|
||||
current_src.append(src_indices[i])
|
||||
current_dst.append(dst_indices[i])
|
||||
else:
|
||||
src_groups.append(current_src)
|
||||
dst_groups.append(current_dst)
|
||||
current_src = [src_indices[i]]
|
||||
current_dst = [dst_indices[i]]
|
||||
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
|
||||
src_groups = np.split(src_indices, brk)
|
||||
dst_groups = np.split(dst_indices, brk)
|
||||
|
||||
src_groups.append(current_src)
|
||||
dst_groups.append(current_dst)
|
||||
src_groups = [g.tolist() for g in src_groups]
|
||||
dst_groups = [g.tolist() for g in dst_groups]
|
||||
|
||||
return src_groups, dst_groups
|
||||
|
||||
|
||||
Reference in New Issue
Block a user