3fs zerocopy (#9109)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -1,6 +1,16 @@
|
|||||||
|
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib
|
||||||
|
python3 benchmark/hf3fs/bench_client.py
|
||||||
|
|
||||||
|
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib
|
||||||
SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs.json \
|
SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs.json \
|
||||||
python3 benchmark/hf3fs/bench_storage.py
|
python3 benchmark/hf3fs/bench_storage.py
|
||||||
|
|
||||||
|
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages:/usr/local/lib/python3.12/dist-packages/torch/lib
|
||||||
|
export SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs.json
|
||||||
|
echo '{"file_path_prefix": "/data/hf3fs-test-0", "file_size": 1099511627776, "numjobs": 16, "entries": 8}' > \
|
||||||
|
${SGLANG_HICACHE_HF3FS_CONFIG_PATH}
|
||||||
|
python3 benchmark/hf3fs/bench_zerocopy.py
|
||||||
|
|
||||||
####################################################################################################
|
####################################################################################################
|
||||||
|
|
||||||
rm -rf nohup.out && \
|
rm -rf nohup.out && \
|
||||||
|
|||||||
@@ -8,6 +8,9 @@ from typing import List
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
|
||||||
|
Hf3fsLocalMetadataClient,
|
||||||
|
)
|
||||||
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
|
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
|
||||||
|
|
||||||
|
|
||||||
@@ -67,12 +70,15 @@ def test():
|
|||||||
k = f"key_{i}"
|
k = f"key_{i}"
|
||||||
v = torch.randn((numel,)).to(dtype=dtype)
|
v = torch.randn((numel,)).to(dtype=dtype)
|
||||||
ok = hicache_hf3fs.set(k, v)
|
ok = hicache_hf3fs.set(k, v)
|
||||||
|
if i < (file_size // bytes_per_page):
|
||||||
assert ok, f"Failed to insert {k}"
|
assert ok, f"Failed to insert {k}"
|
||||||
|
else:
|
||||||
|
assert not ok
|
||||||
tensors[k] = v
|
tensors[k] = v
|
||||||
assert hicache_hf3fs.get("key_0") is None
|
assert hicache_hf3fs.get("key_8") is None
|
||||||
assert hicache_hf3fs.get("key_1") is None
|
assert hicache_hf3fs.get("key_9") is None
|
||||||
|
|
||||||
start = num_pages - hicache_hf3fs.num_pages
|
start = 0
|
||||||
for i in range(start, start + hicache_hf3fs.num_pages):
|
for i in range(start, start + hicache_hf3fs.num_pages):
|
||||||
k = f"key_{i}"
|
k = f"key_{i}"
|
||||||
assert hicache_hf3fs.exists(k)
|
assert hicache_hf3fs.exists(k)
|
||||||
@@ -83,13 +89,16 @@ def test():
|
|||||||
|
|
||||||
assert not hicache_hf3fs.exists("not_exists")
|
assert not hicache_hf3fs.exists("not_exists")
|
||||||
|
|
||||||
hicache_hf3fs.delete("key_9")
|
hicache_hf3fs.delete("key_7")
|
||||||
v2 = torch.randn((numel,)).to(dtype=dtype)
|
v2 = torch.randn((numel,)).to(dtype=dtype)
|
||||||
assert hicache_hf3fs.set("key_new", v2)
|
assert hicache_hf3fs.set("key_new", v2)
|
||||||
assert torch.allclose(hicache_hf3fs.get("key_new"), v2, atol=1e-3)
|
assert torch.allclose(hicache_hf3fs.get("key_new"), v2, atol=1e-3)
|
||||||
|
|
||||||
hicache_hf3fs.clear()
|
hicache_hf3fs.clear()
|
||||||
assert len(hicache_hf3fs.free_pages) == hicache_hf3fs.num_pages
|
assert (
|
||||||
|
len(hicache_hf3fs.metadata_client.rank_metadata.free_pages)
|
||||||
|
== hicache_hf3fs.metadata_client.rank_metadata.num_pages
|
||||||
|
)
|
||||||
|
|
||||||
# batch
|
# batch
|
||||||
num_pages = 10
|
num_pages = 10
|
||||||
@@ -134,12 +143,14 @@ def bench():
|
|||||||
entries = 8
|
entries = 8
|
||||||
dtype = store_dtype
|
dtype = store_dtype
|
||||||
hicache_hf3fs = HiCacheHF3FS(
|
hicache_hf3fs = HiCacheHF3FS(
|
||||||
|
rank=0,
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
file_size=file_size,
|
file_size=file_size,
|
||||||
numjobs=numjobs,
|
numjobs=numjobs,
|
||||||
bytes_per_page=bytes_per_page,
|
bytes_per_page=bytes_per_page,
|
||||||
entries=entries,
|
entries=entries,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
metadata_client=Hf3fsLocalMetadataClient(),
|
||||||
)
|
)
|
||||||
|
|
||||||
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
|
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
|
||||||
@@ -167,7 +178,10 @@ def bench():
|
|||||||
r_bw = []
|
r_bw = []
|
||||||
r_size = num_page * bytes_per_page / (1 << 30)
|
r_size = num_page * bytes_per_page / (1 << 30)
|
||||||
for i in tqdm(range(warmup + iteration), desc="Benchmarking read (GB/s)"):
|
for i in tqdm(range(warmup + iteration), desc="Benchmarking read (GB/s)"):
|
||||||
keys = random.sample(list(hicache_hf3fs.key_to_index.keys()), num_page)
|
keys = random.sample(
|
||||||
|
list(hicache_hf3fs.metadata_client.rank_metadata.key_to_index.keys()),
|
||||||
|
num_page,
|
||||||
|
)
|
||||||
tik = time.perf_counter()
|
tik = time.perf_counter()
|
||||||
results = hicache_hf3fs.batch_get(keys)
|
results = hicache_hf3fs.batch_get(keys)
|
||||||
tok = time.perf_counter()
|
tok = time.perf_counter()
|
||||||
@@ -195,12 +209,14 @@ def allclose():
|
|||||||
entries = 8
|
entries = 8
|
||||||
dtype = store_dtype
|
dtype = store_dtype
|
||||||
hicache_hf3fs = HiCacheHF3FS(
|
hicache_hf3fs = HiCacheHF3FS(
|
||||||
|
rank=0,
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
file_size=file_size,
|
file_size=file_size,
|
||||||
numjobs=numjobs,
|
numjobs=numjobs,
|
||||||
bytes_per_page=bytes_per_page,
|
bytes_per_page=bytes_per_page,
|
||||||
entries=entries,
|
entries=entries,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
metadata_client=Hf3fsLocalMetadataClient(),
|
||||||
)
|
)
|
||||||
|
|
||||||
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
|
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
|
||||||
@@ -218,7 +234,10 @@ def allclose():
|
|||||||
|
|
||||||
read_keys, read_results = [], []
|
read_keys, read_results = [], []
|
||||||
for i in tqdm(range(iteration), desc="Benchmarking read (GB/s)"):
|
for i in tqdm(range(iteration), desc="Benchmarking read (GB/s)"):
|
||||||
keys = random.sample(list(hicache_hf3fs.key_to_index.keys()), num_page)
|
keys = random.sample(
|
||||||
|
list(hicache_hf3fs.metadata_client.rank_metadata.key_to_index.keys()),
|
||||||
|
num_page,
|
||||||
|
)
|
||||||
results = hicache_hf3fs.batch_get(keys)
|
results = hicache_hf3fs.batch_get(keys)
|
||||||
read_keys.extend(keys)
|
read_keys.extend(keys)
|
||||||
read_results.extend(results)
|
read_results.extend(results)
|
||||||
|
|||||||
140
benchmark/hf3fs/bench_zerocopy.py
Normal file
140
benchmark/hf3fs/bench_zerocopy.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from sglang.srt.distributed import (
|
||||||
|
get_world_group,
|
||||||
|
init_distributed_environment,
|
||||||
|
initialize_model_parallel,
|
||||||
|
)
|
||||||
|
from sglang.srt.managers.cache_controller import (
|
||||||
|
HiCacheController,
|
||||||
|
PrefetchOperation,
|
||||||
|
StorageOperation,
|
||||||
|
)
|
||||||
|
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
|
||||||
|
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
|
||||||
|
from sglang.srt.mem_cache.memory_pool_host import MHATokenToKVPoolHost
|
||||||
|
|
||||||
|
init_distributed_environment(
|
||||||
|
world_size=1,
|
||||||
|
rank=0,
|
||||||
|
distributed_init_method="tcp://127.0.0.1:23456",
|
||||||
|
local_rank=0,
|
||||||
|
backend="gloo",
|
||||||
|
)
|
||||||
|
|
||||||
|
initialize_model_parallel(
|
||||||
|
tensor_model_parallel_size=1,
|
||||||
|
pipeline_model_parallel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
group = get_world_group().cpu_group
|
||||||
|
|
||||||
|
max_total_num_tokens = 524288
|
||||||
|
page_size = 64
|
||||||
|
kv_cache_dtype = torch.bfloat16
|
||||||
|
layer_num = 64
|
||||||
|
head_num, head_dim = 8, 128
|
||||||
|
device = "cuda"
|
||||||
|
hicache_ratio = 2
|
||||||
|
hicache_size = 0
|
||||||
|
hicache_mem_layout = "page_first"
|
||||||
|
# hicache_mem_layout = "layer_first"
|
||||||
|
hicache_write_policy = "write_through"
|
||||||
|
hicache_io_backend = "kernel"
|
||||||
|
hicache_storage_backend = "hf3fs"
|
||||||
|
prefetch_threshold = 256
|
||||||
|
|
||||||
|
op_size = 1024
|
||||||
|
op_num = 16
|
||||||
|
|
||||||
|
token_to_kv_pool = MHATokenToKVPool(
|
||||||
|
max_total_num_tokens,
|
||||||
|
page_size=page_size,
|
||||||
|
dtype=kv_cache_dtype,
|
||||||
|
head_num=head_num,
|
||||||
|
head_dim=head_dim,
|
||||||
|
layer_num=layer_num,
|
||||||
|
device=device,
|
||||||
|
enable_memory_saver=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
||||||
|
max_total_num_tokens,
|
||||||
|
dtype=kv_cache_dtype,
|
||||||
|
device=device,
|
||||||
|
kvcache=token_to_kv_pool,
|
||||||
|
need_sort=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
||||||
|
token_to_kv_pool_host = MHATokenToKVPoolHost(
|
||||||
|
kv_cache,
|
||||||
|
hicache_ratio,
|
||||||
|
hicache_size,
|
||||||
|
page_size,
|
||||||
|
hicache_mem_layout,
|
||||||
|
)
|
||||||
|
|
||||||
|
load_cache_event = threading.Event()
|
||||||
|
cache_controller = HiCacheController(
|
||||||
|
token_to_kv_pool_allocator,
|
||||||
|
token_to_kv_pool_host,
|
||||||
|
page_size,
|
||||||
|
group,
|
||||||
|
load_cache_event=load_cache_event,
|
||||||
|
write_policy=hicache_write_policy,
|
||||||
|
io_backend=hicache_io_backend,
|
||||||
|
storage_backend=hicache_storage_backend,
|
||||||
|
prefetch_threshold=prefetch_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
StorageOperation(
|
||||||
|
torch.tensor(list(range(i, i + op_size))),
|
||||||
|
list(range(i, i + op_size)),
|
||||||
|
hash_value=[f"{j}" for j in range(i, i + op_size, page_size)],
|
||||||
|
)
|
||||||
|
for i in tqdm(range(0, op_num * op_size, op_size))
|
||||||
|
]
|
||||||
|
|
||||||
|
tik = time.monotonic()
|
||||||
|
if hicache_mem_layout == "page_first":
|
||||||
|
for operation in operations:
|
||||||
|
cache_controller.zerocopy_page_backup(operation, batch_size=128)
|
||||||
|
elif hicache_mem_layout == "layer_first":
|
||||||
|
for operation in operations:
|
||||||
|
cache_controller.generic_page_backup(operation, batch_size=128)
|
||||||
|
tok = time.monotonic()
|
||||||
|
print(f"{tok-tik:.6f} s")
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
PrefetchOperation(
|
||||||
|
f"{i}",
|
||||||
|
torch.tensor(list(range(i, i + op_size))),
|
||||||
|
list(range(i, i + op_size)),
|
||||||
|
f"{i}",
|
||||||
|
)
|
||||||
|
for i in tqdm(range(0, op_num * op_size, op_size))
|
||||||
|
]
|
||||||
|
|
||||||
|
for operation in operations:
|
||||||
|
operation.hash_value = [
|
||||||
|
f"{j}"
|
||||||
|
for j in range(
|
||||||
|
int(operation.last_hash), int(operation.last_hash) + op_size, page_size
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
tik = time.monotonic()
|
||||||
|
if hicache_mem_layout == "page_first":
|
||||||
|
for operation in operations:
|
||||||
|
cache_controller.zerocopy_page_transfer(operation, batch_size=128)
|
||||||
|
elif hicache_mem_layout == "layer_first":
|
||||||
|
for operation in operations:
|
||||||
|
cache_controller.generic_page_transfer(operation, batch_size=128)
|
||||||
|
tok = time.monotonic()
|
||||||
|
print(f"{tok-tik:.6f} s")
|
||||||
@@ -268,6 +268,11 @@ class HiCacheController:
|
|||||||
)
|
)
|
||||||
|
|
||||||
rank = get_tensor_model_parallel_rank()
|
rank = get_tensor_model_parallel_rank()
|
||||||
|
if self.mem_pool_host.layout == "page_first":
|
||||||
|
bytes_per_page = (
|
||||||
|
mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
|
||||||
|
)
|
||||||
|
elif self.mem_pool_host.layout == "layer_first":
|
||||||
bytes_per_page = (
|
bytes_per_page = (
|
||||||
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
|
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
|
||||||
)
|
)
|
||||||
@@ -555,13 +560,34 @@ class HiCacheController:
|
|||||||
operation.mark_done()
|
operation.mark_done()
|
||||||
return operation.completed_tokens, operation.hash_value
|
return operation.completed_tokens, operation.hash_value
|
||||||
|
|
||||||
|
def zerocopy_page_transfer(self, operation, batch_size=8):
|
||||||
|
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
||||||
|
operation.hash_value, operation.host_indices
|
||||||
|
)
|
||||||
|
for i in range(0, len(hashes), batch_size):
|
||||||
|
page_hashes = hashes[i : i + batch_size]
|
||||||
|
page_dsts = dsts[i : i + batch_size]
|
||||||
|
page_data = self.storage_backend.batch_get(page_hashes, page_dsts)
|
||||||
|
if page_data is None:
|
||||||
|
logger.warning(
|
||||||
|
f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
|
||||||
|
)
|
||||||
|
break
|
||||||
|
completed_tokens = operation.completed_tokens
|
||||||
|
if operation.increment(self.page_size * len(page_hashes)):
|
||||||
|
for i in range(len(page_hashes)):
|
||||||
|
completed_tokens += self.page_size
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
def generic_page_transfer(self, operation, batch_size=8):
|
def generic_page_transfer(self, operation, batch_size=8):
|
||||||
for i in range(0, len(operation.hash_value), batch_size):
|
for i in range(0, len(operation.hash_value), batch_size):
|
||||||
page_hashes = operation.hash_value[i : i + batch_size]
|
page_hashes = operation.hash_value[i : i + batch_size]
|
||||||
# todo: zero copy
|
# todo: zero copy
|
||||||
dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
|
dummy_page_dst = [
|
||||||
page_hashes
|
self.mem_pool_host.get_dummy_flat_data_page()
|
||||||
)
|
for _ in range(len(page_hashes))
|
||||||
|
]
|
||||||
page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
|
page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
|
||||||
if page_data is None:
|
if page_data is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -599,6 +625,9 @@ class HiCacheController:
|
|||||||
if self.is_mooncake_backend():
|
if self.is_mooncake_backend():
|
||||||
self.mooncake_page_transfer(operation)
|
self.mooncake_page_transfer(operation)
|
||||||
elif self.storage_backend_type == "hf3fs":
|
elif self.storage_backend_type == "hf3fs":
|
||||||
|
if self.mem_pool_host.layout == "page_first":
|
||||||
|
self.zerocopy_page_transfer(operation, batch_size=128)
|
||||||
|
elif self.mem_pool_host.layout == "layer_first":
|
||||||
self.generic_page_transfer(operation, batch_size=128)
|
self.generic_page_transfer(operation, batch_size=128)
|
||||||
else:
|
else:
|
||||||
self.generic_page_transfer(operation)
|
self.generic_page_transfer(operation)
|
||||||
@@ -716,6 +745,19 @@ class HiCacheController:
|
|||||||
self.backup_queue.put(operation)
|
self.backup_queue.put(operation)
|
||||||
return operation.id
|
return operation.id
|
||||||
|
|
||||||
|
def zerocopy_page_backup(self, operation, batch_size=8):
|
||||||
|
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
||||||
|
operation.hash_value, operation.host_indices
|
||||||
|
)
|
||||||
|
for i in range(0, len(hashes), batch_size):
|
||||||
|
page_hashes = hashes[i : i + batch_size]
|
||||||
|
page_data = dsts[i : i + batch_size]
|
||||||
|
success = self.storage_backend.batch_set(page_hashes, page_data)
|
||||||
|
if not success:
|
||||||
|
logger.warning(f"Failed to write page {page_hashes} to storage.")
|
||||||
|
break
|
||||||
|
operation.completed_tokens += self.page_size * len(page_hashes)
|
||||||
|
|
||||||
def generic_page_backup(self, operation, batch_size=8):
|
def generic_page_backup(self, operation, batch_size=8):
|
||||||
for i in range(0, len(operation.hash_value), batch_size):
|
for i in range(0, len(operation.hash_value), batch_size):
|
||||||
page_hashes = operation.hash_value[i : i + batch_size]
|
page_hashes = operation.hash_value[i : i + batch_size]
|
||||||
@@ -770,6 +812,9 @@ class HiCacheController:
|
|||||||
if self.is_mooncake_backend():
|
if self.is_mooncake_backend():
|
||||||
self.mooncake_page_backup(operation)
|
self.mooncake_page_backup(operation)
|
||||||
elif self.storage_backend_type == "hf3fs":
|
elif self.storage_backend_type == "hf3fs":
|
||||||
|
if self.mem_pool_host.layout == "page_first":
|
||||||
|
self.zerocopy_page_backup(operation, batch_size=128)
|
||||||
|
elif self.mem_pool_host.layout == "layer_first":
|
||||||
self.generic_page_backup(operation, batch_size=128)
|
self.generic_page_backup(operation, batch_size=128)
|
||||||
else:
|
else:
|
||||||
self.generic_page_backup(operation)
|
self.generic_page_backup(operation)
|
||||||
|
|||||||
@@ -307,6 +307,9 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|||||||
|
|
||||||
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
||||||
|
|
||||||
|
def get_ksize_per_token(self):
|
||||||
|
return self.get_size_per_token() // 2
|
||||||
|
|
||||||
def init_kv_buffer(self):
|
def init_kv_buffer(self):
|
||||||
if self.layout == "layer_first":
|
if self.layout == "layer_first":
|
||||||
dims = (2, self.layer_num, self.size, self.head_num, self.head_dim)
|
dims = (2, self.layer_num, self.size, self.head_num, self.head_dim)
|
||||||
@@ -496,6 +499,21 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|||||||
element_size_list = [element_size] * len(key_list)
|
element_size_list = [element_size] * len(key_list)
|
||||||
return key_list, ptr_list, element_size_list
|
return key_list, ptr_list, element_size_list
|
||||||
|
|
||||||
|
def get_buffer_with_hash(self, keys, indices):
|
||||||
|
assert self.layout == "page_first"
|
||||||
|
assert len(keys) == (len(indices) // self.page_size)
|
||||||
|
|
||||||
|
key_list = []
|
||||||
|
buf_list = []
|
||||||
|
|
||||||
|
for key, i in zip(keys, range(0, len(indices), self.page_size)):
|
||||||
|
key_list.append(f"{key}-k")
|
||||||
|
buf_list.append(self.k_buffer[i : i + self.page_size])
|
||||||
|
key_list.append(f"{key}-v")
|
||||||
|
buf_list.append(self.v_buffer[i : i + self.page_size])
|
||||||
|
|
||||||
|
return key_list, buf_list
|
||||||
|
|
||||||
|
|
||||||
class MLATokenToKVPoolHost(HostKVCache):
|
class MLATokenToKVPoolHost(HostKVCache):
|
||||||
device_pool: MLATokenToKVPool
|
device_pool: MLATokenToKVPool
|
||||||
@@ -538,6 +556,9 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|||||||
* self.layer_num
|
* self.layer_num
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_ksize_per_token(self):
|
||||||
|
return self.get_size_per_token()
|
||||||
|
|
||||||
def init_kv_buffer(self):
|
def init_kv_buffer(self):
|
||||||
if self.layout == "layer_first":
|
if self.layout == "layer_first":
|
||||||
dims = (
|
dims = (
|
||||||
@@ -704,3 +725,14 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|||||||
)
|
)
|
||||||
element_size_list = [element_size] * len(key_list)
|
element_size_list = [element_size] * len(key_list)
|
||||||
return key_list, ptr_list, element_size_list
|
return key_list, ptr_list, element_size_list
|
||||||
|
|
||||||
|
def get_buffer_with_hash(self, keys, indices):
|
||||||
|
assert self.layout == "page_first"
|
||||||
|
assert len(keys) == (len(indices) // self.page_size)
|
||||||
|
|
||||||
|
buf_list = []
|
||||||
|
|
||||||
|
for i in range(0, len(indices), self.page_size):
|
||||||
|
buf_list.append(self.kv_buffer[i : i + self.page_size])
|
||||||
|
|
||||||
|
return keys, buf_list
|
||||||
|
|||||||
@@ -34,6 +34,9 @@ apt-get update \
|
|||||||
python3 python3-pip \
|
python3 python3-pip \
|
||||||
&& apt-get clean \
|
&& apt-get clean \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
# apt install python3.12 python3.12-venv python3.12-dev
|
||||||
|
# curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||||
|
# python3.12 get-pip.py
|
||||||
|
|
||||||
# Generated wheel location: dist/hf3fs_py_usrbio-1.2.9+2db69ce-cp310-cp310-linux_x86_64.whl
|
# Generated wheel location: dist/hf3fs_py_usrbio-1.2.9+2db69ce-cp310-cp310-linux_x86_64.whl
|
||||||
python3 setup.py bdist_wheel
|
python3 setup.py bdist_wheel
|
||||||
@@ -60,6 +63,6 @@ apt update && apt install -y \
|
|||||||
libuv1-dev
|
libuv1-dev
|
||||||
|
|
||||||
# Install Python Package
|
# Install Python Package
|
||||||
pip install hf3fs_py_usrbio-1.2.9+2db69ce-cp310-cp310-linux_x86_64.whl
|
pip install hf3fs_py_usrbio-1.2.9+394583d-cp312-cp312-linux_x86_64.whl
|
||||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.10/dist-packages
|
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.12/dist-packages
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import signal
|
|||||||
import threading
|
import threading
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import List, Optional, Tuple
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -228,15 +228,23 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
self, key: str, target_location: Optional[torch.Tensor] = None
|
self,
|
||||||
|
key: str,
|
||||||
|
target_location: Optional[Any] = None,
|
||||||
|
target_sizes: Optional[Any] = None,
|
||||||
) -> torch.Tensor | None:
|
) -> torch.Tensor | None:
|
||||||
return self.batch_get([key], [target_location] if target_location else None)[0]
|
return self.batch_get(
|
||||||
|
[key],
|
||||||
|
[target_location] if target_location is not None else None,
|
||||||
|
[target_sizes] if target_sizes is not None else None,
|
||||||
|
)[0]
|
||||||
|
|
||||||
@synchronized()
|
@synchronized()
|
||||||
def batch_get(
|
def batch_get(
|
||||||
self,
|
self,
|
||||||
keys: List[str],
|
keys: List[str],
|
||||||
target_locations: Optional[List[torch.Tensor]] = None,
|
target_locations: Optional[Any] = None,
|
||||||
|
target_sizes: Optional[Any] = None,
|
||||||
) -> List[torch.Tensor | None]:
|
) -> List[torch.Tensor | None]:
|
||||||
page_indices = self.metadata_client.get_page_indices(self.rank, keys)
|
page_indices = self.metadata_client.get_page_indices(self.rank, keys)
|
||||||
|
|
||||||
@@ -246,8 +254,14 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
batch_indices.append(i)
|
batch_indices.append(i)
|
||||||
file_offsets.append(page_index * self.bytes_per_page)
|
file_offsets.append(page_index * self.bytes_per_page)
|
||||||
|
|
||||||
|
if target_locations is not None:
|
||||||
|
for target_location in target_locations:
|
||||||
|
assert target_location.is_contiguous()
|
||||||
|
file_results = target_locations
|
||||||
|
else:
|
||||||
file_results = [
|
file_results = [
|
||||||
torch.empty(self.numel, dtype=self.dtype) for _ in range(len(batch_indices))
|
torch.empty(self.numel, dtype=self.dtype)
|
||||||
|
for _ in range(len(batch_indices))
|
||||||
]
|
]
|
||||||
|
|
||||||
futures = [
|
futures = [
|
||||||
@@ -273,10 +287,27 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def set(self, key: str, value: torch.Tensor) -> bool:
|
def set(
|
||||||
return self.batch_set([key], [value])
|
self,
|
||||||
|
key: str,
|
||||||
|
value: Optional[Any] = None,
|
||||||
|
target_location: Optional[Any] = None,
|
||||||
|
target_sizes: Optional[Any] = None,
|
||||||
|
) -> bool:
|
||||||
|
return self.batch_set(
|
||||||
|
[key],
|
||||||
|
[value] if value is not None else None,
|
||||||
|
[target_location] if target_location is not None else None,
|
||||||
|
[target_sizes] if target_sizes is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
def batch_set(
|
||||||
|
self,
|
||||||
|
keys: List[str],
|
||||||
|
values: Optional[Any] = None,
|
||||||
|
target_locations: Optional[Any] = None,
|
||||||
|
target_sizes: Optional[Any] = None,
|
||||||
|
) -> bool:
|
||||||
# Todo: Add prefix block's hash key
|
# Todo: Add prefix block's hash key
|
||||||
key_with_prefix = [(key, "") for key in keys]
|
key_with_prefix = [(key, "") for key in keys]
|
||||||
indices = self.metadata_client.reserve_and_allocate_page_indices(
|
indices = self.metadata_client.reserve_and_allocate_page_indices(
|
||||||
@@ -292,7 +323,8 @@ class HiCacheHF3FS(HiCacheStorage):
|
|||||||
|
|
||||||
batch_indices.append(i)
|
batch_indices.append(i)
|
||||||
file_offsets.append(page_index * self.bytes_per_page)
|
file_offsets.append(page_index * self.bytes_per_page)
|
||||||
file_values.append(value.contiguous())
|
assert value.is_contiguous()
|
||||||
|
file_values.append(value)
|
||||||
|
|
||||||
futures = [
|
futures = [
|
||||||
self.executor.submit(
|
self.executor.submit(
|
||||||
|
|||||||
Reference in New Issue
Block a user