[PD Disaggregation] replace transfer with batch transfer for better performance (#7236)
This commit is contained in:
@@ -251,17 +251,19 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
|
|
||||||
# Worker function for processing a single layer
|
# Worker function for processing a single layer
|
||||||
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
||||||
|
src_addr_list = []
|
||||||
|
dst_addr_list = []
|
||||||
|
length_list = []
|
||||||
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
||||||
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
||||||
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
||||||
length = item_len * len(prefill_index)
|
length = item_len * len(prefill_index)
|
||||||
|
src_addr_list.append(src_addr)
|
||||||
status = self.engine.transfer_sync(
|
dst_addr_list.append(dst_addr)
|
||||||
mooncake_session_id, src_addr, dst_addr, length
|
length_list.append(length)
|
||||||
)
|
return self.engine.batch_transfer_sync(
|
||||||
if status != 0:
|
mooncake_session_id, src_addr_list, dst_addr_list, length_list
|
||||||
return status
|
)
|
||||||
return 0
|
|
||||||
|
|
||||||
futures = [
|
futures = [
|
||||||
executor.submit(
|
executor.submit(
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -90,5 +90,29 @@ class MooncakeTransferEngine:
|
|||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
def batch_transfer_sync(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
buffers: List[int],
|
||||||
|
peer_buffer_addresses: List[int],
|
||||||
|
lengths: List[int],
|
||||||
|
) -> int:
|
||||||
|
"""Synchronously transfer data to the specified address."""
|
||||||
|
try:
|
||||||
|
ret = self.engine.batch_transfer_sync_write(
|
||||||
|
session_id, buffers, peer_buffer_addresses, lengths
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
ret = -1
|
||||||
|
|
||||||
|
if ret < 0:
|
||||||
|
logger.debug(
|
||||||
|
"Failed to batch transfer data. Buffers: %s, Session: %s, Peer addresses: %s",
|
||||||
|
buffers,
|
||||||
|
session_id,
|
||||||
|
peer_buffer_addresses,
|
||||||
|
)
|
||||||
|
return ret
|
||||||
|
|
||||||
def get_session_id(self):
|
def get_session_id(self):
|
||||||
return self.session_id
|
return self.session_id
|
||||||
|
|||||||
Reference in New Issue
Block a user