diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 330be4ab8..40b083bf0 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -37,25 +37,16 @@ logger = logging.getLogger(__name__) def group_concurrent_contiguous( src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64] ) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: - src_groups = [] - dst_groups = [] - current_src = [src_indices[0]] - current_dst = [dst_indices[0]] + """Vectorised NumPy implementation.""" + if src_indices.size == 0: + return [], [] - for i in range(1, len(src_indices)): - src_contiguous = src_indices[i] == src_indices[i - 1] + 1 - dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1 - if src_contiguous and dst_contiguous: - current_src.append(src_indices[i]) - current_dst.append(dst_indices[i]) - else: - src_groups.append(current_src) - dst_groups.append(current_dst) - current_src = [src_indices[i]] - current_dst = [dst_indices[i]] + brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1 + src_groups = np.split(src_indices, brk) + dst_groups = np.split(dst_indices, brk) - src_groups.append(current_src) - dst_groups.append(current_dst) + src_groups = [g.tolist() for g in src_groups] + dst_groups = [g.tolist() for g in dst_groups] return src_groups, dst_groups