[PD] Fix nvlink transport accuracy through transferring metadata with tcp (#9261)
Signed-off-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import ctypes
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
@@ -138,7 +139,29 @@ class KVArgsRegisterInfo:
|
||||
)
|
||||
|
||||
|
||||
class AuxDataCodec:
|
||||
"""Handles serialization and deserialization of auxiliary data buffers"""
|
||||
|
||||
@staticmethod
|
||||
def serialize_data_from_buffer(src_addr, data_length):
|
||||
"""Serialize data from memory buffer to bytes"""
|
||||
buffer = (ctypes.c_byte * data_length).from_address(src_addr)
|
||||
return bytes(buffer)
|
||||
|
||||
@staticmethod
|
||||
def deserialize_data_to_buffer(kv_args, buffer_index, aux_index, data):
|
||||
"""Deserialize bytes into target memory buffer"""
|
||||
dst_aux_ptr = kv_args.aux_data_ptrs[buffer_index]
|
||||
item_len = kv_args.aux_item_lens[buffer_index]
|
||||
dst_addr = dst_aux_ptr + item_len * aux_index
|
||||
buffer = (ctypes.c_byte * len(data)).from_address(dst_addr)
|
||||
buffer[:] = data
|
||||
return
|
||||
|
||||
|
||||
class MooncakeKVManager(BaseKVManager):
|
||||
AUX_DATA_HEADER = b"AUX_DATA"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args: KVArgs,
|
||||
@@ -283,21 +306,10 @@ class MooncakeKVManager(BaseKVManager):
|
||||
if not transfer_blocks:
|
||||
return 0
|
||||
|
||||
# TODO(shangming): Fix me when nvlink_transport of Mooncake is bug-free
|
||||
if self.enable_custom_mem_pool:
|
||||
# batch_transfer_sync has a higher chance to trigger an accuracy drop for MNNVL, fallback to transfer_sync temporarily
|
||||
for src_addr, dst_addr, length in transfer_blocks:
|
||||
status = self.engine.transfer_sync(
|
||||
mooncake_session_id, src_addr, dst_addr, length
|
||||
)
|
||||
if status != 0:
|
||||
return status
|
||||
return 0
|
||||
else:
|
||||
src_addrs, dst_addrs, lengths = zip(*transfer_blocks)
|
||||
return self.engine.batch_transfer_sync(
|
||||
mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
|
||||
)
|
||||
src_addrs, dst_addrs, lengths = zip(*transfer_blocks)
|
||||
return self.engine.batch_transfer_sync(
|
||||
mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
|
||||
)
|
||||
|
||||
def send_kvcache(
|
||||
self,
|
||||
@@ -570,11 +582,14 @@ class MooncakeKVManager(BaseKVManager):
|
||||
|
||||
def send_aux(
|
||||
self,
|
||||
mooncake_session_id: str,
|
||||
req: TransferInfo,
|
||||
prefill_aux_index: int,
|
||||
dst_aux_ptrs: list[int],
|
||||
dst_aux_index: int,
|
||||
):
|
||||
# TODO(shangming): Fix me when nvlink_transport of Mooncake is bug-free
|
||||
if self.enable_custom_mem_pool:
|
||||
return self.send_aux_tcp(req, prefill_aux_index, dst_aux_ptrs)
|
||||
|
||||
transfer_blocks = []
|
||||
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
|
||||
prefill_aux_item_lens = self.kv_args.aux_item_lens
|
||||
@@ -582,10 +597,59 @@ class MooncakeKVManager(BaseKVManager):
|
||||
for i, dst_aux_ptr in enumerate(dst_aux_ptrs):
|
||||
length = prefill_aux_item_lens[i]
|
||||
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
|
||||
dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
|
||||
dst_addr = dst_aux_ptrs[i] + length * req.dst_aux_index
|
||||
transfer_blocks.append((src_addr, dst_addr, length))
|
||||
|
||||
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
||||
return self._transfer_data(req.mooncake_session_id, transfer_blocks)
|
||||
|
||||
def send_aux_tcp(
|
||||
self,
|
||||
req: TransferInfo,
|
||||
prefill_aux_index: int,
|
||||
dst_aux_ptrs: list[int],
|
||||
):
|
||||
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
|
||||
prefill_aux_item_lens = self.kv_args.aux_item_lens
|
||||
|
||||
for i in range(len(prefill_aux_ptrs)):
|
||||
length = prefill_aux_item_lens[i]
|
||||
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
|
||||
data = AuxDataCodec.serialize_data_from_buffer(src_addr, length)
|
||||
|
||||
self.send_aux_data_to_endpoint(
|
||||
remote=req.endpoint,
|
||||
dst_port=req.dst_port,
|
||||
room=req.room,
|
||||
buffer_index=i,
|
||||
aux_index=req.dst_aux_index,
|
||||
data=data,
|
||||
)
|
||||
|
||||
return 0
|
||||
|
||||
def send_aux_data_to_endpoint(
|
||||
self,
|
||||
remote: str,
|
||||
dst_port: int,
|
||||
room: int,
|
||||
buffer_index: int,
|
||||
aux_index: int,
|
||||
data: bytes,
|
||||
):
|
||||
socket = self._connect(
|
||||
format_tcp_address(remote, dst_port), is_ipv6=is_valid_ipv6_address(remote)
|
||||
)
|
||||
|
||||
socket.send_multipart(
|
||||
[
|
||||
MooncakeKVManager.AUX_DATA_HEADER,
|
||||
str(room).encode("ascii"),
|
||||
str(buffer_index).encode("ascii"),
|
||||
str(aux_index).encode("ascii"),
|
||||
struct.pack(">I", len(data)),
|
||||
data,
|
||||
]
|
||||
)
|
||||
|
||||
def sync_status_to_decode_endpoint(
|
||||
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
|
||||
@@ -699,10 +763,9 @@ class MooncakeKVManager(BaseKVManager):
|
||||
if self.pp_group.is_last_rank:
|
||||
# Only the last chunk we need to send the aux data
|
||||
ret = self.send_aux(
|
||||
req.mooncake_session_id,
|
||||
req,
|
||||
kv_chunk.prefill_aux_index,
|
||||
target_rank_registration_info.dst_aux_ptrs,
|
||||
req.dst_aux_index,
|
||||
)
|
||||
polls.append(True if ret == 0 else False)
|
||||
dst_ranks_infos.append(
|
||||
@@ -778,15 +841,38 @@ class MooncakeKVManager(BaseKVManager):
|
||||
|
||||
threading.Thread(target=bootstrap_thread).start()
|
||||
|
||||
def _handle_aux_data(self, msg: List[bytes]):
|
||||
"""Handle AUX_DATA messages received by the decode thread."""
|
||||
room = int(msg[1].decode("ascii"))
|
||||
buffer_index = int(msg[2].decode("ascii"))
|
||||
aux_index = int(msg[3].decode("ascii"))
|
||||
data_length = struct.unpack(">I", msg[4])[0]
|
||||
data = msg[5]
|
||||
|
||||
if len(data) != data_length:
|
||||
logger.error(f"AUX_DATA length mismatch for bootstrap_room {room}")
|
||||
return
|
||||
|
||||
AuxDataCodec.deserialize_data_to_buffer(
|
||||
self.kv_args, buffer_index, aux_index, data
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
|
||||
)
|
||||
|
||||
def start_decode_thread(self):
|
||||
self.rank_port = get_free_port()
|
||||
self._bind_server_socket()
|
||||
|
||||
def decode_thread():
|
||||
while True:
|
||||
(bootstrap_room, status, prefill_rank) = (
|
||||
self.server_socket.recv_multipart()
|
||||
)
|
||||
msg = self.server_socket.recv_multipart()
|
||||
if msg[0] == MooncakeKVManager.AUX_DATA_HEADER:
|
||||
self._handle_aux_data(msg)
|
||||
continue
|
||||
|
||||
(bootstrap_room, status, prefill_rank) = msg
|
||||
status = int(status.decode("ascii"))
|
||||
bootstrap_room = int(bootstrap_room.decode("ascii"))
|
||||
prefill_rank = int(prefill_rank.decode("ascii"))
|
||||
|
||||
@@ -99,7 +99,8 @@ class MetadataBuffers:
|
||||
# For ascend backend, output tokens are placed in the NPU and will be transferred by D2D channel.
|
||||
device = "npu"
|
||||
elif self.custom_mem_pool:
|
||||
device = "cuda"
|
||||
# TODO(shangming): Fix me (use 'cuda') when nvlink_transport of Mooncake is bug-free
|
||||
device = "cpu"
|
||||
with (
|
||||
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||
if self.custom_mem_pool
|
||||
|
||||
Reference in New Issue
Block a user