From 30ca18f423402ae7704156f027cc91be3eaa5471 Mon Sep 17 00:00:00 2001 From: Yuan Luo Date: Wed, 21 May 2025 11:55:04 +0800 Subject: [PATCH] Refactor group_concurrent_contiguous in NIXL (#6214) Co-authored-by: luoyuan.luo --- python/sglang/srt/disaggregation/nixl/conn.py | 26 ++++++------------- 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 78df3a5ad..feff93216 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -35,29 +35,19 @@ logger = logging.getLogger(__name__) NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]] -# From Mooncake backend. def group_concurrent_contiguous( src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64] ) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: - src_groups = [] - dst_groups = [] - current_src = [src_indices[0]] - current_dst = [dst_indices[0]] + """Vectorised NumPy implementation.""" + if src_indices.size == 0: + return [], [] - for i in range(1, len(src_indices)): - src_contiguous = src_indices[i] == src_indices[i - 1] + 1 - dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1 - if src_contiguous and dst_contiguous: - current_src.append(src_indices[i]) - current_dst.append(dst_indices[i]) - else: - src_groups.append(current_src) - dst_groups.append(current_dst) - current_src = [src_indices[i]] - current_dst = [dst_indices[i]] + brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1 + src_groups = np.split(src_indices, brk) + dst_groups = np.split(dst_indices, brk) - src_groups.append(current_src) - dst_groups.append(current_dst) + src_groups = [g.tolist() for g in src_groups] + dst_groups = [g.tolist() for g in dst_groups] return src_groups, dst_groups