3fs zerocopy (#9109)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -268,9 +268,14 @@ class HiCacheController:
|
||||
)
|
||||
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
bytes_per_page = (
|
||||
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
|
||||
)
|
||||
if self.mem_pool_host.layout == "page_first":
|
||||
bytes_per_page = (
|
||||
mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
|
||||
)
|
||||
elif self.mem_pool_host.layout == "layer_first":
|
||||
bytes_per_page = (
|
||||
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
|
||||
)
|
||||
dtype = mem_pool_host.dtype
|
||||
self.storage_backend = HiCacheHF3FS.from_env_config(
|
||||
rank, bytes_per_page, dtype
|
||||
@@ -555,13 +560,34 @@ class HiCacheController:
|
||||
operation.mark_done()
|
||||
return operation.completed_tokens, operation.hash_value
|
||||
|
||||
def zerocopy_page_transfer(self, operation, batch_size=8):
|
||||
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
||||
operation.hash_value, operation.host_indices
|
||||
)
|
||||
for i in range(0, len(hashes), batch_size):
|
||||
page_hashes = hashes[i : i + batch_size]
|
||||
page_dsts = dsts[i : i + batch_size]
|
||||
page_data = self.storage_backend.batch_get(page_hashes, page_dsts)
|
||||
if page_data is None:
|
||||
logger.warning(
|
||||
f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
|
||||
)
|
||||
break
|
||||
completed_tokens = operation.completed_tokens
|
||||
if operation.increment(self.page_size * len(page_hashes)):
|
||||
for i in range(len(page_hashes)):
|
||||
completed_tokens += self.page_size
|
||||
else:
|
||||
break
|
||||
|
||||
def generic_page_transfer(self, operation, batch_size=8):
|
||||
for i in range(0, len(operation.hash_value), batch_size):
|
||||
page_hashes = operation.hash_value[i : i + batch_size]
|
||||
# todo: zero copy
|
||||
dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
|
||||
page_hashes
|
||||
)
|
||||
dummy_page_dst = [
|
||||
self.mem_pool_host.get_dummy_flat_data_page()
|
||||
for _ in range(len(page_hashes))
|
||||
]
|
||||
page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
|
||||
if page_data is None:
|
||||
logger.warning(
|
||||
@@ -599,7 +625,10 @@ class HiCacheController:
|
||||
if self.is_mooncake_backend():
|
||||
self.mooncake_page_transfer(operation)
|
||||
elif self.storage_backend_type == "hf3fs":
|
||||
self.generic_page_transfer(operation, batch_size=128)
|
||||
if self.mem_pool_host.layout == "page_first":
|
||||
self.zerocopy_page_transfer(operation, batch_size=128)
|
||||
elif self.mem_pool_host.layout == "layer_first":
|
||||
self.generic_page_transfer(operation, batch_size=128)
|
||||
else:
|
||||
self.generic_page_transfer(operation)
|
||||
|
||||
@@ -716,6 +745,19 @@ class HiCacheController:
|
||||
self.backup_queue.put(operation)
|
||||
return operation.id
|
||||
|
||||
def zerocopy_page_backup(self, operation, batch_size=8):
|
||||
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
||||
operation.hash_value, operation.host_indices
|
||||
)
|
||||
for i in range(0, len(hashes), batch_size):
|
||||
page_hashes = hashes[i : i + batch_size]
|
||||
page_data = dsts[i : i + batch_size]
|
||||
success = self.storage_backend.batch_set(page_hashes, page_data)
|
||||
if not success:
|
||||
logger.warning(f"Failed to write page {page_hashes} to storage.")
|
||||
break
|
||||
operation.completed_tokens += self.page_size * len(page_hashes)
|
||||
|
||||
def generic_page_backup(self, operation, batch_size=8):
|
||||
for i in range(0, len(operation.hash_value), batch_size):
|
||||
page_hashes = operation.hash_value[i : i + batch_size]
|
||||
@@ -770,7 +812,10 @@ class HiCacheController:
|
||||
if self.is_mooncake_backend():
|
||||
self.mooncake_page_backup(operation)
|
||||
elif self.storage_backend_type == "hf3fs":
|
||||
self.generic_page_backup(operation, batch_size=128)
|
||||
if self.mem_pool_host.layout == "page_first":
|
||||
self.zerocopy_page_backup(operation, batch_size=128)
|
||||
elif self.mem_pool_host.layout == "layer_first":
|
||||
self.generic_page_backup(operation, batch_size=128)
|
||||
else:
|
||||
self.generic_page_backup(operation)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user