Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
49
benchmark/hf3fs/bench.sh
Normal file
49
benchmark/hf3fs/bench.sh
Normal file
@@ -0,0 +1,49 @@
|
||||
SGLANG_HICACHE_HF3FS_CONFIG_PATH=/sgl-workspace/sglang/benchmark/hf3fs/hf3fs.json \
|
||||
python3 benchmark/hf3fs/bench_storage.py
|
||||
|
||||
####################################################################################################
|
||||
|
||||
rm -rf nohup.out && \
|
||||
nohup python3 -m sglang.launch_server \
|
||||
--model-path /code/models/Qwen3-32B/ \
|
||||
--host 0.0.0.0 --port 33301 \
|
||||
--page-size 64 \
|
||||
--enable-hierarchical-cache \
|
||||
--hicache-ratio 2 --hicache-size 0 \
|
||||
--hicache-write-policy write_through \
|
||||
--hicache-storage-backend hf3fs &
|
||||
|
||||
rm -rf bench_multiturn.out && \
|
||||
nohup python3 benchmark/hicache/bench_multiturn.py \
|
||||
--model-path /code/models/Qwen3-32B \
|
||||
--dataset-path /code/models/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--port 33301 \
|
||||
--request-length 2048 --num-clients 512 --num-rounds 3 --max-parallel 8 \
|
||||
> bench_multiturn.out &
|
||||
|
||||
####################################################################################################
|
||||
|
||||
rm -rf nohup.out && \
|
||||
nohup python3 -m sglang.launch_server \
|
||||
--model-path /code/models/DeepSeek-R1/ \
|
||||
--tp 16 --nnodes 2 --node-rank 0 \
|
||||
--dist-init-addr 10.74.249.153:5000 \
|
||||
--host 0.0.0.0 --port 33301 \
|
||||
--page-size 64 \
|
||||
--enable-hierarchical-cache \
|
||||
--hicache-ratio 2 --hicache-size 60 \
|
||||
--hicache-write-policy write_through \
|
||||
--hicache-storage-backend hf3fs &
|
||||
|
||||
rm -rf bench_multiturn.out && \
|
||||
nohup python3 benchmark/hicache/bench_multiturn.py \
|
||||
--model-path /code/models/Qwen3-32B \
|
||||
--dataset-path /code/models/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--port 33301 \
|
||||
--request-length 2048 --num-clients 1024 --num-rounds 3 --max-parallel 8 \
|
||||
> bench_multiturn.out &
|
||||
|
||||
####################################################################################################
|
||||
|
||||
ps aux | grep "sglang.launch_server" | grep -v grep | awk '{print $2}' | xargs kill -9
|
||||
ps aux | grep "bench_multiturn.py" | grep -v grep | awk '{print $2}' | xargs kill -9
|
||||
162
benchmark/hf3fs/bench_client.py
Normal file
162
benchmark/hf3fs/bench_client.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
|
||||
|
||||
|
||||
def print_stats(x: List[int]):
|
||||
x = sorted(x)
|
||||
lenx = len(x)
|
||||
print(
|
||||
f"mean = {sum(x)/len(x):.2f}, "
|
||||
f"min = {min(x):.2f}, "
|
||||
f"p25 = {x[int(lenx*0.25)]:.2f}, "
|
||||
f"p50 = {x[int(lenx*0.5)]:.2f}, "
|
||||
f"p75 = {x[int(lenx*0.75)]:.2f}, "
|
||||
f"max = {max(x):.2f}"
|
||||
)
|
||||
|
||||
|
||||
def test():
|
||||
# /path/to/hf3fs
|
||||
file_path = "/data/bench.bin"
|
||||
file_size = 1 << 40
|
||||
bytes_per_page = 16 << 20
|
||||
entries = 32
|
||||
file_ops = Hf3fsClient(file_path, file_size, bytes_per_page, entries)
|
||||
|
||||
print("test batch_read / batch_write")
|
||||
num_pages = 128
|
||||
dtype = torch.bfloat16
|
||||
numel = bytes_per_page // dtype.itemsize
|
||||
offsets = list(range(file_size // bytes_per_page))
|
||||
random.shuffle(offsets)
|
||||
offsets = offsets[:num_pages]
|
||||
offsets = [i * bytes_per_page for i in offsets]
|
||||
tensor_writes = [
|
||||
torch.randn(numel, dtype=dtype)
|
||||
for _ in tqdm(range(num_pages), desc="prepare tensor")
|
||||
]
|
||||
for i in tqdm(range(0, num_pages, file_ops.entries), desc="batch_write"):
|
||||
results = file_ops.batch_write(
|
||||
offsets[i : i + file_ops.entries], tensor_writes[i : i + file_ops.entries]
|
||||
)
|
||||
assert all([result == numel * dtype.itemsize for result in results])
|
||||
tensor_reads = [
|
||||
torch.empty(numel, dtype=dtype)
|
||||
for _ in tqdm(range(num_pages), desc="prepare tensor")
|
||||
]
|
||||
for i in tqdm(range(0, num_pages, file_ops.entries), desc="batch_read"):
|
||||
results = file_ops.batch_read(
|
||||
offsets[i : i + file_ops.entries], tensor_reads[i : i + file_ops.entries]
|
||||
)
|
||||
assert all([result == numel * dtype.itemsize for result in results])
|
||||
assert all([torch.allclose(r, w) for r, w in zip(tensor_reads, tensor_writes)])
|
||||
|
||||
file_ops.close()
|
||||
print("test done")
|
||||
|
||||
|
||||
def bench():
|
||||
file_path = "/data/bench.bin"
|
||||
file_size = 1 << 40
|
||||
bytes_per_page = 16 << 20
|
||||
entries = 8
|
||||
numjobs = 16
|
||||
|
||||
dtype = torch.bfloat16
|
||||
numel = bytes_per_page // dtype.itemsize
|
||||
|
||||
file_ops = [
|
||||
Hf3fsClient(file_path, file_size, bytes_per_page, entries)
|
||||
for _ in range(numjobs)
|
||||
]
|
||||
|
||||
num_page = entries
|
||||
|
||||
offsets = list(range(file_size // bytes_per_page))
|
||||
tensors_write = [torch.randn(numel, dtype=dtype)] * num_page
|
||||
tensors_read = [torch.empty(numel, dtype=dtype)] * num_page
|
||||
random.shuffle(offsets)
|
||||
|
||||
warmup = 50
|
||||
iteration = 100
|
||||
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=numjobs)
|
||||
|
||||
w_bw = []
|
||||
w_size = num_page * numjobs * bytes_per_page / (1 << 30)
|
||||
for i in tqdm(range(warmup + iteration), desc="Benchmarking write (GB/s)"):
|
||||
_offsets = [
|
||||
[
|
||||
offset * bytes_per_page
|
||||
for offset in offsets[
|
||||
(i * numjobs + j) * num_page : (i * numjobs + j + 1) * num_page
|
||||
]
|
||||
]
|
||||
for j in range(numjobs)
|
||||
]
|
||||
tik = time.perf_counter()
|
||||
futures = [
|
||||
executor.submit(file_ops[j].batch_write, offset, tensors_write)
|
||||
for j, offset in enumerate(_offsets)
|
||||
]
|
||||
results = [future.result() for future in futures]
|
||||
tok = time.perf_counter()
|
||||
if i < warmup:
|
||||
continue
|
||||
w_bw.append(w_size / (tok - tik))
|
||||
results = [
|
||||
_result == bytes_per_page for result in results for _result in result
|
||||
]
|
||||
assert all(results)
|
||||
print_stats(w_bw)
|
||||
|
||||
r_bw = []
|
||||
r_size = w_size
|
||||
for i in tqdm(range(warmup + iteration), desc="Benchmarking read (GB/s)"):
|
||||
_offsets = [
|
||||
[
|
||||
offset * bytes_per_page
|
||||
for offset in offsets[
|
||||
(i * numjobs + j) * num_page : (i * numjobs + j + 1) * num_page
|
||||
]
|
||||
]
|
||||
for j in range(numjobs)
|
||||
]
|
||||
tik = time.perf_counter()
|
||||
futures = [
|
||||
executor.submit(file_ops[j].batch_read, offset, tensors_read)
|
||||
for j, offset in enumerate(_offsets)
|
||||
]
|
||||
results = [future.result() for future in futures]
|
||||
tok = time.perf_counter()
|
||||
if i < warmup:
|
||||
continue
|
||||
r_bw.append(r_size / (tok - tik))
|
||||
results = [
|
||||
_result == bytes_per_page for result in results for _result in result
|
||||
]
|
||||
assert all(results)
|
||||
print_stats(r_bw)
|
||||
|
||||
executor.shutdown(wait=True)
|
||||
for _file_ops in file_ops:
|
||||
_file_ops.close()
|
||||
print("bench done")
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
test()
|
||||
bench()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
241
benchmark/hf3fs/bench_storage.py
Normal file
241
benchmark/hf3fs/bench_storage.py
Normal file
@@ -0,0 +1,241 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
|
||||
|
||||
|
||||
def print_stats(x: List[int]):
|
||||
x = sorted(x)
|
||||
lenx = len(x)
|
||||
print(
|
||||
f"mean = {sum(x)/len(x):.2f}, "
|
||||
f"min = {min(x):.2f}, "
|
||||
f"p25 = {x[int(lenx*0.25)]:.2f}, "
|
||||
f"p50 = {x[int(lenx*0.5)]:.2f}, "
|
||||
f"p75 = {x[int(lenx*0.75)]:.2f}, "
|
||||
f"max = {max(x):.2f}"
|
||||
)
|
||||
|
||||
|
||||
def test():
|
||||
# Qwen3-32B
|
||||
layer_num = 64
|
||||
head_num, head_dim = 8, 128
|
||||
kv_lora_rank, qk_rope_head_dim = 0, 0
|
||||
store_dtype = torch.bfloat16
|
||||
tokens_per_page = 64
|
||||
|
||||
file_path_prefix = "/data/test"
|
||||
file_size = 128 << 20
|
||||
numjobs = 16
|
||||
bytes_per_page = 16 << 20
|
||||
entries = 2
|
||||
dtype = store_dtype
|
||||
|
||||
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
||||
assert config_path
|
||||
try:
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"file_path_prefix": file_path_prefix,
|
||||
"file_size": file_size,
|
||||
"numjobs": numjobs,
|
||||
"entries": entries,
|
||||
},
|
||||
f,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to dump config to {config_path}: {str(e)}")
|
||||
|
||||
rank = 0
|
||||
hicache_hf3fs = HiCacheHF3FS.from_env_config(rank, bytes_per_page, dtype)
|
||||
|
||||
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
|
||||
assert numel * dtype.itemsize == bytes_per_page
|
||||
|
||||
num_pages = 10
|
||||
tensors = {}
|
||||
for i in range(num_pages):
|
||||
k = f"key_{i}"
|
||||
v = torch.randn((numel,)).to(dtype=dtype)
|
||||
ok = hicache_hf3fs.set(k, v)
|
||||
assert ok, f"Failed to insert {k}"
|
||||
tensors[k] = v
|
||||
assert hicache_hf3fs.get("key_0") is None
|
||||
assert hicache_hf3fs.get("key_1") is None
|
||||
|
||||
start = num_pages - hicache_hf3fs.num_pages
|
||||
for i in range(start, start + hicache_hf3fs.num_pages):
|
||||
k = f"key_{i}"
|
||||
assert hicache_hf3fs.exists(k)
|
||||
out = hicache_hf3fs.get(k)
|
||||
assert out is not None
|
||||
v = tensors[k]
|
||||
assert torch.allclose(v, out, atol=1e-3), f"Tensor mismatch for {k}"
|
||||
|
||||
assert not hicache_hf3fs.exists("not_exists")
|
||||
|
||||
hicache_hf3fs.delete("key_9")
|
||||
v2 = torch.randn((numel,)).to(dtype=dtype)
|
||||
assert hicache_hf3fs.set("key_new", v2)
|
||||
assert torch.allclose(hicache_hf3fs.get("key_new"), v2, atol=1e-3)
|
||||
|
||||
hicache_hf3fs.clear()
|
||||
assert len(hicache_hf3fs.free_pages) == hicache_hf3fs.num_pages
|
||||
|
||||
# batch
|
||||
num_pages = 10
|
||||
tensors = {}
|
||||
keys = []
|
||||
values = []
|
||||
for i in range(num_pages):
|
||||
k = f"key_{i}"
|
||||
keys.append(k)
|
||||
v = torch.randn((numel,)).to(dtype=dtype)
|
||||
values.append(v)
|
||||
|
||||
ok = hicache_hf3fs.batch_set(keys, values)
|
||||
assert not ok
|
||||
assert hicache_hf3fs.get("key_8") is None
|
||||
assert hicache_hf3fs.get("key_9") is None
|
||||
|
||||
results = hicache_hf3fs.batch_get(keys[: hicache_hf3fs.num_pages])
|
||||
for result, key, value in zip(
|
||||
results, keys[: hicache_hf3fs.num_pages], values[: hicache_hf3fs.num_pages]
|
||||
):
|
||||
assert torch.allclose(value, result, atol=1e-3), f"Tensor mismatch for {key}"
|
||||
|
||||
hicache_hf3fs.close()
|
||||
os.remove(hicache_hf3fs.file_path)
|
||||
|
||||
print("All test cases passed.")
|
||||
|
||||
|
||||
def bench():
|
||||
# Qwen3-32B
|
||||
layer_num = 64
|
||||
head_num, head_dim = 8, 128
|
||||
kv_lora_rank, qk_rope_head_dim = 0, 0
|
||||
store_dtype = torch.bfloat16
|
||||
tokens_per_page = 64
|
||||
|
||||
file_path = "/data/test.bin"
|
||||
file_size = 1 << 40
|
||||
numjobs = 16
|
||||
bytes_per_page = 16 << 20
|
||||
entries = 8
|
||||
dtype = store_dtype
|
||||
hicache_hf3fs = HiCacheHF3FS(
|
||||
file_path=file_path,
|
||||
file_size=file_size,
|
||||
numjobs=numjobs,
|
||||
bytes_per_page=bytes_per_page,
|
||||
entries=entries,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
|
||||
assert numel * dtype.itemsize == bytes_per_page
|
||||
|
||||
num_page = 128
|
||||
values = [torch.randn((numel,)).to(dtype=dtype) for _ in tqdm(range(num_page))]
|
||||
|
||||
warmup = 50
|
||||
iteration = 100
|
||||
|
||||
w_bw = []
|
||||
w_size = num_page * bytes_per_page / (1 << 30)
|
||||
for i in tqdm(range(warmup + iteration), desc="Benchmarking write (GB/s)"):
|
||||
keys = [f"{j}" for j in range(i * num_page, (i + 1) * num_page)]
|
||||
tik = time.perf_counter()
|
||||
ok = hicache_hf3fs.batch_set(keys, values)
|
||||
tok = time.perf_counter()
|
||||
if i < warmup:
|
||||
continue
|
||||
w_bw.append(w_size / (tok - tik))
|
||||
assert ok
|
||||
print_stats(w_bw)
|
||||
|
||||
r_bw = []
|
||||
r_size = num_page * bytes_per_page / (1 << 30)
|
||||
for i in tqdm(range(warmup + iteration), desc="Benchmarking read (GB/s)"):
|
||||
keys = random.sample(list(hicache_hf3fs.key_to_index.keys()), num_page)
|
||||
tik = time.perf_counter()
|
||||
results = hicache_hf3fs.batch_get(keys)
|
||||
tok = time.perf_counter()
|
||||
if i < warmup:
|
||||
continue
|
||||
r_bw.append(r_size / (tok - tik))
|
||||
assert all([r is not None for r in results])
|
||||
print_stats(r_bw)
|
||||
|
||||
hicache_hf3fs.close()
|
||||
|
||||
|
||||
def allclose():
|
||||
# Qwen3-32B
|
||||
layer_num = 64
|
||||
head_num, head_dim = 8, 128
|
||||
kv_lora_rank, qk_rope_head_dim = 0, 0
|
||||
store_dtype = torch.bfloat16
|
||||
tokens_per_page = 64
|
||||
|
||||
file_path = "/data/test.bin"
|
||||
file_size = 1 << 40
|
||||
numjobs = 16
|
||||
bytes_per_page = 16 << 20
|
||||
entries = 8
|
||||
dtype = store_dtype
|
||||
hicache_hf3fs = HiCacheHF3FS(
|
||||
file_path=file_path,
|
||||
file_size=file_size,
|
||||
numjobs=numjobs,
|
||||
bytes_per_page=bytes_per_page,
|
||||
entries=entries,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
numel = 2 * tokens_per_page * layer_num * head_num * head_dim
|
||||
assert numel * dtype.itemsize == bytes_per_page
|
||||
|
||||
num_page = 128
|
||||
values = [torch.randn((numel,)).to(dtype=dtype) for _ in tqdm(range(num_page))]
|
||||
|
||||
iteration = 100
|
||||
|
||||
for i in tqdm(range(iteration), desc="Benchmarking write (GB/s)"):
|
||||
keys = [f"{j}" for j in range(i * num_page, (i + 1) * num_page)]
|
||||
ok = hicache_hf3fs.batch_set(keys, values)
|
||||
assert ok
|
||||
|
||||
read_keys, read_results = [], []
|
||||
for i in tqdm(range(iteration), desc="Benchmarking read (GB/s)"):
|
||||
keys = random.sample(list(hicache_hf3fs.key_to_index.keys()), num_page)
|
||||
results = hicache_hf3fs.batch_get(keys)
|
||||
read_keys.extend(keys)
|
||||
read_results.extend(results)
|
||||
assert all([r is not None for r in results])
|
||||
|
||||
for key, result in tqdm(zip(read_keys, read_results)):
|
||||
assert torch.allclose(values[int(key) % num_page], result, atol=1e-3)
|
||||
|
||||
hicache_hf3fs.close()
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
test()
|
||||
bench()
|
||||
allclose()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -26,6 +26,7 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
||||
|
||||
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
||||
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -250,17 +251,33 @@ class HiCacheController:
|
||||
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
|
||||
if self.tp_world_size > 1:
|
||||
group_ranks = torch.distributed.get_process_group_ranks(tp_group)
|
||||
self.tp_group = torch.distributed.new_group(group_ranks, backend="gloo")
|
||||
self.prefetch_tp_group = torch.distributed.new_group(
|
||||
group_ranks, backend="gloo"
|
||||
)
|
||||
self.backup_tp_group = torch.distributed.new_group(
|
||||
group_ranks, backend="gloo"
|
||||
)
|
||||
|
||||
if storage_backend == "file":
|
||||
self.storage_backend = HiCacheFile()
|
||||
self.enable_storage = True
|
||||
# todo: threshold policy for prefetching
|
||||
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
||||
elif storage_backend == "hf3fs":
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
bytes_per_page = (
|
||||
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
|
||||
)
|
||||
dtype = mem_pool_host.dtype
|
||||
self.storage_backend = HiCacheHF3FS.from_env_config(
|
||||
rank, bytes_per_page, dtype
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported storage backend: {storage_backend}"
|
||||
)
|
||||
self.enable_storage = True
|
||||
# todo: threshold policy for prefetching
|
||||
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
||||
|
||||
self.load_cache_event = load_cache_event
|
||||
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
||||
@@ -522,8 +539,8 @@ class HiCacheController:
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
||||
for h in operation.hash_value:
|
||||
page_data = self.storage_backend.get(h)
|
||||
page_datas = self.storage_backend.batch_get(operation.hash_value)
|
||||
for h, page_data in zip(operation.hash_value, page_datas):
|
||||
if page_data is None:
|
||||
logger.warning(
|
||||
f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
|
||||
@@ -531,7 +548,9 @@ class HiCacheController:
|
||||
break
|
||||
if operation.increment(self.page_size):
|
||||
self.mem_pool_host.set_from_flat_data_page(
|
||||
operation.host_indices[operation.completed_tokens],
|
||||
operation.host_indices[
|
||||
operation.completed_tokens - self.page_size
|
||||
],
|
||||
page_data,
|
||||
)
|
||||
else:
|
||||
@@ -583,7 +602,7 @@ class HiCacheController:
|
||||
torch.distributed.all_reduce(
|
||||
storage_hit_count_tensor,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_group,
|
||||
group=self.prefetch_tp_group,
|
||||
)
|
||||
storage_hit_count = storage_hit_count_tensor.item()
|
||||
|
||||
@@ -635,21 +654,23 @@ class HiCacheController:
|
||||
last_hash = operation.last_hash
|
||||
tokens_to_backup = operation.token_ids
|
||||
|
||||
last_hashes, data_pages = [], []
|
||||
for i in range(0, len(tokens_to_backup), self.page_size):
|
||||
last_hash = get_hash_str(
|
||||
tokens_to_backup[i : i + self.page_size], last_hash
|
||||
)
|
||||
success = self.storage_backend.set(
|
||||
last_hash,
|
||||
self.mem_pool_host.get_flat_data_page(
|
||||
operation.host_indices[i]
|
||||
),
|
||||
data_page = self.mem_pool_host.get_flat_data_page(
|
||||
operation.host_indices[i]
|
||||
)
|
||||
if not success:
|
||||
logger.warning(f"Failed to write page {last_hash} to storage.")
|
||||
break
|
||||
operation.completed_tokens += self.page_size
|
||||
operation.hash_value.append(last_hash)
|
||||
last_hashes.append(last_hash)
|
||||
data_pages.append(data_page)
|
||||
|
||||
success = self.storage_backend.batch_set(last_hashes, data_pages)
|
||||
if not success:
|
||||
logger.warning(f"Failed to write page {last_hashes} to storage.")
|
||||
else:
|
||||
operation.completed_tokens += len(tokens_to_backup)
|
||||
operation.hash_value.extend(last_hashes)
|
||||
|
||||
min_completed_tokens = operation.completed_tokens
|
||||
if self.tp_world_size > 1:
|
||||
@@ -659,7 +680,7 @@ class HiCacheController:
|
||||
torch.distributed.all_reduce(
|
||||
completed_tokens_tensor,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_group,
|
||||
group=self.backup_tp_group,
|
||||
)
|
||||
min_completed_tokens = completed_tokens_tensor.item()
|
||||
|
||||
|
||||
@@ -79,7 +79,9 @@ class HiRadixCache(RadixCache):
|
||||
self.write_through_threshold = (
|
||||
1 if hicache_write_policy == "write_through" else 3
|
||||
)
|
||||
self.write_through_threshold_storage = 3
|
||||
self.write_through_threshold_storage = (
|
||||
1 if hicache_write_policy == "write_through" else 3
|
||||
)
|
||||
self.load_back_threshold = 10
|
||||
super().__init__(
|
||||
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
|
||||
@@ -388,10 +390,14 @@ class HiRadixCache(RadixCache):
|
||||
self.cache_controller.ack_backup_queue.get()
|
||||
)
|
||||
host_node = self.ongoing_backup[ack_id]
|
||||
if completed_tokens < len(host_node.key):
|
||||
if completed_tokens == 0:
|
||||
host_node.hash_value = None
|
||||
elif completed_tokens < len(host_node.key):
|
||||
# backup is only partially successful, split the node
|
||||
new_node = self._split_node(host_node.key, host_node, completed_tokens)
|
||||
new_node.hash_value = hash_value
|
||||
else:
|
||||
host_node.hash_value = hash_value
|
||||
host_node.release_host()
|
||||
del self.ongoing_backup[ack_id]
|
||||
|
||||
@@ -431,6 +437,8 @@ class HiRadixCache(RadixCache):
|
||||
written_indices,
|
||||
hash_value[:min_completed_tokens],
|
||||
)
|
||||
if len(written_indices):
|
||||
self.cache_controller.mem_pool_host.update_prefetch(written_indices)
|
||||
|
||||
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
|
||||
self.cache_controller.mem_pool_host.free(
|
||||
|
||||
@@ -25,7 +25,6 @@ def synchronized(debug_only=False):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if (not debug_only) or self.debug:
|
||||
return func(self, *args, **kwargs)
|
||||
with self.lock:
|
||||
return func(self, *args, **kwargs)
|
||||
else:
|
||||
@@ -181,6 +180,15 @@ class HostKVCache(abc.ABC):
|
||||
)
|
||||
self.mem_state[indices] = MemoryStateInt.BACKUP
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def update_prefetch(self, indices: torch.Tensor):
|
||||
if not self.is_reserved(indices):
|
||||
raise ValueError(
|
||||
f"The host memory slots should be in RESERVED state before turning into BACKUP. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
self.mem_state[indices] = MemoryStateInt.BACKUP
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def update_synced(self, indices: torch.Tensor):
|
||||
self.mem_state[indices] = MemoryStateInt.SYNCED
|
||||
|
||||
65
python/sglang/srt/mem_cache/storage/hf3fs/README.md
Normal file
65
python/sglang/srt/mem_cache/storage/hf3fs/README.md
Normal file
@@ -0,0 +1,65 @@
|
||||
# HiCacheHF3FS Setup
|
||||
|
||||
## Build & Package
|
||||
### Source Code
|
||||
https://github.com/deepseek-ai/3FS/blob/main/README.md#check-out-source-code
|
||||
```sh
|
||||
git clone https://github.com/deepseek-ai/3fs
|
||||
|
||||
cd 3fs
|
||||
git submodule update --init --recursive
|
||||
./patches/apply.sh
|
||||
```
|
||||
|
||||
### Build Dev Container
|
||||
https://github.com/deepseek-ai/3FS/blob/main/dockerfile/dev.dockerfile
|
||||
```sh
|
||||
cd 3fs/dockerfile
|
||||
docker build -t hf3fs:dev -f dev.dockerfile .
|
||||
```
|
||||
|
||||
### Generate Python Wheel
|
||||
```sh
|
||||
docker run -it hf3fs:dev bash
|
||||
|
||||
# Inside the development container
|
||||
git clone https://github.com/deepseek-ai/3fs
|
||||
|
||||
cd 3fs
|
||||
git submodule update --init --recursive
|
||||
./patches/apply.sh
|
||||
|
||||
apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
python3 python3-pip \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Generated wheel location: dist/hf3fs_py_usrbio-1.2.9+2db69ce-cp310-cp310-linux_x86_64.whl
|
||||
python3 setup.py bdist_wheel
|
||||
```
|
||||
|
||||
## Installation
|
||||
```sh
|
||||
# Install Dependencies
|
||||
# https://github.com/deepseek-ai/3FS/blob/main/dockerfile/dev.dockerfile
|
||||
apt update && apt install -y \
|
||||
libaio-dev \
|
||||
libboost-all-dev \
|
||||
libdouble-conversion-dev \
|
||||
libdwarf-dev \
|
||||
libgflags-dev \
|
||||
libgmock-dev \
|
||||
libgoogle-glog-dev \
|
||||
libgoogle-perftools-dev \
|
||||
libgtest-dev \
|
||||
liblz4-dev \
|
||||
liblzma-dev \
|
||||
libssl-dev \
|
||||
libunwind-dev \
|
||||
libuv1-dev
|
||||
|
||||
# Install Python Package
|
||||
pip install hf3fs_py_usrbio-1.2.9+2db69ce-cp310-cp310-linux_x86_64.whl
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.10/dist-packages
|
||||
```
|
||||
177
python/sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py
Normal file
177
python/sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import threading
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
root = Path(__file__).parent.resolve()
|
||||
hf3fs_utils = load(name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"])
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from hf3fs_fuse.io import (
|
||||
deregister_fd,
|
||||
extract_mount_point,
|
||||
make_ioring,
|
||||
make_iovec,
|
||||
register_fd,
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warning(f"hf3fs_fuse.io is not available: {e}")
|
||||
|
||||
|
||||
def rsynchronized():
|
||||
def _decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
with self.rlock:
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return _decorator
|
||||
|
||||
|
||||
def wsynchronized():
|
||||
def _decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
with self.wlock:
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return _decorator
|
||||
|
||||
|
||||
class Hf3fsClient:
|
||||
def __init__(self, path: str, size: int, bytes_per_page: int, entries: int):
|
||||
self.path = path
|
||||
self.size = size
|
||||
self.bytes_per_page = bytes_per_page
|
||||
self.entries = entries
|
||||
|
||||
self.file = os.open(self.path, os.O_RDWR | os.O_CREAT)
|
||||
os.ftruncate(self.file, size)
|
||||
register_fd(self.file)
|
||||
|
||||
self.hf3fs_mount_point = extract_mount_point(path)
|
||||
self.bs = self.bytes_per_page
|
||||
self.shm_r = multiprocessing.shared_memory.SharedMemory(
|
||||
size=self.bs * self.entries, create=True
|
||||
)
|
||||
self.shm_w = multiprocessing.shared_memory.SharedMemory(
|
||||
size=self.bs * self.entries, create=True
|
||||
)
|
||||
|
||||
self.shm_r_tensor = torch.frombuffer(self.shm_r.buf, dtype=torch.uint8)
|
||||
self.shm_w_tensor = torch.frombuffer(self.shm_w.buf, dtype=torch.uint8)
|
||||
|
||||
self.numa = -1
|
||||
self.ior_r = make_ioring(
|
||||
self.hf3fs_mount_point,
|
||||
self.entries,
|
||||
for_read=True,
|
||||
timeout=1,
|
||||
numa=self.numa,
|
||||
)
|
||||
self.ior_w = make_ioring(
|
||||
self.hf3fs_mount_point,
|
||||
self.entries,
|
||||
for_read=False,
|
||||
timeout=1,
|
||||
numa=self.numa,
|
||||
)
|
||||
self.iov_r = make_iovec(self.shm_r, self.hf3fs_mount_point)
|
||||
self.iov_w = make_iovec(self.shm_w, self.hf3fs_mount_point)
|
||||
|
||||
self.rlock = threading.RLock()
|
||||
self.wlock = threading.RLock()
|
||||
|
||||
@rsynchronized()
|
||||
def batch_read(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
|
||||
self.check(offsets, tensors)
|
||||
|
||||
# prepare
|
||||
current = 0
|
||||
for offset, tensor in zip(offsets, tensors):
|
||||
size = tensor.numel() * tensor.itemsize
|
||||
self.ior_r.prepare(
|
||||
self.iov_r[current : current + size], True, self.file, offset
|
||||
)
|
||||
current += size
|
||||
|
||||
# submit
|
||||
ionum = len(offsets)
|
||||
resv = self.ior_r.submit().wait(min_results=ionum)
|
||||
|
||||
# results
|
||||
hf3fs_utils.read_shm(self.shm_r_tensor, tensors)
|
||||
results = [res.result for res in resv]
|
||||
|
||||
return results
|
||||
|
||||
@wsynchronized()
|
||||
def batch_write(self, offsets: List[int], tensors: List[torch.Tensor]) -> List[int]:
|
||||
self.check(offsets, tensors)
|
||||
|
||||
# prepare
|
||||
hf3fs_utils.write_shm(tensors, self.shm_w_tensor)
|
||||
current = 0
|
||||
for offset, tensor in zip(offsets, tensors):
|
||||
size = tensor.numel() * tensor.itemsize
|
||||
self.ior_w.prepare(
|
||||
self.iov_w[current : current + size], False, self.file, offset
|
||||
)
|
||||
current += size
|
||||
|
||||
# submit
|
||||
ionum = len(offsets)
|
||||
resv = self.ior_w.submit().wait(min_results=ionum)
|
||||
|
||||
# results
|
||||
results = [res.result for res in resv]
|
||||
|
||||
return results
|
||||
|
||||
def check(self, offsets: List[int], tensors: List[torch.Tensor]) -> None:
|
||||
sizes = [t.numel() * t.itemsize for t in tensors]
|
||||
if any(
|
||||
[
|
||||
len(offsets) > self.entries,
|
||||
len(offsets) != len(sizes),
|
||||
all(
|
||||
[
|
||||
offset < 0 or offset + size > self.size
|
||||
for offset, size in zip(offsets, sizes)
|
||||
]
|
||||
),
|
||||
all([size > self.bytes_per_page for size in sizes]),
|
||||
]
|
||||
):
|
||||
self.close()
|
||||
raise ValueError(f"Hf3fsClient.check: {offsets=}, {sizes=}")
|
||||
|
||||
def get_size(self) -> int:
|
||||
return self.size
|
||||
|
||||
def close(self) -> None:
|
||||
deregister_fd(self.file)
|
||||
os.close(self.file)
|
||||
del self.ior_r
|
||||
del self.ior_w
|
||||
del self.iov_r
|
||||
del self.iov_w
|
||||
self.shm_r.close()
|
||||
self.shm_w.close()
|
||||
self.shm_r.unlink()
|
||||
self.shm_w.unlink()
|
||||
|
||||
def flush(self) -> None:
|
||||
os.fsync(self.file)
|
||||
35
python/sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp
Normal file
35
python/sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp
Normal file
@@ -0,0 +1,35 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
|
||||
void read_shm(const torch::Tensor &shm, std::vector<torch::Tensor> dst) {
|
||||
py::gil_scoped_release release;
|
||||
char *src_ptr = static_cast<char *>(shm.data_ptr());
|
||||
size_t current = 0;
|
||||
for (size_t i = 0; i < dst.size(); ++i) {
|
||||
auto &t = dst[i];
|
||||
size_t t_bytes = t.numel() * t.element_size();
|
||||
char *dst_ptr = static_cast<char *>(t.data_ptr());
|
||||
std::memcpy(dst_ptr, src_ptr + current, t_bytes);
|
||||
current += t_bytes;
|
||||
}
|
||||
}
|
||||
|
||||
void write_shm(const std::vector<torch::Tensor> src, torch::Tensor &shm) {
|
||||
py::gil_scoped_release release;
|
||||
char *dst_ptr = static_cast<char *>(shm.data_ptr());
|
||||
size_t current = 0;
|
||||
for (size_t i = 0; i < src.size(); ++i) {
|
||||
auto &t = src[i];
|
||||
size_t t_bytes = t.numel() * t.element_size();
|
||||
char *src_ptr = static_cast<char *>(t.data_ptr());
|
||||
std::memcpy(dst_ptr + current, src_ptr, t_bytes);
|
||||
current += t_bytes;
|
||||
}
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("read_shm", &read_shm, "Read tensors from shared memory");
|
||||
m.def("write_shm", &write_shm, "Write tensors to shared memory");
|
||||
}
|
||||
278
python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py
Normal file
278
python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py
Normal file
@@ -0,0 +1,278 @@
|
||||
import atexit
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from functools import wraps
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
||||
from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AtomicCounter:
|
||||
def __init__(self, n: int):
|
||||
assert n > 0
|
||||
self.n = n
|
||||
self._value = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def next(self) -> int:
|
||||
with self._lock:
|
||||
current = self._value
|
||||
self._value = (current + 1) % self.n
|
||||
return current
|
||||
|
||||
|
||||
def synchronized():
|
||||
def _decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
with self.lock:
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return _decorator
|
||||
|
||||
|
||||
class HiCacheHF3FS(HiCacheStorage):
|
||||
default_env_var: str = "SGLANG_HICACHE_HF3FS_CONFIG_PATH"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
file_size: int,
|
||||
numjobs: int,
|
||||
bytes_per_page: int,
|
||||
entries: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
self.file_path = file_path
|
||||
self.file_size = file_size
|
||||
self.numjobs = numjobs
|
||||
self.bytes_per_page = bytes_per_page
|
||||
self.entries = entries
|
||||
self.dtype = dtype
|
||||
|
||||
self.numel = self.bytes_per_page // self.dtype.itemsize
|
||||
|
||||
self.num_pages = self.file_size // self.bytes_per_page
|
||||
|
||||
logger.info(
|
||||
"HiCacheHF3FS "
|
||||
f"file_path = {self.file_path}, "
|
||||
f"file_size = {self.file_size/(2**30):.2f} GB, "
|
||||
f"numjobs = {self.numjobs}, "
|
||||
f"bytes_per_page = {self.bytes_per_page/(2**20):.2f} MB, "
|
||||
f"entries = {self.entries}, "
|
||||
f"num_pages = {self.num_pages}"
|
||||
)
|
||||
|
||||
self.ac = AtomicCounter(self.numjobs)
|
||||
self.clients = [
|
||||
Hf3fsClient(
|
||||
self.file_path, self.file_size, self.bytes_per_page, self.entries
|
||||
)
|
||||
for _ in range(numjobs)
|
||||
]
|
||||
self.executor = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=self.numjobs, thread_name_prefix="HiCacheHF3FS"
|
||||
)
|
||||
|
||||
# Implemented a preliminary single-file page_hash -> file_offset index as interim storage.
|
||||
# Future iterations may adopt a global KVCache manager to coordinate external cache instances
|
||||
# through centralized metadata orchestration.
|
||||
self.lock = threading.RLock()
|
||||
self.free_pages = list(range(self.num_pages))
|
||||
self.key_to_index = OrderedDict()
|
||||
|
||||
atexit.register(self.close)
|
||||
|
||||
signal.signal(signal.SIGINT, lambda sig, frame: self.close())
|
||||
signal.signal(signal.SIGTERM, lambda sig, frame: self.close())
|
||||
signal.signal(signal.SIGQUIT, lambda sig, frame: self.close())
|
||||
|
||||
@staticmethod
|
||||
def from_env_config(
|
||||
rank: int, bytes_per_page: int, dtype: torch.dtype
|
||||
) -> "HiCacheHF3FS":
|
||||
config_path = os.getenv(HiCacheHF3FS.default_env_var)
|
||||
if not config_path:
|
||||
return HiCacheHF3FS(
|
||||
file_path=f"/data/hicache.{rank}.bin",
|
||||
file_size=1 << 40,
|
||||
numjobs=16,
|
||||
bytes_per_page=bytes_per_page,
|
||||
entries=8,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
try:
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load config from {config_path}: {str(e)}")
|
||||
|
||||
required_keys = {
|
||||
"file_path_prefix",
|
||||
"file_size",
|
||||
"numjobs",
|
||||
"entries",
|
||||
}
|
||||
missing_keys = required_keys - set(config.keys())
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing required keys in config: {missing_keys}")
|
||||
|
||||
return HiCacheHF3FS(
|
||||
file_path=f"{config['file_path_prefix']}.{rank}.bin",
|
||||
file_size=int(config["file_size"]),
|
||||
numjobs=int(config["numjobs"]),
|
||||
bytes_per_page=bytes_per_page,
|
||||
entries=int(config["entries"]),
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def get(
|
||||
self, key: str, target_location: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor | None:
|
||||
return self.batch_get([key], target_location)[0]
|
||||
|
||||
@synchronized()
|
||||
def batch_get(
|
||||
self,
|
||||
keys: List[str],
|
||||
target_locations: Optional[List[torch.Tensor]] = None,
|
||||
) -> List[torch.Tensor | None]:
|
||||
batch_indices, file_offsets = [], []
|
||||
for i, key in enumerate(keys):
|
||||
if key not in self.key_to_index:
|
||||
continue
|
||||
batch_indices.append(i)
|
||||
file_offsets.append(self.key_to_index[key] * self.bytes_per_page)
|
||||
self.key_to_index.move_to_end(key)
|
||||
# TODO: target_locations
|
||||
file_results = [
|
||||
torch.empty(self.numel, dtype=self.dtype) for _ in range(len(batch_indices))
|
||||
]
|
||||
|
||||
futures = [
|
||||
self.executor.submit(
|
||||
self.clients[self.ac.next()].batch_read,
|
||||
file_offsets[i : i + self.entries],
|
||||
file_results[i : i + self.entries],
|
||||
)
|
||||
for i in range(0, len(batch_indices), self.entries)
|
||||
]
|
||||
read_results = [result for future in futures for result in future.result()]
|
||||
|
||||
results = [None] * len(keys)
|
||||
for batch_index, file_result, read_result in zip(
|
||||
batch_indices, file_results, read_results
|
||||
):
|
||||
if read_result == self.bytes_per_page:
|
||||
results[batch_index] = file_result
|
||||
else:
|
||||
logger.error(f"HiCacheHF3FS get {keys[batch_index]} failed")
|
||||
|
||||
return results
|
||||
|
||||
def set(self, key: str, value: torch.Tensor) -> bool:
|
||||
return self.batch_set([key], [value])
|
||||
|
||||
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
||||
indices = self.get_batch_set_indices(keys)
|
||||
batch_indices, file_offsets, file_values = [], [], []
|
||||
for i, (value, (is_written, index)) in enumerate(zip(values, indices)):
|
||||
if is_written or index == -1:
|
||||
continue
|
||||
batch_indices.append(i)
|
||||
file_offsets.append(index * self.bytes_per_page)
|
||||
file_values.append(value.contiguous())
|
||||
|
||||
futures = [
|
||||
self.executor.submit(
|
||||
self.clients[self.ac.next()].batch_write,
|
||||
file_offsets[i : i + self.entries],
|
||||
file_values[i : i + self.entries],
|
||||
)
|
||||
for i in range(0, len(batch_indices), self.entries)
|
||||
]
|
||||
write_results = [
|
||||
result == self.bytes_per_page
|
||||
for future in futures
|
||||
for result in future.result()
|
||||
]
|
||||
|
||||
results = [index[0] for index in indices]
|
||||
for batch_index, write_result in zip(batch_indices, write_results):
|
||||
key = keys[batch_index]
|
||||
index = indices[batch_index][1]
|
||||
if write_result:
|
||||
self.key_to_index[key] = index
|
||||
self.key_to_index.move_to_end(key)
|
||||
else:
|
||||
logger.error(f"HiCacheHF3FS set {key} failed")
|
||||
self.free_pages.append(index)
|
||||
results[batch_index] = write_result
|
||||
return all(results)
|
||||
|
||||
@synchronized()
|
||||
def get_batch_set_indices(self, keys: List[str]) -> list:
|
||||
ionum = len(keys)
|
||||
# results: tuples of (is_written: bool, page_idx: int)
|
||||
# - is_written: True = hit (no I/O), False = write (miss)
|
||||
# - page_idx: page storing data
|
||||
results = [None] * min(ionum, self.num_pages)
|
||||
if ionum > self.num_pages:
|
||||
results.extend([(False, -1)] * (ionum - self.num_pages))
|
||||
|
||||
new_keys = []
|
||||
for batch_index, key in enumerate(keys[: self.num_pages]):
|
||||
if key in self.key_to_index:
|
||||
results[batch_index] = (True, self.key_to_index[key])
|
||||
self.key_to_index.move_to_end(key)
|
||||
else:
|
||||
new_keys.append((batch_index, key))
|
||||
|
||||
for batch_index, _ in new_keys:
|
||||
index = (
|
||||
self.free_pages.pop()
|
||||
if len(self.free_pages) > 0
|
||||
else self.key_to_index.popitem(last=False)[1]
|
||||
)
|
||||
results[batch_index] = (False, index)
|
||||
|
||||
return results
|
||||
|
||||
@synchronized()
|
||||
def delete(self, key: str) -> None:
|
||||
if key not in self.key_to_index:
|
||||
return
|
||||
index = self.key_to_index.pop(key)
|
||||
self.free_pages.append(index)
|
||||
|
||||
@synchronized()
|
||||
def exists(self, key: str) -> bool:
|
||||
return key in self.key_to_index
|
||||
|
||||
@synchronized()
|
||||
def clear(self) -> None:
|
||||
self.free_pages = list(range(self.num_pages))
|
||||
self.key_to_index.clear()
|
||||
|
||||
def close(self) -> None:
|
||||
try:
|
||||
for c in self.clients:
|
||||
c.close()
|
||||
self.executor.shutdown(wait=True)
|
||||
except Exception as e:
|
||||
logger.error(f"close HiCacheHF3FS: {e}")
|
||||
logger.info("close HiCacheHF3FS")
|
||||
@@ -0,0 +1,43 @@
|
||||
import multiprocessing.shared_memory
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.cpp_extension import load
|
||||
from tqdm import tqdm
|
||||
|
||||
root = Path(__file__).parent.resolve()
|
||||
hf3fs_utils = load(
|
||||
name="hf3fs_utils", sources=[f"{root}/hf3fs_utils.cpp"], verbose=True
|
||||
)
|
||||
|
||||
|
||||
def test_rw_shm():
|
||||
numel = 8 << 20
|
||||
dtype = torch.bfloat16
|
||||
page_num = 128
|
||||
page_bytes = numel * dtype.itemsize
|
||||
shm = multiprocessing.shared_memory.SharedMemory(
|
||||
size=page_num * page_bytes, create=True
|
||||
)
|
||||
tshm = torch.frombuffer(shm.buf, dtype=torch.uint8)
|
||||
a = [
|
||||
torch.randn(numel, dtype=dtype)
|
||||
for _ in tqdm(range(page_num), desc="prepare input")
|
||||
]
|
||||
b = [
|
||||
torch.empty(numel, dtype=dtype)
|
||||
for _ in tqdm(range(page_num), desc="prepare output")
|
||||
]
|
||||
hf3fs_utils.write_shm(a, tshm)
|
||||
hf3fs_utils.read_shm(tshm, b)
|
||||
for _a, _b in tqdm(zip(a, b), desc="assert_close"):
|
||||
torch.testing.assert_close(_a, _b)
|
||||
|
||||
del tshm
|
||||
shm.close()
|
||||
shm.unlink()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
@@ -1476,7 +1476,7 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--hicache-storage-backend",
|
||||
type=str,
|
||||
choices=["file"], # todo, mooncake
|
||||
choices=["file", "hf3fs"], # todo, mooncake
|
||||
default=ServerArgs.hicache_storage_backend,
|
||||
help="The storage backend for hierarchical KV cache.",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user