diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index e58186d33..e59497dc9 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import concurrent.futures +import ctypes import dataclasses import logging import os @@ -138,7 +139,29 @@ class KVArgsRegisterInfo: ) +class AuxDataCodec: + """Handles serialization and deserialization of auxiliary data buffers""" + + @staticmethod + def serialize_data_from_buffer(src_addr, data_length): + """Serialize data from memory buffer to bytes""" + buffer = (ctypes.c_byte * data_length).from_address(src_addr) + return bytes(buffer) + + @staticmethod + def deserialize_data_to_buffer(kv_args, buffer_index, aux_index, data): + """Deserialize bytes into target memory buffer""" + dst_aux_ptr = kv_args.aux_data_ptrs[buffer_index] + item_len = kv_args.aux_item_lens[buffer_index] + dst_addr = dst_aux_ptr + item_len * aux_index + buffer = (ctypes.c_byte * len(data)).from_address(dst_addr) + buffer[:] = data + return + + class MooncakeKVManager(BaseKVManager): + AUX_DATA_HEADER = b"AUX_DATA" + def __init__( self, args: KVArgs, @@ -283,21 +306,10 @@ class MooncakeKVManager(BaseKVManager): if not transfer_blocks: return 0 - # TODO(shangming): Fix me when nvlink_transport of Mooncake is bug-free - if self.enable_custom_mem_pool: - # batch_transfer_sync has a higher chance to trigger an accuracy drop for MNNVL, fallback to transfer_sync temporarily - for src_addr, dst_addr, length in transfer_blocks: - status = self.engine.transfer_sync( - mooncake_session_id, src_addr, dst_addr, length - ) - if status != 0: - return status - return 0 - else: - src_addrs, dst_addrs, lengths = zip(*transfer_blocks) - return self.engine.batch_transfer_sync( - mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths) - ) + src_addrs, dst_addrs, lengths = zip(*transfer_blocks) + return self.engine.batch_transfer_sync( + mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths) + ) def send_kvcache( self, @@ -570,11 +582,14 @@ class MooncakeKVManager(BaseKVManager): def send_aux( self, - mooncake_session_id: str, + req: TransferInfo, prefill_aux_index: int, dst_aux_ptrs: list[int], - dst_aux_index: int, ): + # TODO(shangming): Fix me when nvlink_transport of Mooncake is bug-free + if self.enable_custom_mem_pool: + return self.send_aux_tcp(req, prefill_aux_index, dst_aux_ptrs) + transfer_blocks = [] prefill_aux_ptrs = self.kv_args.aux_data_ptrs prefill_aux_item_lens = self.kv_args.aux_item_lens @@ -582,10 +597,59 @@ class MooncakeKVManager(BaseKVManager): for i, dst_aux_ptr in enumerate(dst_aux_ptrs): length = prefill_aux_item_lens[i] src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index - dst_addr = dst_aux_ptrs[i] + length * dst_aux_index + dst_addr = dst_aux_ptrs[i] + length * req.dst_aux_index transfer_blocks.append((src_addr, dst_addr, length)) - return self._transfer_data(mooncake_session_id, transfer_blocks) + return self._transfer_data(req.mooncake_session_id, transfer_blocks) + + def send_aux_tcp( + self, + req: TransferInfo, + prefill_aux_index: int, + dst_aux_ptrs: list[int], + ): + prefill_aux_ptrs = self.kv_args.aux_data_ptrs + prefill_aux_item_lens = self.kv_args.aux_item_lens + + for i in range(len(prefill_aux_ptrs)): + length = prefill_aux_item_lens[i] + src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index + data = AuxDataCodec.serialize_data_from_buffer(src_addr, length) + + self.send_aux_data_to_endpoint( + remote=req.endpoint, + dst_port=req.dst_port, + room=req.room, + buffer_index=i, + aux_index=req.dst_aux_index, + data=data, + ) + + return 0 + + def send_aux_data_to_endpoint( + self, + remote: str, + dst_port: int, + room: int, + buffer_index: int, + aux_index: int, + data: bytes, + ): + socket = self._connect( + format_tcp_address(remote, dst_port), is_ipv6=is_valid_ipv6_address(remote) + ) + + socket.send_multipart( + [ + MooncakeKVManager.AUX_DATA_HEADER, + str(room).encode("ascii"), + str(buffer_index).encode("ascii"), + str(aux_index).encode("ascii"), + struct.pack(">I", len(data)), + data, + ] + ) def sync_status_to_decode_endpoint( self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int @@ -699,10 +763,9 @@ class MooncakeKVManager(BaseKVManager): if self.pp_group.is_last_rank: # Only the last chunk we need to send the aux data ret = self.send_aux( - req.mooncake_session_id, + req, kv_chunk.prefill_aux_index, target_rank_registration_info.dst_aux_ptrs, - req.dst_aux_index, ) polls.append(True if ret == 0 else False) dst_ranks_infos.append( @@ -778,15 +841,38 @@ class MooncakeKVManager(BaseKVManager): threading.Thread(target=bootstrap_thread).start() + def _handle_aux_data(self, msg: List[bytes]): + """Handle AUX_DATA messages received by the decode thread.""" + room = int(msg[1].decode("ascii")) + buffer_index = int(msg[2].decode("ascii")) + aux_index = int(msg[3].decode("ascii")) + data_length = struct.unpack(">I", msg[4])[0] + data = msg[5] + + if len(data) != data_length: + logger.error(f"AUX_DATA length mismatch for bootstrap_room {room}") + return + + AuxDataCodec.deserialize_data_to_buffer( + self.kv_args, buffer_index, aux_index, data + ) + + logger.debug( + f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}" + ) + def start_decode_thread(self): self.rank_port = get_free_port() self._bind_server_socket() def decode_thread(): while True: - (bootstrap_room, status, prefill_rank) = ( - self.server_socket.recv_multipart() - ) + msg = self.server_socket.recv_multipart() + if msg[0] == MooncakeKVManager.AUX_DATA_HEADER: + self._handle_aux_data(msg) + continue + + (bootstrap_room, status, prefill_rank) = msg status = int(status.decode("ascii")) bootstrap_room = int(bootstrap_room.decode("ascii")) prefill_rank = int(prefill_rank.decode("ascii")) diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index 720c9d5a5..534528087 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -99,7 +99,8 @@ class MetadataBuffers: # For ascend backend, output tokens are placed in the NPU and will be transferred by D2D channel. device = "npu" elif self.custom_mem_pool: - device = "cuda" + # TODO(shangming): Fix me (use 'cuda') when nvlink_transport of Mooncake is bug-free + device = "cpu" with ( torch.cuda.use_mem_pool(self.custom_mem_pool) if self.custom_mem_pool