[PD] Fix nvlink transport accuracy through transferring metadata with tcp (#9261)
Signed-off-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
import ctypes
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -138,7 +139,29 @@ class KVArgsRegisterInfo:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AuxDataCodec:
|
||||||
|
"""Handles serialization and deserialization of auxiliary data buffers"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def serialize_data_from_buffer(src_addr, data_length):
|
||||||
|
"""Serialize data from memory buffer to bytes"""
|
||||||
|
buffer = (ctypes.c_byte * data_length).from_address(src_addr)
|
||||||
|
return bytes(buffer)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def deserialize_data_to_buffer(kv_args, buffer_index, aux_index, data):
|
||||||
|
"""Deserialize bytes into target memory buffer"""
|
||||||
|
dst_aux_ptr = kv_args.aux_data_ptrs[buffer_index]
|
||||||
|
item_len = kv_args.aux_item_lens[buffer_index]
|
||||||
|
dst_addr = dst_aux_ptr + item_len * aux_index
|
||||||
|
buffer = (ctypes.c_byte * len(data)).from_address(dst_addr)
|
||||||
|
buffer[:] = data
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
class MooncakeKVManager(BaseKVManager):
|
class MooncakeKVManager(BaseKVManager):
|
||||||
|
AUX_DATA_HEADER = b"AUX_DATA"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
args: KVArgs,
|
args: KVArgs,
|
||||||
@@ -283,21 +306,10 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
if not transfer_blocks:
|
if not transfer_blocks:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# TODO(shangming): Fix me when nvlink_transport of Mooncake is bug-free
|
src_addrs, dst_addrs, lengths = zip(*transfer_blocks)
|
||||||
if self.enable_custom_mem_pool:
|
return self.engine.batch_transfer_sync(
|
||||||
# batch_transfer_sync has a higher chance to trigger an accuracy drop for MNNVL, fallback to transfer_sync temporarily
|
mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
|
||||||
for src_addr, dst_addr, length in transfer_blocks:
|
)
|
||||||
status = self.engine.transfer_sync(
|
|
||||||
mooncake_session_id, src_addr, dst_addr, length
|
|
||||||
)
|
|
||||||
if status != 0:
|
|
||||||
return status
|
|
||||||
return 0
|
|
||||||
else:
|
|
||||||
src_addrs, dst_addrs, lengths = zip(*transfer_blocks)
|
|
||||||
return self.engine.batch_transfer_sync(
|
|
||||||
mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
|
|
||||||
)
|
|
||||||
|
|
||||||
def send_kvcache(
|
def send_kvcache(
|
||||||
self,
|
self,
|
||||||
@@ -570,11 +582,14 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
|
|
||||||
def send_aux(
|
def send_aux(
|
||||||
self,
|
self,
|
||||||
mooncake_session_id: str,
|
req: TransferInfo,
|
||||||
prefill_aux_index: int,
|
prefill_aux_index: int,
|
||||||
dst_aux_ptrs: list[int],
|
dst_aux_ptrs: list[int],
|
||||||
dst_aux_index: int,
|
|
||||||
):
|
):
|
||||||
|
# TODO(shangming): Fix me when nvlink_transport of Mooncake is bug-free
|
||||||
|
if self.enable_custom_mem_pool:
|
||||||
|
return self.send_aux_tcp(req, prefill_aux_index, dst_aux_ptrs)
|
||||||
|
|
||||||
transfer_blocks = []
|
transfer_blocks = []
|
||||||
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
|
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
|
||||||
prefill_aux_item_lens = self.kv_args.aux_item_lens
|
prefill_aux_item_lens = self.kv_args.aux_item_lens
|
||||||
@@ -582,10 +597,59 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
for i, dst_aux_ptr in enumerate(dst_aux_ptrs):
|
for i, dst_aux_ptr in enumerate(dst_aux_ptrs):
|
||||||
length = prefill_aux_item_lens[i]
|
length = prefill_aux_item_lens[i]
|
||||||
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
|
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
|
||||||
dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
|
dst_addr = dst_aux_ptrs[i] + length * req.dst_aux_index
|
||||||
transfer_blocks.append((src_addr, dst_addr, length))
|
transfer_blocks.append((src_addr, dst_addr, length))
|
||||||
|
|
||||||
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
return self._transfer_data(req.mooncake_session_id, transfer_blocks)
|
||||||
|
|
||||||
|
def send_aux_tcp(
|
||||||
|
self,
|
||||||
|
req: TransferInfo,
|
||||||
|
prefill_aux_index: int,
|
||||||
|
dst_aux_ptrs: list[int],
|
||||||
|
):
|
||||||
|
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
|
||||||
|
prefill_aux_item_lens = self.kv_args.aux_item_lens
|
||||||
|
|
||||||
|
for i in range(len(prefill_aux_ptrs)):
|
||||||
|
length = prefill_aux_item_lens[i]
|
||||||
|
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
|
||||||
|
data = AuxDataCodec.serialize_data_from_buffer(src_addr, length)
|
||||||
|
|
||||||
|
self.send_aux_data_to_endpoint(
|
||||||
|
remote=req.endpoint,
|
||||||
|
dst_port=req.dst_port,
|
||||||
|
room=req.room,
|
||||||
|
buffer_index=i,
|
||||||
|
aux_index=req.dst_aux_index,
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def send_aux_data_to_endpoint(
|
||||||
|
self,
|
||||||
|
remote: str,
|
||||||
|
dst_port: int,
|
||||||
|
room: int,
|
||||||
|
buffer_index: int,
|
||||||
|
aux_index: int,
|
||||||
|
data: bytes,
|
||||||
|
):
|
||||||
|
socket = self._connect(
|
||||||
|
format_tcp_address(remote, dst_port), is_ipv6=is_valid_ipv6_address(remote)
|
||||||
|
)
|
||||||
|
|
||||||
|
socket.send_multipart(
|
||||||
|
[
|
||||||
|
MooncakeKVManager.AUX_DATA_HEADER,
|
||||||
|
str(room).encode("ascii"),
|
||||||
|
str(buffer_index).encode("ascii"),
|
||||||
|
str(aux_index).encode("ascii"),
|
||||||
|
struct.pack(">I", len(data)),
|
||||||
|
data,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def sync_status_to_decode_endpoint(
|
def sync_status_to_decode_endpoint(
|
||||||
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
|
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
|
||||||
@@ -699,10 +763,9 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
if self.pp_group.is_last_rank:
|
if self.pp_group.is_last_rank:
|
||||||
# Only the last chunk we need to send the aux data
|
# Only the last chunk we need to send the aux data
|
||||||
ret = self.send_aux(
|
ret = self.send_aux(
|
||||||
req.mooncake_session_id,
|
req,
|
||||||
kv_chunk.prefill_aux_index,
|
kv_chunk.prefill_aux_index,
|
||||||
target_rank_registration_info.dst_aux_ptrs,
|
target_rank_registration_info.dst_aux_ptrs,
|
||||||
req.dst_aux_index,
|
|
||||||
)
|
)
|
||||||
polls.append(True if ret == 0 else False)
|
polls.append(True if ret == 0 else False)
|
||||||
dst_ranks_infos.append(
|
dst_ranks_infos.append(
|
||||||
@@ -778,15 +841,38 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
|
|
||||||
threading.Thread(target=bootstrap_thread).start()
|
threading.Thread(target=bootstrap_thread).start()
|
||||||
|
|
||||||
|
def _handle_aux_data(self, msg: List[bytes]):
|
||||||
|
"""Handle AUX_DATA messages received by the decode thread."""
|
||||||
|
room = int(msg[1].decode("ascii"))
|
||||||
|
buffer_index = int(msg[2].decode("ascii"))
|
||||||
|
aux_index = int(msg[3].decode("ascii"))
|
||||||
|
data_length = struct.unpack(">I", msg[4])[0]
|
||||||
|
data = msg[5]
|
||||||
|
|
||||||
|
if len(data) != data_length:
|
||||||
|
logger.error(f"AUX_DATA length mismatch for bootstrap_room {room}")
|
||||||
|
return
|
||||||
|
|
||||||
|
AuxDataCodec.deserialize_data_to_buffer(
|
||||||
|
self.kv_args, buffer_index, aux_index, data
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
|
||||||
|
)
|
||||||
|
|
||||||
def start_decode_thread(self):
|
def start_decode_thread(self):
|
||||||
self.rank_port = get_free_port()
|
self.rank_port = get_free_port()
|
||||||
self._bind_server_socket()
|
self._bind_server_socket()
|
||||||
|
|
||||||
def decode_thread():
|
def decode_thread():
|
||||||
while True:
|
while True:
|
||||||
(bootstrap_room, status, prefill_rank) = (
|
msg = self.server_socket.recv_multipart()
|
||||||
self.server_socket.recv_multipart()
|
if msg[0] == MooncakeKVManager.AUX_DATA_HEADER:
|
||||||
)
|
self._handle_aux_data(msg)
|
||||||
|
continue
|
||||||
|
|
||||||
|
(bootstrap_room, status, prefill_rank) = msg
|
||||||
status = int(status.decode("ascii"))
|
status = int(status.decode("ascii"))
|
||||||
bootstrap_room = int(bootstrap_room.decode("ascii"))
|
bootstrap_room = int(bootstrap_room.decode("ascii"))
|
||||||
prefill_rank = int(prefill_rank.decode("ascii"))
|
prefill_rank = int(prefill_rank.decode("ascii"))
|
||||||
|
|||||||
@@ -99,7 +99,8 @@ class MetadataBuffers:
|
|||||||
# For ascend backend, output tokens are placed in the NPU and will be transferred by D2D channel.
|
# For ascend backend, output tokens are placed in the NPU and will be transferred by D2D channel.
|
||||||
device = "npu"
|
device = "npu"
|
||||||
elif self.custom_mem_pool:
|
elif self.custom_mem_pool:
|
||||||
device = "cuda"
|
# TODO(shangming): Fix me (use 'cuda') when nvlink_transport of Mooncake is bug-free
|
||||||
|
device = "cpu"
|
||||||
with (
|
with (
|
||||||
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||||
if self.custom_mem_pool
|
if self.custom_mem_pool
|
||||||
|
|||||||
Reference in New Issue
Block a user