fix 3fs zerocopy (#9938)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -324,6 +324,22 @@ class HiCacheController:
|
|||||||
group_ranks, backend="gloo"
|
group_ranks, backend="gloo"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Select the get and set functions
|
||||||
|
self.page_get_func = self._generic_page_get
|
||||||
|
self.page_set_func = self._generic_page_set
|
||||||
|
self.batch_exists_func = self.storage_backend.batch_exists
|
||||||
|
self.is_3fs_zerocopy = (
|
||||||
|
self.storage_backend_type == "hf3fs"
|
||||||
|
and self.mem_pool_host.layout == "page_first"
|
||||||
|
)
|
||||||
|
if self.storage_backend_type == "mooncake":
|
||||||
|
self.page_get_func = self._mooncake_page_get
|
||||||
|
self.page_set_func = self._mooncake_page_set
|
||||||
|
elif self.is_3fs_zerocopy:
|
||||||
|
self.page_get_func = self._3fs_zero_copy_page_get
|
||||||
|
self.page_set_func = self._3fs_zero_copy_page_set
|
||||||
|
self.batch_exists_func = self._3fs_zero_copy_batch_exists
|
||||||
|
|
||||||
self.load_cache_event = load_cache_event
|
self.load_cache_event = load_cache_event
|
||||||
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
||||||
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
|
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
|
||||||
@@ -617,13 +633,19 @@ class HiCacheController:
|
|||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
self.host_mem_release_queue.put(chunk)
|
self.host_mem_release_queue.put(chunk)
|
||||||
|
|
||||||
|
def _3fs_zero_copy_batch_exists(self, batch_hashes):
|
||||||
|
_batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes)
|
||||||
|
hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor
|
||||||
|
return hit_page_num
|
||||||
|
|
||||||
def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
|
def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
|
||||||
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash(
|
||||||
hash_values, host_indices
|
hash_values, host_indices
|
||||||
)
|
)
|
||||||
page_data = self.storage_backend.batch_get(hashes, dsts)
|
page_data = self.storage_backend.batch_get(hashes, dsts)
|
||||||
if page_data:
|
if page_data:
|
||||||
operation.increment(self.page_size * len(hashes))
|
inc = self.page_size * len(hashes) // factor
|
||||||
|
operation.increment(inc)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
|
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
|
||||||
@@ -670,17 +692,6 @@ class HiCacheController:
|
|||||||
break # Operation terminated by controller
|
break # Operation terminated by controller
|
||||||
|
|
||||||
def _page_transfer(self, operation):
|
def _page_transfer(self, operation):
|
||||||
# Select the get function and batch size
|
|
||||||
if self.storage_backend_type == "mooncake":
|
|
||||||
get_func = self._mooncake_page_get
|
|
||||||
elif (
|
|
||||||
self.storage_backend_type == "hf3fs"
|
|
||||||
and self.mem_pool_host.layout == "page_first"
|
|
||||||
):
|
|
||||||
get_func = self._3fs_zero_copy_page_get
|
|
||||||
else:
|
|
||||||
get_func = self._generic_page_get
|
|
||||||
|
|
||||||
# Transfer batch by batch
|
# Transfer batch by batch
|
||||||
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
||||||
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
||||||
@@ -689,7 +700,7 @@ class HiCacheController:
|
|||||||
]
|
]
|
||||||
prev_completed_tokens = operation.completed_tokens
|
prev_completed_tokens = operation.completed_tokens
|
||||||
# Get one batch token, and update the completed_tokens if succeed
|
# Get one batch token, and update the completed_tokens if succeed
|
||||||
get_func(operation, batch_hashes, batch_host_indices)
|
self.page_get_func(operation, batch_hashes, batch_host_indices)
|
||||||
# Check termination
|
# Check termination
|
||||||
if (
|
if (
|
||||||
operation.completed_tokens
|
operation.completed_tokens
|
||||||
@@ -746,7 +757,7 @@ class HiCacheController:
|
|||||||
batch_tokens[i : i + self.page_size], last_hash
|
batch_tokens[i : i + self.page_size], last_hash
|
||||||
)
|
)
|
||||||
batch_hashes.append(last_hash)
|
batch_hashes.append(last_hash)
|
||||||
hit_page_num = self.storage_backend.batch_exists(batch_hashes)
|
hit_page_num = self.batch_exists_func(batch_hashes)
|
||||||
hash_value.extend(batch_hashes[:hit_page_num])
|
hash_value.extend(batch_hashes[:hit_page_num])
|
||||||
storage_query_count += hit_page_num * self.page_size
|
storage_query_count += hit_page_num * self.page_size
|
||||||
if hit_page_num < len(batch_hashes):
|
if hit_page_num < len(batch_hashes):
|
||||||
@@ -839,23 +850,13 @@ class HiCacheController:
|
|||||||
|
|
||||||
# zero copy
|
# zero copy
|
||||||
def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
|
def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
|
||||||
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash(
|
||||||
hash_values, host_indices
|
hash_values, host_indices
|
||||||
)
|
)
|
||||||
return self.storage_backend.batch_set(hashes, dsts)
|
return self.storage_backend.batch_set(hashes, dsts)
|
||||||
|
|
||||||
# Backup batch by batch
|
# Backup batch by batch
|
||||||
def _page_backup(self, operation):
|
def _page_backup(self, operation):
|
||||||
# Select the set function and batch size
|
|
||||||
if self.storage_backend_type == "mooncake":
|
|
||||||
backup_set_func = self._mooncake_page_set
|
|
||||||
elif (
|
|
||||||
self.storage_backend_type == "hf3fs"
|
|
||||||
and self.mem_pool_host.layout == "page_first"
|
|
||||||
):
|
|
||||||
backup_set_func = self._3fs_zero_copy_page_set
|
|
||||||
else:
|
|
||||||
backup_set_func = self._generic_page_set
|
|
||||||
# Backup batch by batch
|
# Backup batch by batch
|
||||||
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
||||||
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
||||||
@@ -864,7 +865,7 @@ class HiCacheController:
|
|||||||
]
|
]
|
||||||
# Set one batch token, and record if success.
|
# Set one batch token, and record if success.
|
||||||
# todo: allow partial success
|
# todo: allow partial success
|
||||||
success = backup_set_func(batch_hashes, batch_host_indices)
|
success = self.page_set_func(batch_hashes, batch_host_indices)
|
||||||
if not success:
|
if not success:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Write page to storage: {len(batch_hashes)} pages failed."
|
f"Write page to storage: {len(batch_hashes)} pages failed."
|
||||||
|
|||||||
@@ -500,20 +500,23 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|||||||
element_size_list = [element_size] * len(key_list)
|
element_size_list = [element_size] * len(key_list)
|
||||||
return key_list, ptr_list, element_size_list
|
return key_list, ptr_list, element_size_list
|
||||||
|
|
||||||
def get_buffer_with_hash(self, keys, indices):
|
def get_buffer_with_hash(self, keys, indices=None):
|
||||||
assert self.layout == "page_first"
|
assert self.layout == "page_first"
|
||||||
assert len(keys) == (len(indices) // self.page_size)
|
assert indices is None or (len(keys) == (len(indices) // self.page_size))
|
||||||
|
|
||||||
key_list = []
|
key_list = []
|
||||||
buf_list = []
|
buf_list = []
|
||||||
|
|
||||||
for key, i in zip(keys, range(0, len(indices), self.page_size)):
|
for i in range(len(keys)):
|
||||||
|
key = keys[i]
|
||||||
key_list.append(f"{key}-k")
|
key_list.append(f"{key}-k")
|
||||||
buf_list.append(self.k_buffer[i : i + self.page_size])
|
|
||||||
key_list.append(f"{key}-v")
|
key_list.append(f"{key}-v")
|
||||||
buf_list.append(self.v_buffer[i : i + self.page_size])
|
if indices is not None:
|
||||||
|
index = indices[i * self.page_size]
|
||||||
|
buf_list.append(self.k_buffer[index : index + self.page_size])
|
||||||
|
buf_list.append(self.v_buffer[index : index + self.page_size])
|
||||||
|
|
||||||
return key_list, buf_list
|
return key_list, buf_list, 2
|
||||||
|
|
||||||
|
|
||||||
class MLATokenToKVPoolHost(HostKVCache):
|
class MLATokenToKVPoolHost(HostKVCache):
|
||||||
@@ -728,13 +731,15 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|||||||
element_size_list = [element_size] * len(key_list)
|
element_size_list = [element_size] * len(key_list)
|
||||||
return key_list, ptr_list, element_size_list
|
return key_list, ptr_list, element_size_list
|
||||||
|
|
||||||
def get_buffer_with_hash(self, keys, indices):
|
def get_buffer_with_hash(self, keys, indices=None):
|
||||||
assert self.layout == "page_first"
|
assert self.layout == "page_first"
|
||||||
assert len(keys) == (len(indices) // self.page_size)
|
assert indices is None or (len(keys) == (len(indices) // self.page_size))
|
||||||
|
|
||||||
buf_list = []
|
buf_list = []
|
||||||
|
|
||||||
for i in range(0, len(indices), self.page_size):
|
if indices is not None:
|
||||||
buf_list.append(self.kv_buffer[i : i + self.page_size])
|
for i in range(len(keys)):
|
||||||
|
index = indices[i * self.page_size]
|
||||||
|
buf_list.append(self.kv_buffer[index : index + self.page_size])
|
||||||
|
|
||||||
return keys, buf_list
|
return keys, buf_list, 1
|
||||||
|
|||||||
@@ -415,22 +415,12 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
return result[0] if result else False
|
return result[0] if result else False
|
||||||
|
|
||||||
def batch_exists(self, keys: List[str]) -> int:
|
def batch_exists(self, keys: List[str]) -> int:
|
||||||
if self.is_page_first_layout and not self.is_mla_model:
|
results = self.metadata_client.exists(self.rank, keys)
|
||||||
query_keys = []
|
for i in range(len(keys)):
|
||||||
# Compatible with page_first layout's key format, Refer to memory_pool_host.py#get_buffer_with_hash
|
if not results[i]:
|
||||||
for key in keys:
|
return i
|
||||||
query_keys.append(f"{key}-k")
|
|
||||||
query_keys.append(f"{key}-v")
|
|
||||||
key_multiplier = 2
|
|
||||||
else:
|
|
||||||
query_keys = keys
|
|
||||||
key_multiplier = 1
|
|
||||||
|
|
||||||
exist_result = self.metadata_client.exists(self.rank, query_keys)
|
return len(keys)
|
||||||
for i in range(len(query_keys)):
|
|
||||||
if not exist_result[i]:
|
|
||||||
return i // key_multiplier
|
|
||||||
return len(query_keys) // key_multiplier
|
|
||||||
|
|
||||||
def clear(self) -> bool:
|
def clear(self) -> bool:
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user