From 88f9c347b288733c87b3056b6e1a8d2717036ce6 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Sun, 15 Jun 2025 11:51:03 -0700 Subject: [PATCH] [PD] use int32 for kv indices & get num_reserved_decode_tokens from server_args (#7214) --- python/sglang/srt/disaggregation/base/conn.py | 4 ++-- .../sglang/srt/disaggregation/common/utils.py | 4 ++-- python/sglang/srt/disaggregation/decode.py | 6 ++---- python/sglang/srt/disaggregation/fake/conn.py | 2 +- .../sglang/srt/disaggregation/mooncake/conn.py | 18 +++++++++--------- python/sglang/srt/disaggregation/nixl/conn.py | 14 +++++++------- python/sglang/srt/disaggregation/prefill.py | 1 - python/sglang/srt/managers/scheduler.py | 1 + 8 files changed, 24 insertions(+), 26 deletions(-) diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py index d4331c234..8e9487be6 100644 --- a/python/sglang/srt/disaggregation/base/conn.py +++ b/python/sglang/srt/disaggregation/base/conn.py @@ -70,7 +70,7 @@ class BaseKVSender(ABC): ... @abstractmethod - def send(self, kv_indices: npt.NDArray[np.int64]): + def send(self, kv_indices: npt.NDArray[np.int32]): """ Send the kv cache at the given kv indices to the decoder server """ @@ -102,7 +102,7 @@ class BaseKVReceiver(ABC): ): ... @abstractmethod - def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): + def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): """ Notify the prefill server about the kv indices and aux index """ diff --git a/python/sglang/srt/disaggregation/common/utils.py b/python/sglang/srt/disaggregation/common/utils.py index ba0cfd6af..6f3da2128 100644 --- a/python/sglang/srt/disaggregation/common/utils.py +++ b/python/sglang/srt/disaggregation/common/utils.py @@ -26,8 +26,8 @@ class FastQueue: def group_concurrent_contiguous( - src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64] -) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: + src_indices: npt.NDArray[np.int32], dst_indices: npt.NDArray[np.int32] +) -> Tuple[List[npt.NDArray[np.int32]], List[npt.NDArray[np.int32]]]: """Vectorised NumPy implementation.""" if src_indices.size == 0: return [], [] diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 336b0581d..9c0860cd0 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -158,6 +158,7 @@ class DecodePreallocQueue: bootstrap_port: int, max_total_num_tokens: int, prefill_pp_size: int, + num_reserved_decode_tokens: int, transfer_backend: TransferBackend, ): self.req_to_token_pool = req_to_token_pool @@ -178,9 +179,7 @@ class DecodePreallocQueue: self.bootstrap_port = bootstrap_port self.max_total_num_tokens = max_total_num_tokens self.prefill_pp_size = prefill_pp_size - self.num_reserved_decode_tokens = int( - os.environ.get("SGLANG_NUM_RESERVED_DECODE_TOKENS", "512") - ) + self.num_reserved_decode_tokens = num_reserved_decode_tokens self.transfer_backend = transfer_backend # Queue for requests pending pre-allocation self.queue: List[DecodeRequest] = [] @@ -404,7 +403,6 @@ class DecodePreallocQueue: ] .cpu() .numpy() - .astype(np.int64) ) decode_req.metadata_buffer_index = ( diff --git a/python/sglang/srt/disaggregation/fake/conn.py b/python/sglang/srt/disaggregation/fake/conn.py index 25335dd68..63a39ac2f 100644 --- a/python/sglang/srt/disaggregation/fake/conn.py +++ b/python/sglang/srt/disaggregation/fake/conn.py @@ -48,7 +48,7 @@ class FakeKVSender(BaseKVSender): def send( self, - kv_indices: npt.NDArray[np.int64], + kv_indices: npt.NDArray[np.int32], ): self.has_sent = True logger.info(f"FakeKVSender send with kv_indices: {kv_indices}") diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index b3d83db69..0e64f634b 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -59,7 +59,7 @@ class KVTransferError(Exception): @dataclasses.dataclass class TransferKVChunk: room: int - prefill_kv_indices: npt.NDArray[np.int64] + prefill_kv_indices: npt.NDArray[np.int32] index_slice: slice is_last: bool prefill_aux_index: Optional[int] @@ -72,7 +72,7 @@ class TransferInfo: endpoint: str dst_port: int mooncake_session_id: str - dst_kv_indices: npt.NDArray[np.int64] + dst_kv_indices: npt.NDArray[np.int32] dst_aux_index: int required_dst_info_num: int is_dummy: bool @@ -81,10 +81,10 @@ class TransferInfo: def from_zmq(cls, msg: List[bytes]): if msg[4] == b"" and msg[5] == b"": is_dummy = True - dst_kv_indices = np.array([], dtype=np.int64) + dst_kv_indices = np.array([], dtype=np.int32) dst_aux_index = None else: - dst_kv_indices = np.frombuffer(msg[4], dtype=np.int64) + dst_kv_indices = np.frombuffer(msg[4], dtype=np.int32) dst_aux_index = int(msg[5].decode("ascii")) is_dummy = False return cls( @@ -233,9 +233,9 @@ class MooncakeKVManager(BaseKVManager): def send_kvcache( self, mooncake_session_id: str, - prefill_kv_indices: npt.NDArray[np.int64], + prefill_kv_indices: npt.NDArray[np.int32], dst_kv_ptrs: list[int], - dst_kv_indices: npt.NDArray[np.int64], + dst_kv_indices: npt.NDArray[np.int32], executor: concurrent.futures.ThreadPoolExecutor, ): # Group by indices @@ -545,7 +545,7 @@ class MooncakeKVManager(BaseKVManager): def add_transfer_request( self, bootstrap_room: int, - kv_indices: npt.NDArray[np.int64], + kv_indices: npt.NDArray[np.int32], index_slice: slice, is_last: bool, aux_index: Optional[int] = None, @@ -701,7 +701,7 @@ class MooncakeKVSender(BaseKVSender): def send( self, - kv_indices: npt.NDArray[np.int64], + kv_indices: npt.NDArray[np.int32], ): index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices)) self.curr_idx += len(kv_indices) @@ -971,7 +971,7 @@ class MooncakeKVReceiver(BaseKVReceiver): cls._socket_locks[endpoint] = threading.Lock() return cls._socket_cache[endpoint], cls._socket_locks[endpoint] - def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): + def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): for bootstrap_info in self.bootstrap_infos: self.prefill_server_url = ( f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}" diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 18378bbf4..aef6cbaf9 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -44,7 +44,7 @@ class TransferInfo: agent_metadata: bytes agent_name: str dst_kv_ptrs: list[int] - dst_kv_indices: npt.NDArray[np.int64] + dst_kv_indices: npt.NDArray[np.int32] dst_aux_ptrs: list[int] dst_aux_index: int dst_gpu_id: int @@ -62,7 +62,7 @@ class TransferInfo: agent_metadata=msg[3], agent_name=msg[4].decode("ascii"), dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])), - dst_kv_indices=np.frombuffer(msg[6], dtype=np.int64), + dst_kv_indices=np.frombuffer(msg[6], dtype=np.int32), dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])), dst_aux_index=int(msg[8].decode("ascii")), dst_gpu_id=int(msg[9].decode("ascii")), @@ -162,9 +162,9 @@ class NixlKVManager(CommonKVManager): def send_kvcache( self, peer_name: str, - prefill_kv_indices: npt.NDArray[np.int64], + prefill_kv_indices: npt.NDArray[np.int32], dst_kv_ptrs: list[int], - dst_kv_indices: npt.NDArray[np.int64], + dst_kv_indices: npt.NDArray[np.int32], dst_gpu_id: int, notif: str, ): @@ -246,7 +246,7 @@ class NixlKVManager(CommonKVManager): def add_transfer_request( self, bootstrap_room: int, - kv_indices: npt.NDArray[np.int64], + kv_indices: npt.NDArray[np.int32], index_slice: slice, is_last: bool, chunk_id: int, @@ -373,7 +373,7 @@ class NixlKVSender(BaseKVSender): def send( self, - kv_indices: npt.NDArray[np.int64], + kv_indices: npt.NDArray[np.int32], ): index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices)) self.curr_idx += len(kv_indices) @@ -417,7 +417,7 @@ class NixlKVReceiver(CommonKVReceiver): self.started_transfer = False super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank) - def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): + def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): for bootstrap_info in self.bootstrap_infos: self.prefill_server_url = ( f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}" diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 94be8b1f6..2bb6312eb 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -576,7 +576,6 @@ class SchedulerDisaggregationPrefillMixin: self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx] .cpu() .numpy() - .astype(np.int64) ) req.start_send_idx = end_idx if last_chunk: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index c8852f0be..e49303937 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -656,6 +656,7 @@ class Scheduler( bootstrap_port=self.server_args.disaggregation_bootstrap_port, max_total_num_tokens=self.max_total_num_tokens, prefill_pp_size=self.server_args.disaggregation_prefill_pp, + num_reserved_decode_tokens=self.server_args.num_reserved_decode_tokens, transfer_backend=self.transfer_backend, )