[PD] optimize kv cache transfer directly using batch transfer (#9149)
Co-authored-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
@@ -356,33 +356,49 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
]
|
]
|
||||||
assert layers_params is not None
|
assert layers_params is not None
|
||||||
|
|
||||||
# Worker function for processing a single layer
|
def set_transfer_blocks(
|
||||||
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
src_ptr: int, dst_ptr: int, item_len: int
|
||||||
|
) -> List[Tuple[int, int, int]]:
|
||||||
transfer_blocks = []
|
transfer_blocks = []
|
||||||
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
||||||
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
||||||
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
||||||
length = item_len * len(prefill_index)
|
length = item_len * len(prefill_index)
|
||||||
transfer_blocks.append((src_addr, dst_addr, length))
|
transfer_blocks.append((src_addr, dst_addr, length))
|
||||||
|
return transfer_blocks
|
||||||
|
|
||||||
|
# Worker function for processing a single layer
|
||||||
|
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
||||||
|
transfer_blocks = set_transfer_blocks(src_ptr, dst_ptr, item_len)
|
||||||
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
||||||
|
|
||||||
futures = [
|
# Worker function for processing all layers in a batch
|
||||||
executor.submit(
|
def process_layers(layers_params: List[Tuple[int, int, int]]) -> int:
|
||||||
process_layer,
|
transfer_blocks = []
|
||||||
src_ptr,
|
for src_ptr, dst_ptr, item_len in layers_params:
|
||||||
dst_ptr,
|
transfer_blocks.extend(set_transfer_blocks(src_ptr, dst_ptr, item_len))
|
||||||
item_len,
|
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
||||||
)
|
|
||||||
for (src_ptr, dst_ptr, item_len) in layers_params
|
|
||||||
]
|
|
||||||
|
|
||||||
for future in concurrent.futures.as_completed(futures):
|
if self.enable_custom_mem_pool:
|
||||||
status = future.result()
|
futures = [
|
||||||
if status != 0:
|
executor.submit(
|
||||||
for f in futures:
|
process_layer,
|
||||||
f.cancel()
|
src_ptr,
|
||||||
return status
|
dst_ptr,
|
||||||
|
item_len,
|
||||||
|
)
|
||||||
|
for (src_ptr, dst_ptr, item_len) in layers_params
|
||||||
|
]
|
||||||
|
for future in concurrent.futures.as_completed(futures):
|
||||||
|
status = future.result()
|
||||||
|
if status != 0:
|
||||||
|
for f in futures:
|
||||||
|
f.cancel()
|
||||||
|
return status
|
||||||
|
else:
|
||||||
|
# Combining all layers' params in one batch transfer is more efficient
|
||||||
|
# compared to using multiple threads
|
||||||
|
return process_layers(layers_params)
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user