3fs zerocopy (#9109)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -268,9 +268,14 @@ class HiCacheController:
|
||||
)
|
||||
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
bytes_per_page = (
|
||||
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
|
||||
)
|
||||
if self.mem_pool_host.layout == "page_first":
|
||||
bytes_per_page = (
|
||||
mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
|
||||
)
|
||||
elif self.mem_pool_host.layout == "layer_first":
|
||||
bytes_per_page = (
|
||||
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
|
||||
)
|
||||
dtype = mem_pool_host.dtype
|
||||
self.storage_backend = HiCacheHF3FS.from_env_config(
|
||||
rank, bytes_per_page, dtype
|
||||
@@ -555,13 +560,34 @@ class HiCacheController:
|
||||
operation.mark_done()
|
||||
return operation.completed_tokens, operation.hash_value
|
||||
|
||||
def zerocopy_page_transfer(self, operation, batch_size=8):
|
||||
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
||||
operation.hash_value, operation.host_indices
|
||||
)
|
||||
for i in range(0, len(hashes), batch_size):
|
||||
page_hashes = hashes[i : i + batch_size]
|
||||
page_dsts = dsts[i : i + batch_size]
|
||||
page_data = self.storage_backend.batch_get(page_hashes, page_dsts)
|
||||
if page_data is None:
|
||||
logger.warning(
|
||||
f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
|
||||
)
|
||||
break
|
||||
completed_tokens = operation.completed_tokens
|
||||
if operation.increment(self.page_size * len(page_hashes)):
|
||||
for i in range(len(page_hashes)):
|
||||
completed_tokens += self.page_size
|
||||
else:
|
||||
break
|
||||
|
||||
def generic_page_transfer(self, operation, batch_size=8):
|
||||
for i in range(0, len(operation.hash_value), batch_size):
|
||||
page_hashes = operation.hash_value[i : i + batch_size]
|
||||
# todo: zero copy
|
||||
dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
|
||||
page_hashes
|
||||
)
|
||||
dummy_page_dst = [
|
||||
self.mem_pool_host.get_dummy_flat_data_page()
|
||||
for _ in range(len(page_hashes))
|
||||
]
|
||||
page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
|
||||
if page_data is None:
|
||||
logger.warning(
|
||||
@@ -599,7 +625,10 @@ class HiCacheController:
|
||||
if self.is_mooncake_backend():
|
||||
self.mooncake_page_transfer(operation)
|
||||
elif self.storage_backend_type == "hf3fs":
|
||||
self.generic_page_transfer(operation, batch_size=128)
|
||||
if self.mem_pool_host.layout == "page_first":
|
||||
self.zerocopy_page_transfer(operation, batch_size=128)
|
||||
elif self.mem_pool_host.layout == "layer_first":
|
||||
self.generic_page_transfer(operation, batch_size=128)
|
||||
else:
|
||||
self.generic_page_transfer(operation)
|
||||
|
||||
@@ -716,6 +745,19 @@ class HiCacheController:
|
||||
self.backup_queue.put(operation)
|
||||
return operation.id
|
||||
|
||||
def zerocopy_page_backup(self, operation, batch_size=8):
|
||||
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
||||
operation.hash_value, operation.host_indices
|
||||
)
|
||||
for i in range(0, len(hashes), batch_size):
|
||||
page_hashes = hashes[i : i + batch_size]
|
||||
page_data = dsts[i : i + batch_size]
|
||||
success = self.storage_backend.batch_set(page_hashes, page_data)
|
||||
if not success:
|
||||
logger.warning(f"Failed to write page {page_hashes} to storage.")
|
||||
break
|
||||
operation.completed_tokens += self.page_size * len(page_hashes)
|
||||
|
||||
def generic_page_backup(self, operation, batch_size=8):
|
||||
for i in range(0, len(operation.hash_value), batch_size):
|
||||
page_hashes = operation.hash_value[i : i + batch_size]
|
||||
@@ -770,7 +812,10 @@ class HiCacheController:
|
||||
if self.is_mooncake_backend():
|
||||
self.mooncake_page_backup(operation)
|
||||
elif self.storage_backend_type == "hf3fs":
|
||||
self.generic_page_backup(operation, batch_size=128)
|
||||
if self.mem_pool_host.layout == "page_first":
|
||||
self.zerocopy_page_backup(operation, batch_size=128)
|
||||
elif self.mem_pool_host.layout == "layer_first":
|
||||
self.generic_page_backup(operation, batch_size=128)
|
||||
else:
|
||||
self.generic_page_backup(operation)
|
||||
|
||||
|
||||
@@ -307,6 +307,9 @@ class MHATokenToKVPoolHost(HostKVCache):
|
||||
|
||||
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
||||
|
||||
def get_ksize_per_token(self):
|
||||
return self.get_size_per_token() // 2
|
||||
|
||||
def init_kv_buffer(self):
|
||||
if self.layout == "layer_first":
|
||||
dims = (2, self.layer_num, self.size, self.head_num, self.head_dim)
|
||||
@@ -496,6 +499,21 @@ class MHATokenToKVPoolHost(HostKVCache):
|
||||
element_size_list = [element_size] * len(key_list)
|
||||
return key_list, ptr_list, element_size_list
|
||||
|
||||
def get_buffer_with_hash(self, keys, indices):
|
||||
assert self.layout == "page_first"
|
||||
assert len(keys) == (len(indices) // self.page_size)
|
||||
|
||||
key_list = []
|
||||
buf_list = []
|
||||
|
||||
for key, i in zip(keys, range(0, len(indices), self.page_size)):
|
||||
key_list.append(f"{key}-k")
|
||||
buf_list.append(self.k_buffer[i : i + self.page_size])
|
||||
key_list.append(f"{key}-v")
|
||||
buf_list.append(self.v_buffer[i : i + self.page_size])
|
||||
|
||||
return key_list, buf_list
|
||||
|
||||
|
||||
class MLATokenToKVPoolHost(HostKVCache):
|
||||
device_pool: MLATokenToKVPool
|
||||
@@ -538,6 +556,9 @@ class MLATokenToKVPoolHost(HostKVCache):
|
||||
* self.layer_num
|
||||
)
|
||||
|
||||
def get_ksize_per_token(self):
|
||||
return self.get_size_per_token()
|
||||
|
||||
def init_kv_buffer(self):
|
||||
if self.layout == "layer_first":
|
||||
dims = (
|
||||
@@ -704,3 +725,14 @@ class MLATokenToKVPoolHost(HostKVCache):
|
||||
)
|
||||
element_size_list = [element_size] * len(key_list)
|
||||
return key_list, ptr_list, element_size_list
|
||||
|
||||
def get_buffer_with_hash(self, keys, indices):
|
||||
assert self.layout == "page_first"
|
||||
assert len(keys) == (len(indices) // self.page_size)
|
||||
|
||||
buf_list = []
|
||||
|
||||
for i in range(0, len(indices), self.page_size):
|
||||
buf_list.append(self.kv_buffer[i : i + self.page_size])
|
||||
|
||||
return keys, buf_list
|
||||
|
||||
@@ -34,6 +34,9 @@ apt-get update \
|
||||
python3 python3-pip \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
# apt install python3.12 python3.12-venv python3.12-dev
|
||||
# curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||
# python3.12 get-pip.py
|
||||
|
||||
# Generated wheel location: dist/hf3fs_py_usrbio-1.2.9+2db69ce-cp310-cp310-linux_x86_64.whl
|
||||
python3 setup.py bdist_wheel
|
||||
@@ -60,6 +63,6 @@ apt update && apt install -y \
|
||||
libuv1-dev
|
||||
|
||||
# Install Python Package
|
||||
pip install hf3fs_py_usrbio-1.2.9+2db69ce-cp310-cp310-linux_x86_64.whl
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.10/dist-packages
|
||||
pip install hf3fs_py_usrbio-1.2.9+394583d-cp312-cp312-linux_x86_64.whl
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages
|
||||
```
|
||||
|
||||
@@ -7,7 +7,7 @@ import signal
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -228,15 +228,23 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
)
|
||||
|
||||
def get(
|
||||
self, key: str, target_location: Optional[torch.Tensor] = None
|
||||
self,
|
||||
key: str,
|
||||
target_location: Optional[Any] = None,
|
||||
target_sizes: Optional[Any] = None,
|
||||
) -> torch.Tensor | None:
|
||||
return self.batch_get([key], [target_location] if target_location else None)[0]
|
||||
return self.batch_get(
|
||||
[key],
|
||||
[target_location] if target_location is not None else None,
|
||||
[target_sizes] if target_sizes is not None else None,
|
||||
)[0]
|
||||
|
||||
@synchronized()
|
||||
def batch_get(
|
||||
self,
|
||||
keys: List[str],
|
||||
target_locations: Optional[List[torch.Tensor]] = None,
|
||||
target_locations: Optional[Any] = None,
|
||||
target_sizes: Optional[Any] = None,
|
||||
) -> List[torch.Tensor | None]:
|
||||
page_indices = self.metadata_client.get_page_indices(self.rank, keys)
|
||||
|
||||
@@ -246,9 +254,15 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
batch_indices.append(i)
|
||||
file_offsets.append(page_index * self.bytes_per_page)
|
||||
|
||||
file_results = [
|
||||
torch.empty(self.numel, dtype=self.dtype) for _ in range(len(batch_indices))
|
||||
]
|
||||
if target_locations is not None:
|
||||
for target_location in target_locations:
|
||||
assert target_location.is_contiguous()
|
||||
file_results = target_locations
|
||||
else:
|
||||
file_results = [
|
||||
torch.empty(self.numel, dtype=self.dtype)
|
||||
for _ in range(len(batch_indices))
|
||||
]
|
||||
|
||||
futures = [
|
||||
self.executor.submit(
|
||||
@@ -273,10 +287,27 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
|
||||
return results
|
||||
|
||||
def set(self, key: str, value: torch.Tensor) -> bool:
|
||||
return self.batch_set([key], [value])
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Optional[Any] = None,
|
||||
target_location: Optional[Any] = None,
|
||||
target_sizes: Optional[Any] = None,
|
||||
) -> bool:
|
||||
return self.batch_set(
|
||||
[key],
|
||||
[value] if value is not None else None,
|
||||
[target_location] if target_location is not None else None,
|
||||
[target_sizes] if target_sizes is not None else None,
|
||||
)
|
||||
|
||||
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
||||
def batch_set(
|
||||
self,
|
||||
keys: List[str],
|
||||
values: Optional[Any] = None,
|
||||
target_locations: Optional[Any] = None,
|
||||
target_sizes: Optional[Any] = None,
|
||||
) -> bool:
|
||||
# Todo: Add prefix block's hash key
|
||||
key_with_prefix = [(key, "") for key in keys]
|
||||
indices = self.metadata_client.reserve_and_allocate_page_indices(
|
||||
@@ -292,7 +323,8 @@ class HiCacheHF3FS(HiCacheStorage):
|
||||
|
||||
batch_indices.append(i)
|
||||
file_offsets.append(page_index * self.bytes_per_page)
|
||||
file_values.append(value.contiguous())
|
||||
assert value.is_contiguous()
|
||||
file_values.append(value)
|
||||
|
||||
futures = [
|
||||
self.executor.submit(
|
||||
|
||||
Reference in New Issue
Block a user