3fs zerocopy (#9109)

Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
pansicheng
2025-08-22 17:56:38 +08:00
committed by GitHub
parent cebf45994b
commit 70cf4abccc
7 changed files with 310 additions and 29 deletions

View File

@@ -268,9 +268,14 @@ class HiCacheController:
)
rank = get_tensor_model_parallel_rank()
bytes_per_page = (
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
)
if self.mem_pool_host.layout == "page_first":
bytes_per_page = (
mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
)
elif self.mem_pool_host.layout == "layer_first":
bytes_per_page = (
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
)
dtype = mem_pool_host.dtype
self.storage_backend = HiCacheHF3FS.from_env_config(
rank, bytes_per_page, dtype
@@ -555,13 +560,34 @@ class HiCacheController:
operation.mark_done()
return operation.completed_tokens, operation.hash_value
def zerocopy_page_transfer(self, operation, batch_size=8):
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
operation.hash_value, operation.host_indices
)
for i in range(0, len(hashes), batch_size):
page_hashes = hashes[i : i + batch_size]
page_dsts = dsts[i : i + batch_size]
page_data = self.storage_backend.batch_get(page_hashes, page_dsts)
if page_data is None:
logger.warning(
f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
)
break
completed_tokens = operation.completed_tokens
if operation.increment(self.page_size * len(page_hashes)):
for i in range(len(page_hashes)):
completed_tokens += self.page_size
else:
break
def generic_page_transfer(self, operation, batch_size=8):
for i in range(0, len(operation.hash_value), batch_size):
page_hashes = operation.hash_value[i : i + batch_size]
# todo: zero copy
dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
page_hashes
)
dummy_page_dst = [
self.mem_pool_host.get_dummy_flat_data_page()
for _ in range(len(page_hashes))
]
page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
if page_data is None:
logger.warning(
@@ -599,7 +625,10 @@ class HiCacheController:
if self.is_mooncake_backend():
self.mooncake_page_transfer(operation)
elif self.storage_backend_type == "hf3fs":
self.generic_page_transfer(operation, batch_size=128)
if self.mem_pool_host.layout == "page_first":
self.zerocopy_page_transfer(operation, batch_size=128)
elif self.mem_pool_host.layout == "layer_first":
self.generic_page_transfer(operation, batch_size=128)
else:
self.generic_page_transfer(operation)
@@ -716,6 +745,19 @@ class HiCacheController:
self.backup_queue.put(operation)
return operation.id
def zerocopy_page_backup(self, operation, batch_size=8):
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
operation.hash_value, operation.host_indices
)
for i in range(0, len(hashes), batch_size):
page_hashes = hashes[i : i + batch_size]
page_data = dsts[i : i + batch_size]
success = self.storage_backend.batch_set(page_hashes, page_data)
if not success:
logger.warning(f"Failed to write page {page_hashes} to storage.")
break
operation.completed_tokens += self.page_size * len(page_hashes)
def generic_page_backup(self, operation, batch_size=8):
for i in range(0, len(operation.hash_value), batch_size):
page_hashes = operation.hash_value[i : i + batch_size]
@@ -770,7 +812,10 @@ class HiCacheController:
if self.is_mooncake_backend():
self.mooncake_page_backup(operation)
elif self.storage_backend_type == "hf3fs":
self.generic_page_backup(operation, batch_size=128)
if self.mem_pool_host.layout == "page_first":
self.zerocopy_page_backup(operation, batch_size=128)
elif self.mem_pool_host.layout == "layer_first":
self.generic_page_backup(operation, batch_size=128)
else:
self.generic_page_backup(operation)